├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── code ├── .DS_Store ├── README.md ├── augmentations │ ├── __init__.py │ └── ctaugment.py ├── config.py ├── configs │ └── swin_tiny_patch4_window7_224_lite.yaml ├── dataloaders │ ├── acdc_data_processing.py │ ├── brats2019.py │ ├── brats_proprecessing.py │ ├── dataset.py │ └── utils.py ├── networks │ ├── VoxResNet.py │ ├── attention.py │ ├── attention_unet.py │ ├── config.py │ ├── discriminator.py │ ├── efficient_encoder.py │ ├── efficientunet.py │ ├── encoder_tool.py │ ├── enet.py │ ├── grid_attention_layer.py │ ├── net_factory.py │ ├── net_factory_3d.py │ ├── networks_other.py │ ├── neural_network.py │ ├── nnunet.py │ ├── pnet.py │ ├── swin_transformer_unet_skip_expand_decoder_sys.py │ ├── unet.py │ ├── unet_3D.py │ ├── unet_3D_dv_semi.py │ ├── utils.py │ ├── vision_transformer.py │ └── vnet.py ├── pretrained_ckpt │ └── readme.txt ├── test_2D_fully.py ├── test_3D.py ├── test_3D_util.py ├── test_acdc_unet_semi_seg.sh ├── test_brats2019_semi_seg.sh ├── test_urpc.py ├── test_urpc_util.py ├── train_acdc_unet_semi_seg.sh ├── train_adversarial_network_2D.py ├── train_adversarial_network_3D.py ├── train_brats2019_semi_seg.sh ├── train_cross_consistency_training_2D.py ├── train_cross_pseudo_supervision_2D.py ├── train_cross_pseudo_supervision_3D.py ├── train_cross_teaching_between_cnn_transformer_2D.py ├── train_deep_co_training_2D.py ├── train_entropy_minimization_2D.py ├── train_entropy_minimization_3D.py ├── train_fixmatch_cta.py ├── train_fixmatch_standard_augs.py ├── train_fully_supervised_2D.py ├── train_fully_supervised_3D.py ├── train_interpolation_consistency_training_2D.py ├── train_interpolation_consistency_training_3D.py ├── train_mean_teacher_2D.py ├── train_mean_teacher_3D.py ├── train_regularized_dropout_2D.py ├── train_regularized_dropout_3D.py ├── train_uncertainty_aware_mean_teacher_2D.py ├── train_uncertainty_aware_mean_teacher_3D.py ├── train_uncertainty_rectified_pyramid_consistency_2D.py ├── train_uncertainty_rectified_pyramid_consistency_3D.py ├── utils │ ├── losses.py │ ├── metrics.py │ ├── ramps.py │ └── util.py ├── val_2D.py ├── val_3D.py └── val_urpc_util.py ├── data ├── ACDC │ ├── README.md │ ├── test.list │ ├── train.list │ ├── train_slices.list │ └── val.list └── BraTS2019 │ ├── README.md │ ├── test.txt │ ├── train.txt │ └── val.txt └── environment.yml /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SSL4MIS/06df6047a59aba9988ced8331998b5957ecb356b/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | .vscode/ 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Data 132 | data/ 133 | 134 | # Models 135 | model/ 136 | 137 | # test file 138 | code/test.py 139 | test.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 xdluo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiLab-git/SSL4MIS/06df6047a59aba9988ced8331998b5957ecb356b/code/.DS_Store -------------------------------------------------------------------------------- /code/README.md: -------------------------------------------------------------------------------- 1 | ## Semi-supervised Learning for Medical Image Segmentation (**SSL4MIS**) 2 | 3 | ## Requirements 4 | Some important required packages include: 5 | * [Pytorch][torch_link] version >=0.4.1. 6 | * TensorBoardX 7 | * Python == 3.6 8 | * Efficientnet-Pytorch `pip install efficientnet_pytorch` 9 | * Some basic python packages such as Numpy, Scikit-image, SimpleITK, Scipy ...... 10 | 11 | Follow official guidance to install [Pytorch][torch_link]. 12 | 13 | [torch_link]:https://pytorch.org/ 14 | 15 | # Usage 16 | 17 | 1. Clone the repo: 18 | ``` 19 | git clone https://github.com/HiLab-git/SSL4MIS.git 20 | cd SSL4MIS 21 | ``` 22 | 2. Download the processed data and put the data in `../data/BraTS2019` or `../data/ACDC`, please read and follow the [README](https://github.com/Luoxd1996/SSL4MIS/tree/master/data/). 23 | 24 | 3. Train the model 25 | ``` 26 | cd code 27 | python train_XXXXX_3D.py or python train_XXXXX_2D.py or bash train_acdc_XXXXX.sh 28 | ``` 29 | 30 | 4. Test the model 31 | ``` 32 | python test_XXXXX.py 33 | ``` 34 | # Reimplemented methods 35 | * [Mean Teacher](https://papers.nips.cc/paper/6719-mean-teachers-are-better-role-models-weight-averaged-consistency-targets-improve-semi-supervised-deep-learning-results.pdf)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_mean_teacher_2D.py)/[3D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_mean_teacher_3D.py)] 36 | * [Entropy Minimization](https://openaccess.thecvf.com/content_CVPR_2019/papers/Vu_ADVENT_Adversarial_Entropy_Minimization_for_Domain_Adaptation_in_Semantic_Segmentation_CVPR_2019_paper.pdf)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_entropy_minimization_2D.py)/[3D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_entropy_minimization_3D.py)] 37 | * [Deep Adversarial Networks](https://link.springer.com/chapter/10.1007/978-3-319-66179-7_47)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_adversarial_network_2D.py)/[3D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_adversarial_network_3D.py)] 38 | * [Uncertainty Aware Mean Teacher](https://arxiv.org/pdf/1907.07034.pdf)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_uncertainty_aware_mean_teacher_2D.py)/[3D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_uncertainty_aware_mean_teacher_3D.py)] 39 | * [Interpolation Consistency Training](https://arxiv.org/pdf/1903.03825.pdf)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_interpolation_consistency_training_2D.py)/[3D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_interpolation_consistency_training_3D.py)] 40 | * [Uncertainty Rectified Pyramid Consistency](https://arxiv.org/pdf/2012.07042.pdf)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_uncertainty_rectified_pyramid_consistency_2D.py)/[3D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_uncertainty_rectified_pyramid_consistency_3D.py)] 41 | * [Cross Pseudo Supervision](https://arxiv.org/abs/2106.01226)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_cross_pseudo_supervision_2D.py)/[3D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_cross_pseudo_supervision_3D.py)] 42 | * [Cross Consistency Training](https://openaccess.thecvf.com/content_CVPR_2020/papers/Ouali_Semi-Supervised_Semantic_Segmentation_With_Cross-Consistency_Training_CVPR_2020_paper.pdf)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_cross_consistency_training_2D.py)] 43 | * [Deep Co-Training](https://openaccess.thecvf.com/content_ECCV_2018/papers/Siyuan_Qiao_Deep_Co-Training_for_ECCV_2018_paper.pdf)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_deep_co_training_2D.py)] 44 | * [Cross Teaching between CNN and Transformer](https://arxiv.org/pdf/2112.04894.pdf)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_cross_teaching_between_cnn_transformer_2D.py)] 45 | * [Regularized Dropout](https://proceedings.neurips.cc/paper/2021/file/5a66b9200f29ac3fa0ae244cc2a51b39-Paper.pdf)[[2D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_regularized_dropout_2D.py)/[3D](https://github.com/HiLab-git/SSL4MIS/blob/master/code/train_regularized_dropout_3D.py)] 46 | ## Acknowledgement 47 | * Part of the code is adapted from open-source codebase and original implementations of algorithms, we thank these author for their fantastic and efficient codebase, such as, [UA-MT](https://github.com/yulequan/UA-MT), [Attention-Gated-Networks](https://github.com/ozan-oktay/Attention-Gated-Networks) and [segmentatic_segmentation.pytorch](https://github.com/qubvel/segmentation_models.pytorch) . 48 | -------------------------------------------------------------------------------- /code/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | 4 | from augmentations.ctaugment import * 5 | 6 | 7 | class StorableCTAugment(CTAugment): 8 | def load_state_dict(self, state): 9 | for k in ["decay", "depth", "th", "rates"]: 10 | assert k in state, "{} not in {}".format(k, state.keys()) 11 | setattr(self, k, state[k]) 12 | 13 | def state_dict(self): 14 | return OrderedDict( 15 | [(k, getattr(self, k)) for k in ["decay", "depth", "th", "rates"]] 16 | ) 17 | 18 | 19 | def get_default_cta(): 20 | return StorableCTAugment() 21 | 22 | 23 | def cta_apply(pil_img, ops): 24 | if ops is None: 25 | return pil_img 26 | for op, args in ops: 27 | pil_img = OPS[op].f(pil_img, *args) 28 | return pil_img 29 | 30 | 31 | def deserialize(policy_str): 32 | return [OP(f=x[0], bins=x[1]) for x in json.loads(policy_str)] 33 | 34 | 35 | def stats(cta): 36 | return "\n".join( 37 | "%-16s %s" 38 | % ( 39 | k, 40 | " / ".join( 41 | " ".join("%.2f" % x for x in cta.rate_to_p(rate)) 42 | for rate in cta.rates[k] 43 | ), 44 | ) 45 | for k in sorted(OPS.keys()) 46 | ) 47 | 48 | 49 | def interleave(x, batch, inverse=False): 50 | """ 51 | TF code 52 | def interleave(x, batch): 53 | s = x.get_shape().as_list() 54 | return tf.reshape(tf.transpose(tf.reshape(x, [-1, batch] + s[1:]), [1, 0] + list(range(2, 1+len(s)))), [-1] + s[1:]) 55 | """ 56 | shape = x.shape 57 | axes = [batch, -1] if inverse else [-1, batch] 58 | return x.reshape(*axes, *shape[1:]).transpose(0, 1).reshape(-1, *shape[1:]) 59 | 60 | 61 | def deinterleave(x, batch): 62 | return interleave(x, batch, inverse=True) 63 | -------------------------------------------------------------------------------- /code/augmentations/ctaugment.py: -------------------------------------------------------------------------------- 1 | # https://raw.githubusercontent.com/google-research/fixmatch/master/libml/ctaugment.py 2 | # 3 | # Copyright 2019 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Control Theory based self-augmentation, modified from https://github.com/vfdev-5/FixMatch-pytorch""" 17 | import random 18 | import torch 19 | from collections import namedtuple 20 | 21 | import numpy as np 22 | from scipy.ndimage.interpolation import zoom 23 | from PIL import Image, ImageOps, ImageEnhance, ImageFilter 24 | 25 | 26 | OPS = {} 27 | OP = namedtuple("OP", ("f", "bins")) 28 | Sample = namedtuple("Sample", ("train", "probe")) 29 | 30 | 31 | def register(*bins): 32 | def wrap(f): 33 | OPS[f.__name__] = OP(f, bins) 34 | return f 35 | 36 | return wrap 37 | 38 | 39 | class CTAugment(object): 40 | def __init__(self, depth=2, th=0.85, decay=0.99): 41 | self.decay = decay 42 | self.depth = depth 43 | self.th = th 44 | self.rates = {} 45 | for k, op in OPS.items(): 46 | self.rates[k] = tuple([np.ones(x, "f") for x in op.bins]) 47 | 48 | def rate_to_p(self, rate): 49 | p = rate + (1 - self.decay) # Avoid to have all zero. 50 | p = p / p.max() 51 | p[p < self.th] = 0 52 | return p 53 | 54 | def policy(self, probe, weak): 55 | num_strong_ops = 11 56 | kl_weak = list(OPS.keys())[num_strong_ops:] 57 | kl_strong = list(OPS.keys())[:num_strong_ops] 58 | 59 | if weak: 60 | kl = kl_weak 61 | else: 62 | kl = kl_strong 63 | 64 | v = [] 65 | if probe: 66 | for _ in range(self.depth): 67 | k = random.choice(kl) 68 | bins = self.rates[k] 69 | rnd = np.random.uniform(0, 1, len(bins)) 70 | v.append(OP(k, rnd.tolist())) 71 | return v 72 | for _ in range(self.depth): 73 | vt = [] 74 | k = random.choice(kl) 75 | bins = self.rates[k] 76 | rnd = np.random.uniform(0, 1, len(bins)) 77 | for r, bin in zip(rnd, bins): 78 | p = self.rate_to_p(bin) 79 | value = np.random.choice(p.shape[0], p=p / p.sum()) 80 | vt.append((value + r) / p.shape[0]) 81 | v.append(OP(k, vt)) 82 | return v 83 | 84 | def update_rates(self, policy, proximity): 85 | for k, bins in policy: 86 | for p, rate in zip(bins, self.rates[k]): 87 | p = int(p * len(rate) * 0.999) 88 | rate[p] = rate[p] * self.decay + proximity * (1 - self.decay) 89 | print(f"\t {k} weights updated") 90 | 91 | def stats(self): 92 | return "\n".join( 93 | "%-16s %s" 94 | % ( 95 | k, 96 | " / ".join( 97 | " ".join("%.2f" % x for x in self.rate_to_p(rate)) 98 | for rate in self.rates[k] 99 | ), 100 | ) 101 | for k in sorted(OPS.keys()) 102 | ) 103 | 104 | 105 | def _enhance(x, op, level): 106 | return op(x).enhance(0.1 + 1.9 * level) 107 | 108 | 109 | def _imageop(x, op, level): 110 | return Image.blend(x, op(x), level) 111 | 112 | 113 | def _filter(x, op, level): 114 | return Image.blend(x, x.filter(op), level) 115 | 116 | 117 | @register(17) 118 | def autocontrast(x, level): 119 | return _imageop(x, ImageOps.autocontrast, level) 120 | 121 | 122 | @register(17) 123 | def brightness(x, brightness): 124 | return _enhance(x, ImageEnhance.Brightness, brightness) 125 | 126 | 127 | @register(17) 128 | def color(x, color): 129 | return _enhance(x, ImageEnhance.Color, color) 130 | 131 | 132 | @register(17) 133 | def contrast(x, contrast): 134 | return _enhance(x, ImageEnhance.Contrast, contrast) 135 | 136 | 137 | @register(17) 138 | def equalize(x, level): 139 | return _imageop(x, ImageOps.equalize, level) 140 | 141 | 142 | @register(17) 143 | def invert(x, level): 144 | return _imageop(x, ImageOps.invert, level) 145 | 146 | 147 | @register(8) 148 | def posterize(x, level): 149 | level = 1 + int(level * 7.999) 150 | return ImageOps.posterize(x, level) 151 | 152 | 153 | @register(17) 154 | def solarize(x, th): 155 | th = int(th * 255.999) 156 | return ImageOps.solarize(x, th) 157 | 158 | 159 | @register(17) 160 | def smooth(x, level): 161 | return _filter(x, ImageFilter.SMOOTH, level) 162 | 163 | 164 | @register(17) 165 | def blur(x, level): 166 | return _filter(x, ImageFilter.BLUR, level) 167 | 168 | 169 | @register(17) 170 | def sharpness(x, sharpness): 171 | return _enhance(x, ImageEnhance.Sharpness, sharpness) 172 | 173 | 174 | # weak after here 175 | 176 | 177 | @register(17) 178 | def cutout(x, level): 179 | """Apply cutout to pil_img at the specified level.""" 180 | size = 1 + int(level * min(x.size) * 0.499) 181 | img_height, img_width = x.size 182 | height_loc = np.random.randint(low=img_height // 2, high=img_height) 183 | width_loc = np.random.randint(low=img_height // 2, high=img_width) 184 | upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) 185 | lower_coord = ( 186 | min(img_height, height_loc + size // 2), 187 | min(img_width, width_loc + size // 2), 188 | ) 189 | pixels = x.load() # create the pixel map 190 | for i in range(upper_coord[0], lower_coord[0]): # for every col: 191 | for j in range(upper_coord[1], lower_coord[1]): # For every row 192 | x.putpixel((i, j), 0) # set the color accordingly 193 | return x 194 | 195 | 196 | @register() 197 | def identity(x): 198 | return x 199 | 200 | 201 | @register(17, 6) 202 | def rescale(x, scale, method): 203 | s = x.size 204 | scale *= 0.25 205 | crop = (scale * s[0], scale * s[1], s[0] * (1 - scale), s[1] * (1 - scale)) 206 | methods = ( 207 | Image.ANTIALIAS, 208 | Image.BICUBIC, 209 | Image.BILINEAR, 210 | Image.BOX, 211 | Image.HAMMING, 212 | Image.NEAREST, 213 | ) 214 | method = methods[int(method * 5.99)] 215 | return x.crop(crop).resize(x.size, method) 216 | 217 | 218 | @register(17) 219 | def rotate(x, angle): 220 | angle = int(np.round((2 * angle - 1) * 45)) 221 | return x.rotate(angle) 222 | 223 | 224 | @register(17) 225 | def shear_x(x, shear): 226 | shear = (2 * shear - 1) * 0.3 227 | return x.transform(x.size, Image.AFFINE, (1, shear, 0, 0, 1, 0)) 228 | 229 | 230 | @register(17) 231 | def shear_y(x, shear): 232 | shear = (2 * shear - 1) * 0.3 233 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, shear, 1, 0)) 234 | 235 | 236 | @register(17) 237 | def translate_x(x, delta): 238 | delta = (2 * delta - 1) * 0.3 239 | return x.transform(x.size, Image.AFFINE, (1, 0, delta, 0, 1, 0)) 240 | 241 | 242 | @register(17) 243 | def translate_y(x, delta): 244 | delta = (2 * delta - 1) * 0.3 245 | return x.transform(x.size, Image.AFFINE, (1, 0, 0, 0, 1, delta)) 246 | -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' 51 | _C.MODEL.RESUME = '' 52 | # Number of classes, overwritten in data preparation 53 | _C.MODEL.NUM_CLASSES = 1000 54 | # Dropout rate 55 | _C.MODEL.DROP_RATE = 0.0 56 | # Drop path rate 57 | _C.MODEL.DROP_PATH_RATE = 0.1 58 | # Label Smoothing 59 | _C.MODEL.LABEL_SMOOTHING = 0.1 60 | 61 | # Swin Transformer parameters 62 | _C.MODEL.SWIN = CN() 63 | _C.MODEL.SWIN.PATCH_SIZE = 4 64 | _C.MODEL.SWIN.IN_CHANS = 3 65 | _C.MODEL.SWIN.EMBED_DIM = 96 66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 67 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] 68 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 69 | _C.MODEL.SWIN.WINDOW_SIZE = 7 70 | _C.MODEL.SWIN.MLP_RATIO = 4. 71 | _C.MODEL.SWIN.QKV_BIAS = True 72 | _C.MODEL.SWIN.QK_SCALE = None 73 | _C.MODEL.SWIN.APE = False 74 | _C.MODEL.SWIN.PATCH_NORM = True 75 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" 76 | 77 | # ----------------------------------------------------------------------------- 78 | # Training settings 79 | # ----------------------------------------------------------------------------- 80 | _C.TRAIN = CN() 81 | _C.TRAIN.START_EPOCH = 0 82 | _C.TRAIN.EPOCHS = 300 83 | _C.TRAIN.WARMUP_EPOCHS = 20 84 | _C.TRAIN.WEIGHT_DECAY = 0.05 85 | _C.TRAIN.BASE_LR = 5e-4 86 | _C.TRAIN.WARMUP_LR = 5e-7 87 | _C.TRAIN.MIN_LR = 5e-6 88 | # Clip gradient norm 89 | _C.TRAIN.CLIP_GRAD = 5.0 90 | # Auto resume from latest checkpoint 91 | _C.TRAIN.AUTO_RESUME = True 92 | # Gradient accumulation steps 93 | # could be overwritten by command line argument 94 | _C.TRAIN.ACCUMULATION_STEPS = 0 95 | # Whether to use gradient checkpointing to save memory 96 | # could be overwritten by command line argument 97 | _C.TRAIN.USE_CHECKPOINT = False 98 | 99 | # LR scheduler 100 | _C.TRAIN.LR_SCHEDULER = CN() 101 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 102 | # Epoch interval to decay LR, used in StepLRScheduler 103 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 104 | # LR decay rate, used in StepLRScheduler 105 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 106 | 107 | # Optimizer 108 | _C.TRAIN.OPTIMIZER = CN() 109 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 110 | # Optimizer Epsilon 111 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 112 | # Optimizer Betas 113 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 114 | # SGD momentum 115 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 116 | 117 | # ----------------------------------------------------------------------------- 118 | # Augmentation settings 119 | # ----------------------------------------------------------------------------- 120 | _C.AUG = CN() 121 | # Color jitter factor 122 | _C.AUG.COLOR_JITTER = 0.4 123 | # Use AutoAugment policy. "v0" or "original" 124 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 125 | # Random erase prob 126 | _C.AUG.REPROB = 0.25 127 | # Random erase mode 128 | _C.AUG.REMODE = 'pixel' 129 | # Random erase count 130 | _C.AUG.RECOUNT = 1 131 | # Mixup alpha, mixup enabled if > 0 132 | _C.AUG.MIXUP = 0.8 133 | # Cutmix alpha, cutmix enabled if > 0 134 | _C.AUG.CUTMIX = 1.0 135 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 136 | _C.AUG.CUTMIX_MINMAX = None 137 | # Probability of performing mixup or cutmix when either/both is enabled 138 | _C.AUG.MIXUP_PROB = 1.0 139 | # Probability of switching to cutmix when both mixup and cutmix enabled 140 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 141 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 142 | _C.AUG.MIXUP_MODE = 'batch' 143 | 144 | # ----------------------------------------------------------------------------- 145 | # Testing settings 146 | # ----------------------------------------------------------------------------- 147 | _C.TEST = CN() 148 | # Whether to use center crop when testing 149 | _C.TEST.CROP = True 150 | 151 | # ----------------------------------------------------------------------------- 152 | # Misc 153 | # ----------------------------------------------------------------------------- 154 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 155 | # overwritten by command line argument 156 | _C.AMP_OPT_LEVEL = '' 157 | # Path to output folder, overwritten by command line argument 158 | _C.OUTPUT = '' 159 | # Tag of experiment, overwritten by command line argument 160 | _C.TAG = 'default' 161 | # Frequency to save checkpoint 162 | _C.SAVE_FREQ = 1 163 | # Frequency to logging info 164 | _C.PRINT_FREQ = 10 165 | # Fixed random seed 166 | _C.SEED = 0 167 | # Perform evaluation only, overwritten by command line argument 168 | _C.EVAL_MODE = False 169 | # Test throughput only, overwritten by command line argument 170 | _C.THROUGHPUT_MODE = False 171 | # local rank for DistributedDataParallel, given by command line argument 172 | _C.LOCAL_RANK = 0 173 | 174 | 175 | def _update_config_from_file(config, cfg_file): 176 | config.defrost() 177 | with open(cfg_file, 'r') as f: 178 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 179 | 180 | for cfg in yaml_cfg.setdefault('BASE', ['']): 181 | if cfg: 182 | _update_config_from_file( 183 | config, os.path.join(os.path.dirname(cfg_file), cfg) 184 | ) 185 | print('=> merge config from {}'.format(cfg_file)) 186 | config.merge_from_file(cfg_file) 187 | config.freeze() 188 | 189 | 190 | def update_config(config, args): 191 | _update_config_from_file(config, args.cfg) 192 | 193 | config.defrost() 194 | if args.opts: 195 | config.merge_from_list(args.opts) 196 | 197 | # merge from specific arguments 198 | if args.batch_size: 199 | config.DATA.BATCH_SIZE = args.batch_size 200 | if args.zip: 201 | config.DATA.ZIP_MODE = True 202 | if args.cache_mode: 203 | config.DATA.CACHE_MODE = args.cache_mode 204 | if args.resume: 205 | config.MODEL.RESUME = args.resume 206 | if args.accumulation_steps: 207 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 208 | if args.use_checkpoint: 209 | config.TRAIN.USE_CHECKPOINT = True 210 | if args.amp_opt_level: 211 | config.AMP_OPT_LEVEL = args.amp_opt_level 212 | if args.tag: 213 | config.TAG = args.tag 214 | if args.eval: 215 | config.EVAL_MODE = True 216 | if args.throughput: 217 | config.THROUGHPUT_MODE = True 218 | 219 | config.freeze() 220 | 221 | 222 | def get_config(args): 223 | """Get a yacs CfgNode object with default values.""" 224 | # Return a clone so that the defaults will not be altered 225 | # This is for the "local variable" use pattern 226 | config = _C.clone() 227 | update_config(config, args) 228 | 229 | return config 230 | -------------------------------------------------------------------------------- /code/configs/swin_tiny_patch4_window7_224_lite.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | PRETRAIN_CKPT: "../code/pretrained_ckpt/swin_tiny_patch4_window7_224.pth" 6 | SWIN: 7 | FINAL_UPSAMPLE: "expand_first" 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 2, 2 ] 10 | DECODER_DEPTHS: [ 2, 2, 2, 1] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /code/dataloaders/acdc_data_processing.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import h5py 5 | import numpy as np 6 | import SimpleITK as sitk 7 | 8 | slice_num = 0 9 | mask_path = sorted(glob.glob("/home/xdluo/data/ACDC/image/*.nii.gz")) 10 | for case in mask_path: 11 | img_itk = sitk.ReadImage(case) 12 | origin = img_itk.GetOrigin() 13 | spacing = img_itk.GetSpacing() 14 | direction = img_itk.GetDirection() 15 | image = sitk.GetArrayFromImage(img_itk) 16 | msk_path = case.replace("image", "label").replace(".nii.gz", "_gt.nii.gz") 17 | if os.path.exists(msk_path): 18 | print(msk_path) 19 | msk_itk = sitk.ReadImage(msk_path) 20 | mask = sitk.GetArrayFromImage(msk_itk) 21 | image = (image - image.min()) / (image.max() - image.min()) 22 | print(image.shape) 23 | image = image.astype(np.float32) 24 | item = case.split("/")[-1].split(".")[0] 25 | if image.shape != mask.shape: 26 | print("Error") 27 | print(item) 28 | for slice_ind in range(image.shape[0]): 29 | f = h5py.File( 30 | '/home/xdluo/data/ACDC/data/{}_slice_{}.h5'.format(item, slice_ind), 'w') 31 | f.create_dataset( 32 | 'image', data=image[slice_ind], compression="gzip") 33 | f.create_dataset('label', data=mask[slice_ind], compression="gzip") 34 | f.close() 35 | slice_num += 1 36 | print("Converted all ACDC volumes to 2D slices") 37 | print("Total {} slices".format(slice_num)) 38 | -------------------------------------------------------------------------------- /code/dataloaders/brats2019.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from glob import glob 5 | from torch.utils.data import Dataset 6 | import h5py 7 | import itertools 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | class BraTS2019(Dataset): 12 | """ BraTS2019 Dataset """ 13 | 14 | def __init__(self, base_dir=None, split='train', num=None, transform=None): 15 | self._base_dir = base_dir 16 | self.transform = transform 17 | self.sample_list = [] 18 | 19 | train_path = self._base_dir+'/train.txt' 20 | test_path = self._base_dir+'/val.txt' 21 | 22 | if split == 'train': 23 | with open(train_path, 'r') as f: 24 | self.image_list = f.readlines() 25 | elif split == 'test': 26 | with open(test_path, 'r') as f: 27 | self.image_list = f.readlines() 28 | 29 | self.image_list = [item.replace('\n', '').split(",")[0] for item in self.image_list] 30 | if num is not None: 31 | self.image_list = self.image_list[:num] 32 | print("total {} samples".format(len(self.image_list))) 33 | 34 | def __len__(self): 35 | return len(self.image_list) 36 | 37 | def __getitem__(self, idx): 38 | image_name = self.image_list[idx] 39 | h5f = h5py.File(self._base_dir + "/data/{}.h5".format(image_name), 'r') 40 | image = h5f['image'][:] 41 | label = h5f['label'][:] 42 | sample = {'image': image, 'label': label.astype(np.uint8)} 43 | if self.transform: 44 | sample = self.transform(sample) 45 | return sample 46 | 47 | 48 | class CenterCrop(object): 49 | def __init__(self, output_size): 50 | self.output_size = output_size 51 | 52 | def __call__(self, sample): 53 | image, label = sample['image'], sample['label'] 54 | 55 | # pad the sample if necessary 56 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 57 | self.output_size[2]: 58 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 59 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 60 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 61 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], 62 | mode='constant', constant_values=0) 63 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], 64 | mode='constant', constant_values=0) 65 | 66 | (w, h, d) = image.shape 67 | 68 | w1 = int(round((w - self.output_size[0]) / 2.)) 69 | h1 = int(round((h - self.output_size[1]) / 2.)) 70 | d1 = int(round((d - self.output_size[2]) / 2.)) 71 | 72 | label = label[w1:w1 + self.output_size[0], h1:h1 + 73 | self.output_size[1], d1:d1 + self.output_size[2]] 74 | image = image[w1:w1 + self.output_size[0], h1:h1 + 75 | self.output_size[1], d1:d1 + self.output_size[2]] 76 | 77 | return {'image': image, 'label': label} 78 | 79 | 80 | class RandomCrop(object): 81 | """ 82 | Crop randomly the image in a sample 83 | Args: 84 | output_size (int): Desired output size 85 | """ 86 | 87 | def __init__(self, output_size, with_sdf=False): 88 | self.output_size = output_size 89 | self.with_sdf = with_sdf 90 | 91 | def __call__(self, sample): 92 | image, label = sample['image'], sample['label'] 93 | if self.with_sdf: 94 | sdf = sample['sdf'] 95 | 96 | # pad the sample if necessary 97 | if label.shape[0] <= self.output_size[0] or label.shape[1] <= self.output_size[1] or label.shape[2] <= \ 98 | self.output_size[2]: 99 | pw = max((self.output_size[0] - label.shape[0]) // 2 + 3, 0) 100 | ph = max((self.output_size[1] - label.shape[1]) // 2 + 3, 0) 101 | pd = max((self.output_size[2] - label.shape[2]) // 2 + 3, 0) 102 | image = np.pad(image, [(pw, pw), (ph, ph), (pd, pd)], 103 | mode='constant', constant_values=0) 104 | label = np.pad(label, [(pw, pw), (ph, ph), (pd, pd)], 105 | mode='constant', constant_values=0) 106 | if self.with_sdf: 107 | sdf = np.pad(sdf, [(pw, pw), (ph, ph), (pd, pd)], 108 | mode='constant', constant_values=0) 109 | 110 | (w, h, d) = image.shape 111 | # if np.random.uniform() > 0.33: 112 | # w1 = np.random.randint((w - self.output_size[0])//4, 3*(w - self.output_size[0])//4) 113 | # h1 = np.random.randint((h - self.output_size[1])//4, 3*(h - self.output_size[1])//4) 114 | # else: 115 | w1 = np.random.randint(0, w - self.output_size[0]) 116 | h1 = np.random.randint(0, h - self.output_size[1]) 117 | d1 = np.random.randint(0, d - self.output_size[2]) 118 | 119 | label = label[w1:w1 + self.output_size[0], h1:h1 + 120 | self.output_size[1], d1:d1 + self.output_size[2]] 121 | image = image[w1:w1 + self.output_size[0], h1:h1 + 122 | self.output_size[1], d1:d1 + self.output_size[2]] 123 | if self.with_sdf: 124 | sdf = sdf[w1:w1 + self.output_size[0], h1:h1 + 125 | self.output_size[1], d1:d1 + self.output_size[2]] 126 | return {'image': image, 'label': label, 'sdf': sdf} 127 | else: 128 | return {'image': image, 'label': label} 129 | 130 | 131 | class RandomRotFlip(object): 132 | """ 133 | Crop randomly flip the dataset in a sample 134 | Args: 135 | output_size (int): Desired output size 136 | """ 137 | 138 | def __call__(self, sample): 139 | image, label = sample['image'], sample['label'] 140 | k = np.random.randint(0, 4) 141 | image = np.rot90(image, k) 142 | label = np.rot90(label, k) 143 | axis = np.random.randint(0, 2) 144 | image = np.flip(image, axis=axis).copy() 145 | label = np.flip(label, axis=axis).copy() 146 | 147 | return {'image': image, 'label': label} 148 | 149 | 150 | class RandomNoise(object): 151 | def __init__(self, mu=0, sigma=0.1): 152 | self.mu = mu 153 | self.sigma = sigma 154 | 155 | def __call__(self, sample): 156 | image, label = sample['image'], sample['label'] 157 | noise = np.clip(self.sigma * np.random.randn( 158 | image.shape[0], image.shape[1], image.shape[2]), -2*self.sigma, 2*self.sigma) 159 | noise = noise + self.mu 160 | image = image + noise 161 | return {'image': image, 'label': label} 162 | 163 | 164 | class CreateOnehotLabel(object): 165 | def __init__(self, num_classes): 166 | self.num_classes = num_classes 167 | 168 | def __call__(self, sample): 169 | image, label = sample['image'], sample['label'] 170 | onehot_label = np.zeros( 171 | (self.num_classes, label.shape[0], label.shape[1], label.shape[2]), dtype=np.float32) 172 | for i in range(self.num_classes): 173 | onehot_label[i, :, :, :] = (label == i).astype(np.float32) 174 | return {'image': image, 'label': label, 'onehot_label': onehot_label} 175 | 176 | 177 | class ToTensor(object): 178 | """Convert ndarrays in sample to Tensors.""" 179 | 180 | def __call__(self, sample): 181 | image = sample['image'] 182 | image = image.reshape( 183 | 1, image.shape[0], image.shape[1], image.shape[2]).astype(np.float32) 184 | if 'onehot_label' in sample: 185 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long(), 186 | 'onehot_label': torch.from_numpy(sample['onehot_label']).long()} 187 | else: 188 | return {'image': torch.from_numpy(image), 'label': torch.from_numpy(sample['label']).long()} 189 | 190 | 191 | class TwoStreamBatchSampler(Sampler): 192 | """Iterate two sets of indices 193 | 194 | An 'epoch' is one iteration through the primary indices. 195 | During the epoch, the secondary indices are iterated through 196 | as many times as needed. 197 | """ 198 | 199 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 200 | self.primary_indices = primary_indices 201 | self.secondary_indices = secondary_indices 202 | self.secondary_batch_size = secondary_batch_size 203 | self.primary_batch_size = batch_size - secondary_batch_size 204 | 205 | assert len(self.primary_indices) >= self.primary_batch_size > 0 206 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 207 | 208 | def __iter__(self): 209 | primary_iter = iterate_once(self.primary_indices) 210 | secondary_iter = iterate_eternally(self.secondary_indices) 211 | return ( 212 | primary_batch + secondary_batch 213 | for (primary_batch, secondary_batch) 214 | in zip(grouper(primary_iter, self.primary_batch_size), 215 | grouper(secondary_iter, self.secondary_batch_size)) 216 | ) 217 | 218 | def __len__(self): 219 | return len(self.primary_indices) // self.primary_batch_size 220 | 221 | 222 | def iterate_once(iterable): 223 | return np.random.permutation(iterable) 224 | 225 | 226 | def iterate_eternally(indices): 227 | def infinite_shuffles(): 228 | while True: 229 | yield np.random.permutation(indices) 230 | return itertools.chain.from_iterable(infinite_shuffles()) 231 | 232 | 233 | def grouper(iterable, n): 234 | "Collect data into fixed-length chunks or blocks" 235 | # grouper('ABCDEFG', 3) --> ABC DEF" 236 | args = [iter(iterable)] * n 237 | return zip(*args) -------------------------------------------------------------------------------- /code/dataloaders/brats_proprecessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | from skimage import measure 5 | import nibabel as nib 6 | import SimpleITK as sitk 7 | import glob 8 | 9 | 10 | def brain_bbox(data, gt): 11 | mask = (data != 0) 12 | brain_voxels = np.where(mask != 0) 13 | minZidx = int(np.min(brain_voxels[0])) 14 | maxZidx = int(np.max(brain_voxels[0])) 15 | minXidx = int(np.min(brain_voxels[1])) 16 | maxXidx = int(np.max(brain_voxels[1])) 17 | minYidx = int(np.min(brain_voxels[2])) 18 | maxYidx = int(np.max(brain_voxels[2])) 19 | data_bboxed = data[minZidx:maxZidx, minXidx:maxXidx, minYidx:maxYidx] 20 | gt_bboxed = gt[minZidx:maxZidx, minXidx:maxXidx, minYidx:maxYidx] 21 | return data_bboxed, gt_bboxed 22 | 23 | 24 | def volume_bounding_box(data, gt, expend=0, status="train"): 25 | data, gt = brain_bbox(data, gt) 26 | print(data.shape) 27 | mask = (gt != 0) 28 | brain_voxels = np.where(mask != 0) 29 | z, x, y = data.shape 30 | minZidx = int(np.min(brain_voxels[0])) 31 | maxZidx = int(np.max(brain_voxels[0])) 32 | minXidx = int(np.min(brain_voxels[1])) 33 | maxXidx = int(np.max(brain_voxels[1])) 34 | minYidx = int(np.min(brain_voxels[2])) 35 | maxYidx = int(np.max(brain_voxels[2])) 36 | 37 | minZidx_jitterd = max(minZidx - expend, 0) 38 | maxZidx_jitterd = min(maxZidx + expend, z) 39 | minXidx_jitterd = max(minXidx - expend, 0) 40 | maxXidx_jitterd = min(maxXidx + expend, x) 41 | minYidx_jitterd = max(minYidx - expend, 0) 42 | maxYidx_jitterd = min(maxYidx + expend, y) 43 | 44 | data_bboxed = data[minZidx_jitterd:maxZidx_jitterd, 45 | minXidx_jitterd:maxXidx_jitterd, minYidx_jitterd:maxYidx_jitterd] 46 | print([minZidx, maxZidx, minXidx, maxXidx, minYidx, maxYidx]) 47 | print([minZidx_jitterd, maxZidx_jitterd, 48 | minXidx_jitterd, maxXidx_jitterd, minYidx_jitterd, maxYidx_jitterd]) 49 | 50 | if status == "train": 51 | gt_bboxed = np.zeros_like(data_bboxed, dtype=np.uint8) 52 | gt_bboxed[expend:maxZidx_jitterd-expend, expend:maxXidx_jitterd - 53 | expend, expend:maxYidx_jitterd - expend] = 1 54 | return data_bboxed, gt_bboxed 55 | 56 | if status == "test": 57 | gt_bboxed = gt[minZidx_jitterd:maxZidx_jitterd, 58 | minXidx_jitterd:maxXidx_jitterd, minYidx_jitterd:maxYidx_jitterd] 59 | return data_bboxed, gt_bboxed 60 | 61 | 62 | def itensity_normalize_one_volume(volume): 63 | """ 64 | normalize the itensity of an nd volume based on the mean and std of nonzeor region 65 | inputs: 66 | volume: the input nd volume 67 | outputs: 68 | out: the normalized nd volume 69 | """ 70 | 71 | pixels = volume[volume > 0] 72 | mean = pixels.mean() 73 | std = pixels.std() 74 | out = (volume - mean)/std 75 | out_random = np.random.normal(0, 1, size=volume.shape) 76 | # out[volume == 0] = out_random[volume == 0] 77 | out = out.astype(np.float32) 78 | return out 79 | 80 | 81 | class MedicalImageDeal(object): 82 | def __init__(self, img, percent=1): 83 | self.img = img 84 | self.percent = percent 85 | 86 | @property 87 | def valid_img(self): 88 | from skimage import exposure 89 | cdf = exposure.cumulative_distribution(self.img) 90 | watershed = cdf[1][cdf[0] >= self.percent][0] 91 | return np.clip(self.img, self.img.min(), watershed) 92 | 93 | @property 94 | def norm_img(self): 95 | return (self.img - self.img.min()) / (self.img.max() - self.img.min()) 96 | 97 | 98 | all_flair = glob.glob("flair/*_flair.nii.gz") 99 | for p in all_flair: 100 | data = sitk.GetArrayFromImage(sitk.ReadImage(p)) 101 | lab = sitk.GetArrayFromImage(sitk.ReadImage(p.replace("flair", "seg"))) 102 | img, lab = brain_bbox(data, lab) 103 | img = MedicalImageDeal(img, percent=0.999).valid_img 104 | img = itensity_normalize_one_volume(img) 105 | lab[lab > 0] = 1 106 | uid = p.split("/")[-1] 107 | sitk.WriteImage(sitk.GetImageFromArray( 108 | img), "/media/xdluo/Data/brats19/data/flair/{}".format(uid)) 109 | sitk.WriteImage(sitk.GetImageFromArray( 110 | lab), "/media/xdluo/Data/brats19/data/label/{}".format(uid)) 111 | -------------------------------------------------------------------------------- /code/dataloaders/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import random 5 | import numpy as np 6 | from glob import glob 7 | from torch.utils.data import Dataset 8 | import h5py 9 | from scipy.ndimage.interpolation import zoom 10 | from torchvision import transforms 11 | import itertools 12 | from scipy import ndimage 13 | from torch.utils.data.sampler import Sampler 14 | import augmentations 15 | from augmentations.ctaugment import OPS 16 | import matplotlib.pyplot as plt 17 | from PIL import Image 18 | 19 | 20 | class BaseDataSets(Dataset): 21 | def __init__( 22 | self, 23 | base_dir=None, 24 | split="train", 25 | num=None, 26 | transform=None, 27 | ops_weak=None, 28 | ops_strong=None, 29 | ): 30 | self._base_dir = base_dir 31 | self.sample_list = [] 32 | self.split = split 33 | self.transform = transform 34 | self.ops_weak = ops_weak 35 | self.ops_strong = ops_strong 36 | 37 | assert bool(ops_weak) == bool( 38 | ops_strong 39 | ), "For using CTAugment learned policies, provide both weak and strong batch augmentation policy" 40 | 41 | if self.split == "train": 42 | with open(self._base_dir + "/train_slices.list", "r") as f1: 43 | self.sample_list = f1.readlines() 44 | self.sample_list = [item.replace("\n", "") for item in self.sample_list] 45 | 46 | elif self.split == "val": 47 | with open(self._base_dir + "/val.list", "r") as f: 48 | self.sample_list = f.readlines() 49 | self.sample_list = [item.replace("\n", "") for item in self.sample_list] 50 | if num is not None and self.split == "train": 51 | self.sample_list = self.sample_list[:num] 52 | print("total {} samples".format(len(self.sample_list))) 53 | 54 | def __len__(self): 55 | return len(self.sample_list) 56 | 57 | def __getitem__(self, idx): 58 | case = self.sample_list[idx] 59 | if self.split == "train": 60 | h5f = h5py.File(self._base_dir + "/data/slices/{}.h5".format(case), "r") 61 | else: 62 | h5f = h5py.File(self._base_dir + "/data/{}.h5".format(case), "r") 63 | image = h5f["image"][:] 64 | label = h5f["label"][:] 65 | sample = {"image": image, "label": label} 66 | if self.split == "train": 67 | if None not in (self.ops_weak, self.ops_strong): 68 | sample = self.transform(sample, self.ops_weak, self.ops_strong) 69 | else: 70 | sample = self.transform(sample) 71 | sample["idx"] = idx 72 | return sample 73 | 74 | 75 | def random_rot_flip(image, label=None): 76 | k = np.random.randint(0, 4) 77 | image = np.rot90(image, k) 78 | axis = np.random.randint(0, 2) 79 | image = np.flip(image, axis=axis).copy() 80 | if label is not None: 81 | label = np.rot90(label, k) 82 | label = np.flip(label, axis=axis).copy() 83 | return image, label 84 | else: 85 | return image 86 | 87 | 88 | def random_rotate(image, label): 89 | angle = np.random.randint(-20, 20) 90 | image = ndimage.rotate(image, angle, order=0, reshape=False) 91 | label = ndimage.rotate(label, angle, order=0, reshape=False) 92 | return image, label 93 | 94 | 95 | def color_jitter(image): 96 | if not torch.is_tensor(image): 97 | np_to_tensor = transforms.ToTensor() 98 | image = np_to_tensor(image) 99 | 100 | # s is the strength of color distortion. 101 | s = 1.0 102 | jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 103 | return jitter(image) 104 | 105 | 106 | class CTATransform(object): 107 | def __init__(self, output_size, cta): 108 | self.output_size = output_size 109 | self.cta = cta 110 | 111 | def __call__(self, sample, ops_weak, ops_strong): 112 | image, label = sample["image"], sample["label"] 113 | image = self.resize(image) 114 | label = self.resize(label) 115 | to_tensor = transforms.ToTensor() 116 | 117 | # fix dimensions 118 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 119 | label = torch.from_numpy(label.astype(np.uint8)) 120 | 121 | # apply augmentations 122 | image_weak = augmentations.cta_apply(transforms.ToPILImage()(image), ops_weak) 123 | image_strong = augmentations.cta_apply(image_weak, ops_strong) 124 | label_aug = augmentations.cta_apply(transforms.ToPILImage()(label), ops_weak) 125 | label_aug = to_tensor(label_aug).squeeze(0) 126 | label_aug = torch.round(255 * label_aug).int() 127 | 128 | sample = { 129 | "image_weak": to_tensor(image_weak), 130 | "image_strong": to_tensor(image_strong), 131 | "label_aug": label_aug, 132 | } 133 | return sample 134 | 135 | def cta_apply(self, pil_img, ops): 136 | if ops is None: 137 | return pil_img 138 | for op, args in ops: 139 | pil_img = OPS[op].f(pil_img, *args) 140 | return pil_img 141 | 142 | def resize(self, image): 143 | x, y = image.shape 144 | return zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=0) 145 | 146 | 147 | class RandomGenerator(object): 148 | def __init__(self, output_size): 149 | self.output_size = output_size 150 | 151 | def __call__(self, sample): 152 | image, label = sample["image"], sample["label"] 153 | # ind = random.randrange(0, img.shape[0]) 154 | # image = img[ind, ...] 155 | # label = lab[ind, ...] 156 | rand_value = random.random() 157 | if rand_value < 0.5: 158 | image, label = random_rot_flip(image, label) 159 | else: 160 | image, label = random_rotate(image, label) 161 | x, y = image.shape 162 | image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=0) 163 | label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0) 164 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 165 | label = torch.from_numpy(label.astype(np.uint8)) 166 | sample = {"image": image, "label": label} 167 | return sample 168 | 169 | 170 | class WeakStrongAugment(object): 171 | """returns weakly and strongly augmented images 172 | 173 | Args: 174 | object (tuple): output size of network 175 | """ 176 | 177 | def __init__(self, output_size): 178 | self.output_size = output_size 179 | 180 | def __call__(self, sample): 181 | image, label = sample["image"], sample["label"] 182 | image = self.resize(image) 183 | label = self.resize(label) 184 | # weak augmentation is rotation / flip 185 | image_weak, label = random_rot_flip(image, label) 186 | # strong augmentation is color jitter 187 | image_strong = color_jitter(image_weak).type("torch.FloatTensor") 188 | # fix dimensions 189 | image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0) 190 | image_weak = torch.from_numpy(image_weak.astype(np.float32)).unsqueeze(0) 191 | label = torch.from_numpy(label.astype(np.uint8)) 192 | 193 | sample = { 194 | "image": image, 195 | "image_weak": image_weak, 196 | "image_strong": image_strong, 197 | "label_aug": label, 198 | } 199 | return sample 200 | 201 | def resize(self, image): 202 | x, y = image.shape 203 | return zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=0) 204 | 205 | 206 | class TwoStreamBatchSampler(Sampler): 207 | """Iterate two sets of indices 208 | 209 | An 'epoch' is one iteration through the primary indices. 210 | During the epoch, the secondary indices are iterated through 211 | as many times as needed. 212 | """ 213 | 214 | def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size): 215 | self.primary_indices = primary_indices 216 | self.secondary_indices = secondary_indices 217 | self.secondary_batch_size = secondary_batch_size 218 | self.primary_batch_size = batch_size - secondary_batch_size 219 | 220 | assert len(self.primary_indices) >= self.primary_batch_size > 0 221 | assert len(self.secondary_indices) >= self.secondary_batch_size > 0 222 | 223 | def __iter__(self): 224 | primary_iter = iterate_once(self.primary_indices) 225 | secondary_iter = iterate_eternally(self.secondary_indices) 226 | return ( 227 | primary_batch + secondary_batch 228 | for (primary_batch, secondary_batch) in zip( 229 | grouper(primary_iter, self.primary_batch_size), 230 | grouper(secondary_iter, self.secondary_batch_size), 231 | ) 232 | ) 233 | 234 | def __len__(self): 235 | return len(self.primary_indices) // self.primary_batch_size 236 | 237 | 238 | def iterate_once(iterable): 239 | return np.random.permutation(iterable) 240 | 241 | 242 | def iterate_eternally(indices): 243 | def infinite_shuffles(): 244 | while True: 245 | yield np.random.permutation(indices) 246 | 247 | return itertools.chain.from_iterable(infinite_shuffles()) 248 | 249 | 250 | def grouper(iterable, n): 251 | "Collect data into fixed-length chunks or blocks" 252 | # grouper('ABCDEFG', 3) --> ABC DEF" 253 | args = [iter(iterable)] * n 254 | return zip(*args) 255 | -------------------------------------------------------------------------------- /code/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | # import matplotlib.pyplot as plt 6 | from skimage import measure 7 | import scipy.ndimage as nd 8 | 9 | 10 | def recursive_glob(rootdir='.', suffix=''): 11 | """Performs recursive glob with given suffix and rootdir 12 | :param rootdir is the root directory 13 | :param suffix is the suffix to be searched 14 | """ 15 | return [os.path.join(looproot, filename) 16 | for looproot, _, filenames in os.walk(rootdir) 17 | for filename in filenames if filename.endswith(suffix)] 18 | 19 | def get_cityscapes_labels(): 20 | return np.array([ 21 | # [ 0, 0, 0], 22 | [128, 64, 128], 23 | [244, 35, 232], 24 | [70, 70, 70], 25 | [102, 102, 156], 26 | [190, 153, 153], 27 | [153, 153, 153], 28 | [250, 170, 30], 29 | [220, 220, 0], 30 | [107, 142, 35], 31 | [152, 251, 152], 32 | [0, 130, 180], 33 | [220, 20, 60], 34 | [255, 0, 0], 35 | [0, 0, 142], 36 | [0, 0, 70], 37 | [0, 60, 100], 38 | [0, 80, 100], 39 | [0, 0, 230], 40 | [119, 11, 32]]) 41 | 42 | def get_pascal_labels(): 43 | """Load the mapping that associates pascal classes with label colors 44 | Returns: 45 | np.ndarray with dimensions (21, 3) 46 | """ 47 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 48 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 49 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 50 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 51 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 52 | [0, 64, 128]]) 53 | 54 | 55 | def encode_segmap(mask): 56 | """Encode segmentation label images as pascal classes 57 | Args: 58 | mask (np.ndarray): raw segmentation label image of dimension 59 | (M, N, 3), in which the Pascal classes are encoded as colours. 60 | Returns: 61 | (np.ndarray): class map with dimensions (M,N), where the value at 62 | a given location is the integer denoting the class index. 63 | """ 64 | mask = mask.astype(int) 65 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 66 | for ii, label in enumerate(get_pascal_labels()): 67 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 68 | label_mask = label_mask.astype(int) 69 | return label_mask 70 | 71 | 72 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 73 | rgb_masks = [] 74 | for label_mask in label_masks: 75 | rgb_mask = decode_segmap(label_mask, dataset) 76 | rgb_masks.append(rgb_mask) 77 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 78 | return rgb_masks 79 | 80 | def decode_segmap(label_mask, dataset, plot=False): 81 | """Decode segmentation class labels into a color image 82 | Args: 83 | label_mask (np.ndarray): an (M,N) array of integer values denoting 84 | the class label at each spatial location. 85 | plot (bool, optional): whether to show the resulting color image 86 | in a figure. 87 | Returns: 88 | (np.ndarray, optional): the resulting decoded color image. 89 | """ 90 | if dataset == 'pascal': 91 | n_classes = 21 92 | label_colours = get_pascal_labels() 93 | elif dataset == 'cityscapes': 94 | n_classes = 19 95 | label_colours = get_cityscapes_labels() 96 | else: 97 | raise NotImplementedError 98 | 99 | r = label_mask.copy() 100 | g = label_mask.copy() 101 | b = label_mask.copy() 102 | for ll in range(0, n_classes): 103 | r[label_mask == ll] = label_colours[ll, 0] 104 | g[label_mask == ll] = label_colours[ll, 1] 105 | b[label_mask == ll] = label_colours[ll, 2] 106 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 107 | rgb[:, :, 0] = r / 255.0 108 | rgb[:, :, 1] = g / 255.0 109 | rgb[:, :, 2] = b / 255.0 110 | if plot: 111 | plt.imshow(rgb) 112 | plt.show() 113 | else: 114 | return rgb 115 | 116 | def generate_param_report(logfile, param): 117 | log_file = open(logfile, 'w') 118 | # for key, val in param.items(): 119 | # log_file.write(key + ':' + str(val) + '\n') 120 | log_file.write(str(param)) 121 | log_file.close() 122 | 123 | def cross_entropy2d(logit, target, ignore_index=255, weight=None, size_average=True, batch_average=True): 124 | n, c, h, w = logit.size() 125 | # logit = logit.permute(0, 2, 3, 1) 126 | target = target.squeeze(1) 127 | if weight is None: 128 | criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, size_average=False) 129 | else: 130 | criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array(weight)).float().cuda(), ignore_index=ignore_index, size_average=False) 131 | loss = criterion(logit, target.long()) 132 | 133 | if size_average: 134 | loss /= (h * w) 135 | 136 | if batch_average: 137 | loss /= n 138 | 139 | return loss 140 | 141 | def lr_poly(base_lr, iter_, max_iter=100, power=0.9): 142 | return base_lr * ((1 - float(iter_) / max_iter) ** power) 143 | 144 | 145 | def get_iou(pred, gt, n_classes=21): 146 | total_iou = 0.0 147 | for i in range(len(pred)): 148 | pred_tmp = pred[i] 149 | gt_tmp = gt[i] 150 | 151 | intersect = [0] * n_classes 152 | union = [0] * n_classes 153 | for j in range(n_classes): 154 | match = (pred_tmp == j) + (gt_tmp == j) 155 | 156 | it = torch.sum(match == 2).item() 157 | un = torch.sum(match > 0).item() 158 | 159 | intersect[j] += it 160 | union[j] += un 161 | 162 | iou = [] 163 | for k in range(n_classes): 164 | if union[k] == 0: 165 | continue 166 | iou.append(intersect[k] / union[k]) 167 | 168 | img_iou = (sum(iou) / len(iou)) 169 | total_iou += img_iou 170 | 171 | return total_iou 172 | 173 | def get_dice(pred, gt): 174 | total_dice = 0.0 175 | pred = pred.long() 176 | gt = gt.long() 177 | for i in range(len(pred)): 178 | pred_tmp = pred[i] 179 | gt_tmp = gt[i] 180 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 181 | print(dice) 182 | total_dice += dice 183 | 184 | return total_dice 185 | 186 | def get_mc_dice(pred, gt, num=2): 187 | # num is the total number of classes, include the background 188 | total_dice = np.zeros(num-1) 189 | pred = pred.long() 190 | gt = gt.long() 191 | for i in range(len(pred)): 192 | for j in range(1, num): 193 | pred_tmp = (pred[i]==j) 194 | gt_tmp = (gt[i]==j) 195 | dice = 2.0*torch.sum(pred_tmp*gt_tmp).item()/(1.0+torch.sum(pred_tmp**2)+torch.sum(gt_tmp**2)).item() 196 | total_dice[j-1] +=dice 197 | return total_dice 198 | 199 | def post_processing(prediction): 200 | prediction = nd.binary_fill_holes(prediction) 201 | label_cc, num_cc = measure.label(prediction,return_num=True) 202 | total_cc = np.sum(prediction) 203 | measure.regionprops(label_cc) 204 | for cc in range(1,num_cc+1): 205 | single_cc = (label_cc==cc) 206 | single_vol = np.sum(single_cc) 207 | if single_vol/total_cc<0.2: 208 | prediction[single_cc]=0 209 | 210 | return prediction 211 | 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /code/networks/VoxResNet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import print_function, division 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class SEBlock(nn.Module): 10 | def __init__(self, in_channels, r): 11 | super(SEBlock, self).__init__() 12 | 13 | redu_chns = int(in_channels / r) 14 | self.se_layers = nn.Sequential( 15 | nn.AdaptiveAvgPool3d(1), 16 | nn.Conv3d(in_channels, redu_chns, kernel_size=1, padding=0), 17 | nn.ReLU(), 18 | nn.Conv3d(redu_chns, in_channels, kernel_size=1, padding=0), 19 | nn.ReLU()) 20 | 21 | def forward(self, x): 22 | f = self.se_layers(x) 23 | return f * x + x 24 | 25 | 26 | class VoxRex(nn.Module): 27 | def __init__(self, in_channels): 28 | super(VoxRex, self).__init__() 29 | self.block = nn.Sequential( 30 | nn.InstanceNorm3d(in_channels), 31 | nn.ReLU(inplace=True), 32 | nn.Conv3d(in_channels, in_channels, 33 | kernel_size=3, padding=1, bias=False), 34 | nn.InstanceNorm3d(in_channels), 35 | nn.ReLU(inplace=True), 36 | nn.Conv3d(in_channels, in_channels, 37 | kernel_size=3, padding=1, bias=False) 38 | ) 39 | 40 | def forward(self, x): 41 | return self.block(x)+x 42 | 43 | 44 | class ConvBlock(nn.Module): 45 | """two convolution layers with batch norm and leaky relu""" 46 | 47 | def __init__(self, in_channels, out_channels): 48 | super(ConvBlock, self).__init__() 49 | self.conv_conv = nn.Sequential( 50 | nn.InstanceNorm3d(in_channels), 51 | nn.ReLU(inplace=True), 52 | nn.Conv3d(in_channels, out_channels, 53 | kernel_size=3, padding=1, bias=False), 54 | nn.InstanceNorm3d(out_channels), 55 | nn.ReLU(inplace=True), 56 | nn.Conv3d(out_channels, out_channels, 57 | kernel_size=3, padding=1, bias=False) 58 | ) 59 | 60 | def forward(self, x): 61 | return self.conv_conv(x) 62 | 63 | 64 | class UpBlock(nn.Module): 65 | """Upssampling followed by ConvBlock""" 66 | 67 | def __init__(self, in_channels, out_channels): 68 | super(UpBlock, self).__init__() 69 | self.up = nn.Upsample( 70 | scale_factor=2, mode='trilinear', align_corners=True) 71 | self.conv = ConvBlock(in_channels, out_channels) 72 | 73 | def forward(self, x1, x2): 74 | x1 = self.up(x1) 75 | x = torch.cat([x2, x1], dim=1) 76 | return self.conv(x) 77 | 78 | 79 | class VoxResNet(nn.Module): 80 | def __init__(self, in_chns=1, feature_chns=64, class_num=2): 81 | super(VoxResNet, self).__init__() 82 | self.in_chns = in_chns 83 | self.ft_chns = feature_chns 84 | self.n_class = class_num 85 | 86 | self.conv1 = nn.Conv3d(in_chns, feature_chns, kernel_size=3, padding=1) 87 | self.res1 = VoxRex(feature_chns) 88 | self.res2 = VoxRex(feature_chns) 89 | self.res3 = VoxRex(feature_chns) 90 | self.res4 = VoxRex(feature_chns) 91 | self.res5 = VoxRex(feature_chns) 92 | self.res6 = VoxRex(feature_chns) 93 | 94 | self.up1 = UpBlock(feature_chns * 2, feature_chns) 95 | self.up2 = UpBlock(feature_chns * 2, feature_chns) 96 | 97 | self.out = nn.Conv3d(feature_chns, self.n_class, kernel_size=1) 98 | 99 | self.maxpool = nn.MaxPool3d(2) 100 | self.upsample = nn.Upsample( 101 | scale_factor=2, mode='trilinear', align_corners=True) 102 | 103 | def forward(self, x): 104 | x = self.maxpool(self.conv1(x)) 105 | x1 = self.res1(x) 106 | x2 = self.res2(x1) 107 | x2_pool = self.maxpool(x2) 108 | x3 = self.res3(x2_pool) 109 | x4 = self.maxpool(self.res4(x3)) 110 | x5 = self.res5(x4) 111 | x6 = self.res6(x5) 112 | up1 = self.up1(x6, x2_pool) 113 | up2 = self.up2(up1, x) 114 | up = self.upsample(up2) 115 | out = self.out(up) 116 | return out 117 | -------------------------------------------------------------------------------- /code/networks/attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | try: 4 | from inplace_abn import InPlaceABN 5 | except ImportError: 6 | InPlaceABN = None 7 | 8 | 9 | class Conv2dReLU(nn.Sequential): 10 | def __init__( 11 | self, 12 | in_channels, 13 | out_channels, 14 | kernel_size, 15 | padding=0, 16 | stride=1, 17 | use_batchnorm=True, 18 | ): 19 | 20 | if use_batchnorm == "inplace" and InPlaceABN is None: 21 | raise RuntimeError( 22 | "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " 23 | + "To install see: https://github.com/mapillary/inplace_abn" 24 | ) 25 | 26 | super().__init__() 27 | 28 | conv = nn.Conv2d( 29 | in_channels, 30 | out_channels, 31 | kernel_size, 32 | stride=stride, 33 | padding=padding, 34 | bias=not (use_batchnorm), 35 | ) 36 | relu = nn.ReLU(inplace=True) 37 | 38 | if use_batchnorm == "inplace": 39 | bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) 40 | relu = nn.Identity() 41 | 42 | elif use_batchnorm and use_batchnorm != "inplace": 43 | bn = nn.BatchNorm2d(out_channels) 44 | 45 | else: 46 | bn = nn.Identity() 47 | 48 | super(Conv2dReLU, self).__init__(conv, bn, relu) 49 | 50 | 51 | class SCSEModule(nn.Module): 52 | def __init__(self, in_channels, reduction=16): 53 | super().__init__() 54 | self.cSE = nn.Sequential( 55 | nn.AdaptiveAvgPool2d(1), 56 | nn.Conv2d(in_channels, in_channels // reduction, 1), 57 | nn.ReLU(inplace=True), 58 | nn.Conv2d(in_channels // reduction, in_channels, 1), 59 | nn.Sigmoid(), 60 | ) 61 | self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) 62 | 63 | def forward(self, x): 64 | return x * self.cSE(x) + x * self.sSE(x) 65 | 66 | 67 | class Activation(nn.Module): 68 | 69 | def __init__(self, name, **params): 70 | 71 | super().__init__() 72 | 73 | if name is None or name == 'identity': 74 | self.activation = nn.Identity(**params) 75 | elif name == 'sigmoid': 76 | self.activation = nn.Sigmoid() 77 | elif name == 'softmax2d': 78 | self.activation = nn.Softmax(dim=1, **params) 79 | elif name == 'softmax': 80 | self.activation = nn.Softmax(**params) 81 | elif name == 'logsoftmax': 82 | self.activation = nn.LogSoftmax(**params) 83 | elif callable(name): 84 | self.activation = name(**params) 85 | else: 86 | raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/None; got {}'.format(name)) 87 | 88 | def forward(self, x): 89 | return self.activation(x) 90 | 91 | 92 | class Attention(nn.Module): 93 | 94 | def __init__(self, name, **params): 95 | super().__init__() 96 | 97 | if name is None: 98 | self.attention = nn.Identity(**params) 99 | elif name == 'scse': 100 | self.attention = SCSEModule(**params) 101 | else: 102 | raise ValueError("Attention {} is not implemented".format(name)) 103 | 104 | def forward(self, x): 105 | return self.attention(x) 106 | 107 | 108 | class Flatten(nn.Module): 109 | def forward(self, x): 110 | return x.view(x.shape[0], -1) 111 | -------------------------------------------------------------------------------- /code/networks/attention_unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from networks.utils import UnetConv3, UnetUp3_CT, UnetGridGatingSignal3, UnetDsv3 4 | import torch.nn.functional as F 5 | from networks.networks_other import init_weights 6 | from networks.grid_attention_layer import GridAttentionBlock3D 7 | 8 | 9 | class Attention_UNet(nn.Module): 10 | 11 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, 12 | nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True): 13 | super(Attention_UNet, self).__init__() 14 | self.is_deconv = is_deconv 15 | self.in_channels = in_channels 16 | self.is_batchnorm = is_batchnorm 17 | self.feature_scale = feature_scale 18 | 19 | filters = [64, 128, 256, 512, 1024] 20 | filters = [int(x / self.feature_scale) for x in filters] 21 | 22 | # downsampling 23 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 24 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 25 | 26 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 27 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 28 | 29 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 30 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 31 | 32 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 33 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 34 | 35 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1)) 36 | self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1, 1, 1), is_batchnorm=self.is_batchnorm) 37 | 38 | # attention blocks 39 | self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1], 40 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 41 | self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2], 42 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 43 | self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3], 44 | nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample) 45 | 46 | # upsampling 47 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 48 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 49 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 50 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 51 | 52 | # deep supervision 53 | self.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8) 54 | self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4) 55 | self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2) 56 | self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1) 57 | 58 | # final conv (without any concat) 59 | self.final = nn.Conv3d(n_classes*4, n_classes, 1) 60 | 61 | # initialise weights 62 | for m in self.modules(): 63 | if isinstance(m, nn.Conv3d): 64 | init_weights(m, init_type='kaiming') 65 | elif isinstance(m, nn.BatchNorm3d): 66 | init_weights(m, init_type='kaiming') 67 | 68 | def forward(self, inputs): 69 | # Feature Extraction 70 | conv1 = self.conv1(inputs) 71 | maxpool1 = self.maxpool1(conv1) 72 | 73 | conv2 = self.conv2(maxpool1) 74 | maxpool2 = self.maxpool2(conv2) 75 | 76 | conv3 = self.conv3(maxpool2) 77 | maxpool3 = self.maxpool3(conv3) 78 | 79 | conv4 = self.conv4(maxpool3) 80 | maxpool4 = self.maxpool4(conv4) 81 | 82 | # Gating Signal Generation 83 | center = self.center(maxpool4) 84 | gating = self.gating(center) 85 | 86 | # Attention Mechanism 87 | # Upscaling Part (Decoder) 88 | g_conv4, att4 = self.attentionblock4(conv4, gating) 89 | up4 = self.up_concat4(g_conv4, center) 90 | g_conv3, att3 = self.attentionblock3(conv3, up4) 91 | up3 = self.up_concat3(g_conv3, up4) 92 | g_conv2, att2 = self.attentionblock2(conv2, up3) 93 | up2 = self.up_concat2(g_conv2, up3) 94 | up1 = self.up_concat1(conv1, up2) 95 | 96 | # Deep Supervision 97 | dsv4 = self.dsv4(up4) 98 | dsv3 = self.dsv3(up3) 99 | dsv2 = self.dsv2(up2) 100 | dsv1 = self.dsv1(up1) 101 | final = self.final(torch.cat([dsv1,dsv2,dsv3,dsv4], dim=1)) 102 | 103 | return final 104 | 105 | 106 | @staticmethod 107 | def apply_argmax_softmax(pred): 108 | log_p = F.softmax(pred, dim=1) 109 | 110 | return log_p 111 | 112 | 113 | class MultiAttentionBlock(nn.Module): 114 | def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor): 115 | super(MultiAttentionBlock, self).__init__() 116 | self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, 117 | inter_channels=inter_size, mode=nonlocal_mode, 118 | sub_sample_factor= sub_sample_factor) 119 | self.gate_block_2 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size, 120 | inter_channels=inter_size, mode=nonlocal_mode, 121 | sub_sample_factor=sub_sample_factor) 122 | self.combine_gates = nn.Sequential(nn.Conv3d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), 123 | nn.BatchNorm3d(in_size), 124 | nn.ReLU(inplace=True) 125 | ) 126 | 127 | # initialise the blocks 128 | for m in self.children(): 129 | if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue 130 | init_weights(m, init_type='kaiming') 131 | 132 | def forward(self, input, gating_signal): 133 | gate_1, attention_1 = self.gate_block_1(input, gating_signal) 134 | gate_2, attention_2 = self.gate_block_2(input, gating_signal) 135 | 136 | return self.combine_gates(torch.cat([gate_1, gate_2], 1)), torch.cat([attention_1, attention_2], 1) -------------------------------------------------------------------------------- /code/networks/config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import yaml 10 | from yacs.config import CfgNode as CN 11 | 12 | _C = CN() 13 | 14 | # Base config files 15 | _C.BASE = [''] 16 | 17 | # ----------------------------------------------------------------------------- 18 | # Data settings 19 | # ----------------------------------------------------------------------------- 20 | _C.DATA = CN() 21 | # Batch size for a single GPU, could be overwritten by command line argument 22 | _C.DATA.BATCH_SIZE = 128 23 | # Path to dataset, could be overwritten by command line argument 24 | _C.DATA.DATA_PATH = '' 25 | # Dataset name 26 | _C.DATA.DATASET = 'imagenet' 27 | # Input image size 28 | _C.DATA.IMG_SIZE = 224 29 | # Interpolation to resize image (random, bilinear, bicubic) 30 | _C.DATA.INTERPOLATION = 'bicubic' 31 | # Use zipped dataset instead of folder dataset 32 | # could be overwritten by command line argument 33 | _C.DATA.ZIP_MODE = False 34 | # Cache Data in Memory, could be overwritten by command line argument 35 | _C.DATA.CACHE_MODE = 'part' 36 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 37 | _C.DATA.PIN_MEMORY = True 38 | # Number of data loading threads 39 | _C.DATA.NUM_WORKERS = 8 40 | 41 | # ----------------------------------------------------------------------------- 42 | # Model settings 43 | # ----------------------------------------------------------------------------- 44 | _C.MODEL = CN() 45 | # Model type 46 | _C.MODEL.TYPE = 'swin' 47 | # Model name 48 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 49 | # Checkpoint to resume, could be overwritten by command line argument 50 | _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' 51 | _C.MODEL.RESUME = '' 52 | # Number of classes, overwritten in data preparation 53 | _C.MODEL.NUM_CLASSES = 1000 54 | # Dropout rate 55 | _C.MODEL.DROP_RATE = 0.0 56 | # Drop path rate 57 | _C.MODEL.DROP_PATH_RATE = 0.1 58 | # Label Smoothing 59 | _C.MODEL.LABEL_SMOOTHING = 0.1 60 | 61 | # Swin Transformer parameters 62 | _C.MODEL.SWIN = CN() 63 | _C.MODEL.SWIN.PATCH_SIZE = 4 64 | _C.MODEL.SWIN.IN_CHANS = 3 65 | _C.MODEL.SWIN.EMBED_DIM = 96 66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 67 | _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] 68 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 69 | _C.MODEL.SWIN.WINDOW_SIZE = 7 70 | _C.MODEL.SWIN.MLP_RATIO = 4. 71 | _C.MODEL.SWIN.QKV_BIAS = True 72 | _C.MODEL.SWIN.QK_SCALE = False 73 | _C.MODEL.SWIN.APE = False 74 | _C.MODEL.SWIN.PATCH_NORM = True 75 | _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" 76 | 77 | # ----------------------------------------------------------------------------- 78 | # Training settings 79 | # ----------------------------------------------------------------------------- 80 | _C.TRAIN = CN() 81 | _C.TRAIN.START_EPOCH = 0 82 | _C.TRAIN.EPOCHS = 300 83 | _C.TRAIN.WARMUP_EPOCHS = 20 84 | _C.TRAIN.WEIGHT_DECAY = 0.05 85 | _C.TRAIN.BASE_LR = 5e-4 86 | _C.TRAIN.WARMUP_LR = 5e-7 87 | _C.TRAIN.MIN_LR = 5e-6 88 | # Clip gradient norm 89 | _C.TRAIN.CLIP_GRAD = 5.0 90 | # Auto resume from latest checkpoint 91 | _C.TRAIN.AUTO_RESUME = True 92 | # Gradient accumulation steps 93 | # could be overwritten by command line argument 94 | _C.TRAIN.ACCUMULATION_STEPS = 0 95 | # Whether to use gradient checkpointing to save memory 96 | # could be overwritten by command line argument 97 | _C.TRAIN.USE_CHECKPOINT = False 98 | 99 | # LR scheduler 100 | _C.TRAIN.LR_SCHEDULER = CN() 101 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 102 | # Epoch interval to decay LR, used in StepLRScheduler 103 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 104 | # LR decay rate, used in StepLRScheduler 105 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 106 | 107 | # Optimizer 108 | _C.TRAIN.OPTIMIZER = CN() 109 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 110 | # Optimizer Epsilon 111 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 112 | # Optimizer Betas 113 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 114 | # SGD momentum 115 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 116 | 117 | # ----------------------------------------------------------------------------- 118 | # Augmentation settings 119 | # ----------------------------------------------------------------------------- 120 | _C.AUG = CN() 121 | # Color jitter factor 122 | _C.AUG.COLOR_JITTER = 0.4 123 | # Use AutoAugment policy. "v0" or "original" 124 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 125 | # Random erase prob 126 | _C.AUG.REPROB = 0.25 127 | # Random erase mode 128 | _C.AUG.REMODE = 'pixel' 129 | # Random erase count 130 | _C.AUG.RECOUNT = 1 131 | # Mixup alpha, mixup enabled if > 0 132 | _C.AUG.MIXUP = 0.8 133 | # Cutmix alpha, cutmix enabled if > 0 134 | _C.AUG.CUTMIX = 1.0 135 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 136 | _C.AUG.CUTMIX_MINMAX = False 137 | # Probability of performing mixup or cutmix when either/both is enabled 138 | _C.AUG.MIXUP_PROB = 1.0 139 | # Probability of switching to cutmix when both mixup and cutmix enabled 140 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 141 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 142 | _C.AUG.MIXUP_MODE = 'batch' 143 | 144 | # ----------------------------------------------------------------------------- 145 | # Testing settings 146 | # ----------------------------------------------------------------------------- 147 | _C.TEST = CN() 148 | # Whether to use center crop when testing 149 | _C.TEST.CROP = True 150 | 151 | # ----------------------------------------------------------------------------- 152 | # Misc 153 | # ----------------------------------------------------------------------------- 154 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 155 | # overwritten by command line argument 156 | _C.AMP_OPT_LEVEL = '' 157 | # Path to output folder, overwritten by command line argument 158 | _C.OUTPUT = '' 159 | # Tag of experiment, overwritten by command line argument 160 | _C.TAG = 'default' 161 | # Frequency to save checkpoint 162 | _C.SAVE_FREQ = 1 163 | # Frequency to logging info 164 | _C.PRINT_FREQ = 10 165 | # Fixed random seed 166 | _C.SEED = 0 167 | # Perform evaluation only, overwritten by command line argument 168 | _C.EVAL_MODE = False 169 | # Test throughput only, overwritten by command line argument 170 | _C.THROUGHPUT_MODE = False 171 | # local rank for DistributedDataParallel, given by command line argument 172 | _C.LOCAL_RANK = 0 173 | 174 | 175 | def _update_config_from_file(config, cfg_file): 176 | config.defrost() 177 | with open(cfg_file, 'r') as f: 178 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 179 | 180 | for cfg in yaml_cfg.setdefault('BASE', ['']): 181 | if cfg: 182 | _update_config_from_file( 183 | config, os.path.join(os.path.dirname(cfg_file), cfg) 184 | ) 185 | print('=> merge config from {}'.format(cfg_file)) 186 | config.merge_from_file(cfg_file) 187 | config.freeze() 188 | 189 | 190 | def update_config(config, args): 191 | _update_config_from_file(config, args.cfg) 192 | 193 | config.defrost() 194 | if args.opts: 195 | config.merge_from_list(args.opts) 196 | 197 | # merge from specific arguments 198 | if args.batch_size: 199 | config.DATA.BATCH_SIZE = args.batch_size 200 | if args.zip: 201 | config.DATA.ZIP_MODE = True 202 | if args.cache_mode: 203 | config.DATA.CACHE_MODE = args.cache_mode 204 | if args.resume: 205 | config.MODEL.RESUME = args.resume 206 | if args.accumulation_steps: 207 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 208 | if args.use_checkpoint: 209 | config.TRAIN.USE_CHECKPOINT = True 210 | if args.amp_opt_level: 211 | config.AMP_OPT_LEVEL = args.amp_opt_level 212 | if args.tag: 213 | config.TAG = args.tag 214 | if args.eval: 215 | config.EVAL_MODE = True 216 | if args.throughput: 217 | config.THROUGHPUT_MODE = True 218 | 219 | config.freeze() 220 | 221 | 222 | def get_config(args): 223 | """Get a yacs CfgNode object with default values.""" 224 | # Return a clone so that the defaults will not be altered 225 | # This is for the "local variable" use pattern 226 | config = _C.clone() 227 | update_config(config, args) 228 | 229 | return config 230 | -------------------------------------------------------------------------------- /code/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FC3DDiscriminator(nn.Module): 7 | 8 | def __init__(self, num_classes, ndf=64, n_channel=1): 9 | super(FC3DDiscriminator, self).__init__() 10 | # downsample 16 11 | self.conv0 = nn.Conv3d( 12 | num_classes, ndf, kernel_size=4, stride=2, padding=1) 13 | self.conv1 = nn.Conv3d( 14 | n_channel, ndf, kernel_size=4, stride=2, padding=1) 15 | 16 | self.conv2 = nn.Conv3d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 17 | self.conv3 = nn.Conv3d( 18 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 19 | self.conv4 = nn.Conv3d( 20 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 21 | self.avgpool = nn.AvgPool3d((6, 6, 6)) # (D/16, W/16, H/16) 22 | self.classifier = nn.Linear(ndf*8, 2) 23 | 24 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 25 | self.dropout = nn.Dropout3d(0.5) 26 | self.Softmax = nn.Softmax() 27 | 28 | def forward(self, map, image): 29 | batch_size = map.shape[0] 30 | map_feature = self.conv0(map) 31 | image_feature = self.conv1(image) 32 | x = torch.add(map_feature, image_feature) 33 | x = self.leaky_relu(x) 34 | x = self.dropout(x) 35 | 36 | x = self.conv2(x) 37 | x = self.leaky_relu(x) 38 | x = self.dropout(x) 39 | 40 | x = self.conv3(x) 41 | x = self.leaky_relu(x) 42 | x = self.dropout(x) 43 | 44 | x = self.conv4(x) 45 | x = self.leaky_relu(x) 46 | 47 | x = self.avgpool(x) 48 | 49 | x = x.view(batch_size, -1) 50 | 51 | x = self.classifier(x) 52 | x = x.reshape((batch_size, 2)) 53 | # x = self.Softmax(x) 54 | 55 | return x 56 | 57 | 58 | class FCDiscriminator(nn.Module): 59 | 60 | def __init__(self, num_classes, ndf=64, n_channel=1): 61 | super(FCDiscriminator, self).__init__() 62 | self.conv0 = nn.Conv2d( 63 | num_classes, ndf, kernel_size=4, stride=2, padding=1) 64 | self.conv1 = nn.Conv2d( 65 | n_channel, ndf, kernel_size=4, stride=2, padding=1) 66 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 67 | self.conv3 = nn.Conv2d( 68 | ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 69 | self.conv4 = nn.Conv2d( 70 | ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 71 | self.classifier = nn.Linear(ndf*32, 2) 72 | self.avgpool = nn.AvgPool2d((7, 7)) 73 | 74 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 75 | self.dropout = nn.Dropout2d(0.5) 76 | # self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 77 | # self.sigmoid = nn.Sigmoid() 78 | 79 | def forward(self, map, feature): 80 | map_feature = self.conv0(map) 81 | image_feature = self.conv1(feature) 82 | x = torch.add(map_feature, image_feature) 83 | 84 | x = self.conv2(x) 85 | x = self.leaky_relu(x) 86 | x = self.dropout(x) 87 | 88 | x = self.conv3(x) 89 | x = self.leaky_relu(x) 90 | x = self.dropout(x) 91 | 92 | x = self.conv4(x) 93 | x = self.leaky_relu(x) 94 | x = self.avgpool(x) 95 | x = x.view(x.size(0), -1) 96 | x = self.classifier(x) 97 | # x = self.up_sample(x) 98 | # x = self.sigmoid(x) 99 | 100 | return x 101 | -------------------------------------------------------------------------------- /code/networks/efficientunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from networks.attention import * 6 | from networks.efficient_encoder import get_encoder 7 | 8 | 9 | def initialize_decoder(module): 10 | for m in module.modules(): 11 | 12 | if isinstance(m, nn.Conv2d): 13 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 14 | if m.bias is not None: 15 | nn.init.constant_(m.bias, 0) 16 | 17 | elif isinstance(m, nn.BatchNorm2d): 18 | nn.init.constant_(m.weight, 1) 19 | nn.init.constant_(m.bias, 0) 20 | 21 | elif isinstance(m, nn.Linear): 22 | nn.init.xavier_uniform_(m.weight) 23 | if m.bias is not None: 24 | nn.init.constant_(m.bias, 0) 25 | 26 | 27 | class DecoderBlock(nn.Module): 28 | def __init__( 29 | self, 30 | in_channels, 31 | skip_channels, 32 | out_channels, 33 | use_batchnorm=True, 34 | attention_type=None, 35 | ): 36 | super().__init__() 37 | self.conv1 = Conv2dReLU( 38 | in_channels + skip_channels, 39 | out_channels, 40 | kernel_size=3, 41 | padding=1, 42 | use_batchnorm=use_batchnorm, 43 | ) 44 | self.attention1 = Attention(attention_type, in_channels=in_channels + skip_channels) 45 | self.conv2 = Conv2dReLU( 46 | out_channels, 47 | out_channels, 48 | kernel_size=3, 49 | padding=1, 50 | use_batchnorm=use_batchnorm, 51 | ) 52 | self.attention2 = Attention(attention_type, in_channels=out_channels) 53 | 54 | def forward(self, x, skip=None): 55 | x = F.interpolate(x, scale_factor=2, mode="nearest") 56 | if skip is not None: 57 | x = torch.cat([x, skip], dim=1) 58 | x = self.attention1(x) 59 | x = self.conv1(x) 60 | x = self.conv2(x) 61 | x = self.attention2(x) 62 | return x 63 | 64 | 65 | class CenterBlock(nn.Sequential): 66 | def __init__(self, in_channels, out_channels, use_batchnorm=True): 67 | conv1 = Conv2dReLU( 68 | in_channels, 69 | out_channels, 70 | kernel_size=3, 71 | padding=1, 72 | use_batchnorm=use_batchnorm, 73 | ) 74 | conv2 = Conv2dReLU( 75 | out_channels, 76 | out_channels, 77 | kernel_size=3, 78 | padding=1, 79 | use_batchnorm=use_batchnorm, 80 | ) 81 | super().__init__(conv1, conv2) 82 | 83 | 84 | class UnetDecoder(nn.Module): 85 | def __init__( 86 | self, 87 | encoder_channels, 88 | decoder_channels, 89 | n_blocks=5, 90 | use_batchnorm=True, 91 | attention_type=None, 92 | center=False, 93 | ): 94 | super().__init__() 95 | 96 | if n_blocks != len(decoder_channels): 97 | raise ValueError( 98 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 99 | n_blocks, len(decoder_channels) 100 | ) 101 | ) 102 | 103 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution 104 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 105 | 106 | # computing blocks input and output channels 107 | head_channels = encoder_channels[0] 108 | in_channels = [head_channels] + list(decoder_channels[:-1]) 109 | skip_channels = list(encoder_channels[1:]) + [0] 110 | out_channels = decoder_channels 111 | 112 | if center: 113 | self.center = CenterBlock( 114 | head_channels, head_channels, use_batchnorm=use_batchnorm 115 | ) 116 | else: 117 | self.center = nn.Identity() 118 | 119 | # combine decoder keyword arguments 120 | kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) 121 | blocks = [ 122 | DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) 123 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 124 | ] 125 | self.blocks = nn.ModuleList(blocks) 126 | 127 | def forward(self, *features): 128 | 129 | features = features[1:] # remove first skip with same spatial resolution 130 | features = features[::-1] # reverse channels to start from head of encoder 131 | 132 | head = features[0] 133 | skips = features[1:] 134 | 135 | x = self.center(head) 136 | for i, decoder_block in enumerate(self.blocks): 137 | skip = skips[i] if i < len(skips) else None 138 | x = decoder_block(x, skip) 139 | 140 | return x 141 | 142 | 143 | class Effi_UNet(nn.Module): 144 | """Unet_ is a fully convolution neural network for image semantic segmentation 145 | 146 | Args: 147 | encoder_name: name of classification model (without last dense layers) used as feature 148 | extractor to build segmentation model. 149 | encoder_depth (int): number of stages used in decoder, larger depth - more features are generated. 150 | e.g. for depth=3 encoder will generate list of features with following spatial shapes 151 | [(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature tensor will have 152 | spatial resolution (H/(2^depth), W/(2^depth)] 153 | encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). 154 | decoder_channels: list of numbers of ``Conv2D`` layer filters in decoder blocks 155 | decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers 156 | is used. If 'inplace' InplaceABN will be used, allows to decrease memory consumption. 157 | One of [True, False, 'inplace'] 158 | decoder_attention_type: attention module used in decoder of the model 159 | One of [``None``, ``scse``] 160 | in_channels: number of input channels for model, default is 3. 161 | classes: a number of classes for output (output shape - ``(batch, classes, h, w)``). 162 | activation: activation function to apply after final convolution; 163 | One of [``sigmoid``, ``softmax``, ``logsoftmax``, ``identity``, callable, None] 164 | aux_params: if specified model will have additional classification auxiliary output 165 | build on top of encoder, supported params: 166 | - classes (int): number of classes 167 | - pooling (str): one of 'max', 'avg'. Default is 'avg'. 168 | - dropout (float): dropout factor in [0, 1) 169 | - activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits) 170 | 171 | Returns: 172 | ``torch.nn.Module``: **Unet** 173 | 174 | .. _Unet: 175 | https://arxiv.org/pdf/1505.04597 176 | 177 | """ 178 | 179 | def __init__( 180 | self, 181 | encoder_name: str = "resnet34", 182 | encoder_depth: int = 5, 183 | encoder_weights: str = "imagenet", 184 | decoder_use_batchnorm=True, 185 | decoder_channels=(256, 128, 64, 32, 16), 186 | decoder_attention_type=None, 187 | in_channels: int = 3, 188 | classes: int = 1): 189 | super().__init__() 190 | 191 | self.encoder = get_encoder( 192 | encoder_name, 193 | in_channels=in_channels, 194 | depth=encoder_depth, 195 | weights=encoder_weights, 196 | ) 197 | 198 | self.decoder = UnetDecoder( 199 | encoder_channels=self.encoder.out_channels, 200 | decoder_channels=decoder_channels, 201 | n_blocks=encoder_depth, 202 | use_batchnorm=decoder_use_batchnorm, 203 | center=True if encoder_name.startswith("vgg") else False, 204 | attention_type=decoder_attention_type, 205 | ) 206 | initialize_decoder(self.decoder) 207 | self.classifier = nn.Conv2d(decoder_channels[-1], classes, 1) 208 | 209 | def forward(self, x): 210 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 211 | features = self.encoder(x) 212 | decoder_output = self.decoder(*features) 213 | output = self.classifier(decoder_output) 214 | 215 | return output 216 | 217 | 218 | # unet = UNet('efficientnet-b3', encoder_weights='imagenet', in_channels=1, classes=1, decoder_attention_type="scse") 219 | # t = torch.rand(2, 1, 224, 224) 220 | # print(unet) 221 | # print(unet(t).shape) 222 | -------------------------------------------------------------------------------- /code/networks/encoder_tool.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.model_zoo as model_zoo 6 | from efficientnet_pytorch import EfficientNet 7 | from efficientnet_pytorch.utils import get_model_params, url_map 8 | 9 | 10 | class EncoderMixin: 11 | """Add encoder functionality such as: 12 | - output channels specification of feature tensors (produced by encoder) 13 | - patching first convolution for arbitrary input channels 14 | """ 15 | 16 | @property 17 | def out_channels(self) -> List: 18 | """Return channels dimensions for each tensor of forward output of encoder""" 19 | return self._out_channels[: self._depth + 1] 20 | 21 | def set_in_channels(self, in_channels): 22 | """Change first convolution chennels""" 23 | if in_channels == 3: 24 | return 25 | 26 | self._in_channels = in_channels 27 | if self._out_channels[0] == 3: 28 | self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) 29 | 30 | patch_first_conv(model=self, in_channels=in_channels) 31 | 32 | 33 | def patch_first_conv(model, in_channels): 34 | """Change first convolution layer input channels. 35 | In case: 36 | in_channels == 1 or in_channels == 2 -> reuse original weights 37 | in_channels > 3 -> make random kaiming normal initialization 38 | """ 39 | 40 | # get first conv 41 | for module in model.modules(): 42 | if isinstance(module, nn.Conv2d): 43 | break 44 | 45 | # change input channels for first conv 46 | module.in_channels = in_channels 47 | weight = module.weight.detach() 48 | reset = False 49 | 50 | if in_channels == 1: 51 | weight = weight.sum(1, keepdim=True) 52 | elif in_channels == 2: 53 | weight = weight[:, :2] * (3.0 / 2.0) 54 | else: 55 | reset = True 56 | weight = torch.Tensor( 57 | module.out_channels, 58 | module.in_channels // module.groups, 59 | *module.kernel_size 60 | ) 61 | 62 | module.weight = nn.parameter.Parameter(weight) 63 | if reset: 64 | module.reset_parameters() 65 | 66 | 67 | class EfficientNetEncoder(EfficientNet, EncoderMixin): 68 | def __init__(self, stage_idxs, out_channels, model_name, depth=5): 69 | 70 | blocks_args, global_params = get_model_params(model_name, override_params=None) 71 | super().__init__(blocks_args, global_params) 72 | 73 | self._stage_idxs = list(stage_idxs) + [len(self._blocks)] 74 | self._out_channels = out_channels 75 | self._depth = depth 76 | self._in_channels = 3 77 | 78 | del self._fc 79 | 80 | def forward(self, x): 81 | 82 | features = [x] 83 | 84 | if self._depth > 0: 85 | x = self._swish(self._bn0(self._conv_stem(x))) 86 | features.append(x) 87 | 88 | if self._depth > 1: 89 | skip_connection_idx = 0 90 | for idx, block in enumerate(self._blocks): 91 | drop_connect_rate = self._global_params.drop_connect_rate 92 | if drop_connect_rate: 93 | drop_connect_rate *= float(idx) / len(self._blocks) 94 | x = block(x, drop_connect_rate=drop_connect_rate) 95 | if idx == self._stage_idxs[skip_connection_idx] - 1: 96 | skip_connection_idx += 1 97 | features.append(x) 98 | if skip_connection_idx + 1 == self._depth: 99 | break 100 | return features 101 | 102 | def load_state_dict(self, state_dict, **kwargs): 103 | state_dict.pop("_fc.bias") 104 | state_dict.pop("_fc.weight") 105 | super().load_state_dict(state_dict, **kwargs) 106 | 107 | 108 | def _get_pretrained_settings(encoder): 109 | pretrained_settings = { 110 | "imagenet": { 111 | "mean": [0.485, 0.456, 0.406], 112 | "std": [0.229, 0.224, 0.225], 113 | "url": url_map[encoder], 114 | "input_space": "RGB", 115 | "input_range": [0, 1], 116 | } 117 | } 118 | return pretrained_settings 119 | 120 | 121 | efficient_net_encoders = { 122 | "efficientnet-b0": { 123 | "encoder": EfficientNetEncoder, 124 | "pretrained_settings": _get_pretrained_settings("efficientnet-b0"), 125 | "params": { 126 | "out_channels": (3, 32, 24, 40, 112, 320), 127 | "stage_idxs": (3, 5, 9), 128 | "model_name": "efficientnet-b0", 129 | }, 130 | }, 131 | "efficientnet-b1": { 132 | "encoder": EfficientNetEncoder, 133 | "pretrained_settings": _get_pretrained_settings("efficientnet-b1"), 134 | "params": { 135 | "out_channels": (3, 32, 24, 40, 112, 320), 136 | "stage_idxs": (5, 8, 16), 137 | "model_name": "efficientnet-b1", 138 | }, 139 | }, 140 | "efficientnet-b2": { 141 | "encoder": EfficientNetEncoder, 142 | "pretrained_settings": _get_pretrained_settings("efficientnet-b2"), 143 | "params": { 144 | "out_channels": (3, 32, 24, 48, 120, 352), 145 | "stage_idxs": (5, 8, 16), 146 | "model_name": "efficientnet-b2", 147 | }, 148 | }, 149 | "efficientnet-b3": { 150 | "encoder": EfficientNetEncoder, 151 | "pretrained_settings": _get_pretrained_settings("efficientnet-b3"), 152 | "params": { 153 | "out_channels": (3, 40, 32, 48, 136, 384), 154 | "stage_idxs": (5, 8, 18), 155 | "model_name": "efficientnet-b3", 156 | }, 157 | }, 158 | "efficientnet-b4": { 159 | "encoder": EfficientNetEncoder, 160 | "pretrained_settings": _get_pretrained_settings("efficientnet-b4"), 161 | "params": { 162 | "out_channels": (3, 48, 32, 56, 160, 448), 163 | "stage_idxs": (6, 10, 22), 164 | "model_name": "efficientnet-b4", 165 | }, 166 | }, 167 | "efficientnet-b5": { 168 | "encoder": EfficientNetEncoder, 169 | "pretrained_settings": _get_pretrained_settings("efficientnet-b5"), 170 | "params": { 171 | "out_channels": (3, 48, 40, 64, 176, 512), 172 | "stage_idxs": (8, 13, 27), 173 | "model_name": "efficientnet-b5", 174 | }, 175 | }, 176 | "efficientnet-b6": { 177 | "encoder": EfficientNetEncoder, 178 | "pretrained_settings": _get_pretrained_settings("efficientnet-b6"), 179 | "params": { 180 | "out_channels": (3, 56, 40, 72, 200, 576), 181 | "stage_idxs": (9, 15, 31), 182 | "model_name": "efficientnet-b6", 183 | }, 184 | }, 185 | "efficientnet-b7": { 186 | "encoder": EfficientNetEncoder, 187 | "pretrained_settings": _get_pretrained_settings("efficientnet-b7"), 188 | "params": { 189 | "out_channels": (3, 64, 48, 80, 224, 640), 190 | "stage_idxs": (11, 18, 38), 191 | "model_name": "efficientnet-b7", 192 | }, 193 | }, 194 | } 195 | 196 | encoders = {} 197 | encoders.update(efficient_net_encoders) 198 | 199 | 200 | def get_encoder(name, in_channels=3, depth=5, weights=None): 201 | Encoder = encoders[name]["encoder"] 202 | params = encoders[name]["params"] 203 | params.update(depth=depth) 204 | encoder = Encoder(**params) 205 | 206 | if weights is not None: 207 | settings = encoders[name]["pretrained_settings"][weights] 208 | encoder.load_state_dict(model_zoo.load_url(settings["url"])) 209 | 210 | encoder.set_in_channels(in_channels) 211 | 212 | return encoder 213 | -------------------------------------------------------------------------------- /code/networks/net_factory.py: -------------------------------------------------------------------------------- 1 | from networks.efficientunet import Effi_UNet 2 | from networks.enet import ENet 3 | from networks.pnet import PNet2D 4 | from networks.unet import UNet, UNet_DS, UNet_URPC, UNet_CCT 5 | import argparse 6 | from networks.vision_transformer import SwinUnet as ViT_seg 7 | from networks.config import get_config 8 | from networks.nnunet import initialize_network 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--root_path', type=str, 13 | default='../data/ACDC', help='Name of Experiment') 14 | parser.add_argument('--exp', type=str, 15 | default='ACDC/Cross_Supervision_CNN_Trans2D', help='experiment_name') 16 | parser.add_argument('--model', type=str, 17 | default='unet', help='model_name') 18 | parser.add_argument('--max_iterations', type=int, 19 | default=30000, help='maximum epoch number to train') 20 | parser.add_argument('--batch_size', type=int, default=8, 21 | help='batch_size per gpu') 22 | parser.add_argument('--deterministic', type=int, default=1, 23 | help='whether use deterministic training') 24 | parser.add_argument('--base_lr', type=float, default=0.01, 25 | help='segmentation network learning rate') 26 | parser.add_argument('--patch_size', type=list, default=[224, 224], 27 | help='patch size of network input') 28 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 29 | parser.add_argument('--num_classes', type=int, default=4, 30 | help='output channel of network') 31 | parser.add_argument( 32 | '--cfg', type=str, default="../code/configs/swin_tiny_patch4_window7_224_lite.yaml", help='path to config file', ) 33 | parser.add_argument( 34 | "--opts", 35 | help="Modify config options by adding 'KEY VALUE' pairs. ", 36 | default=None, 37 | nargs='+', 38 | ) 39 | parser.add_argument('--zip', action='store_true', 40 | help='use zipped dataset instead of folder dataset') 41 | parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], 42 | help='no: no cache, ' 43 | 'full: cache all data, ' 44 | 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') 45 | parser.add_argument('--resume', help='resume from checkpoint') 46 | parser.add_argument('--accumulation-steps', type=int, 47 | help="gradient accumulation steps") 48 | parser.add_argument('--use-checkpoint', action='store_true', 49 | help="whether to use gradient checkpointing to save memory") 50 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 51 | help='mixed precision opt level, if O0, no amp is used') 52 | parser.add_argument('--tag', help='tag of experiment') 53 | parser.add_argument('--eval', action='store_true', 54 | help='Perform evaluation only') 55 | parser.add_argument('--throughput', action='store_true', 56 | help='Test throughput only') 57 | 58 | # label and unlabel 59 | parser.add_argument('--labeled_bs', type=int, default=4, 60 | help='labeled_batch_size per gpu') 61 | parser.add_argument('--labeled_num', type=int, default=7, 62 | help='labeled data') 63 | # costs 64 | parser.add_argument('--ema_decay', type=float, default=0.99, help='ema_decay') 65 | parser.add_argument('--consistency_type', type=str, 66 | default="mse", help='consistency_type') 67 | parser.add_argument('--consistency', type=float, 68 | default=0.1, help='consistency') 69 | parser.add_argument('--consistency_rampup', type=float, 70 | default=200.0, help='consistency_rampup') 71 | args = parser.parse_args() 72 | config = get_config(args) 73 | 74 | 75 | def net_factory(net_type="unet", in_chns=1, class_num=3): 76 | if net_type == "unet": 77 | net = UNet(in_chns=in_chns, class_num=class_num).cuda() 78 | elif net_type == "enet": 79 | net = ENet(in_channels=in_chns, num_classes=class_num).cuda() 80 | elif net_type == "unet_ds": 81 | net = UNet_DS(in_chns=in_chns, class_num=class_num).cuda() 82 | elif net_type == "unet_cct": 83 | net = UNet_CCT(in_chns=in_chns, class_num=class_num).cuda() 84 | elif net_type == "unet_urpc": 85 | net = UNet_URPC(in_chns=in_chns, class_num=class_num).cuda() 86 | elif net_type == "efficient_unet": 87 | net = Effi_UNet('efficientnet-b3', encoder_weights='imagenet', 88 | in_channels=in_chns, classes=class_num).cuda() 89 | elif net_type == "ViT_Seg": 90 | net = ViT_seg(config, img_size=args.patch_size, 91 | num_classes=args.num_classes).cuda() 92 | elif net_type == "pnet": 93 | net = PNet2D(in_chns, class_num, 64, [1, 2, 4, 8, 16]).cuda() 94 | elif net_type == "nnUNet": 95 | net = initialize_network(num_classes=class_num).cuda() 96 | else: 97 | net = None 98 | return net 99 | -------------------------------------------------------------------------------- /code/networks/net_factory_3d.py: -------------------------------------------------------------------------------- 1 | from networks.unet_3D import unet_3D 2 | from networks.vnet import VNet 3 | from networks.VoxResNet import VoxResNet 4 | from networks.attention_unet import Attention_UNet 5 | from networks.nnunet import initialize_network 6 | 7 | 8 | def net_factory_3d(net_type="unet_3D", in_chns=1, class_num=2): 9 | if net_type == "unet_3D": 10 | net = unet_3D(n_classes=class_num, in_channels=in_chns).cuda() 11 | elif net_type == "attention_unet": 12 | net = Attention_UNet(n_classes=class_num, in_channels=in_chns).cuda() 13 | elif net_type == "voxresnet": 14 | net = VoxResNet(in_chns=in_chns, feature_chns=64, 15 | class_num=class_num).cuda() 16 | elif net_type == "vnet": 17 | net = VNet(n_channels=in_chns, n_classes=class_num, 18 | normalization='batchnorm', has_dropout=True).cuda() 19 | elif net_type == "nnUNet": 20 | net = initialize_network(num_classes=class_num).cuda() 21 | else: 22 | net = None 23 | return net 24 | -------------------------------------------------------------------------------- /code/networks/pnet.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- coding: utf-8 -*- 3 | """ 4 | An PyTorch implementation of the DeepIGeoS paper: 5 | Wang, Guotai and Zuluaga, Maria A and Li, Wenqi and Pratt, Rosalind and Patel, Premal A and Aertsen, Michael and Doel, Tom and David, Anna L and Deprest, Jan and Ourselin, S{\'e}bastien and others: 6 | DeepIGeoS: a deep interactive geodesic framework for medical image segmentation. 7 | TPAMI (7) 2018: 1559--1572 8 | Note that there are some modifications from the original paper, such as 9 | the use of leaky relu here. 10 | """ 11 | from __future__ import division, print_function 12 | 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | class PNetBlock(nn.Module): 18 | def __init__(self, in_channels, out_channels, dilation, padding): 19 | super(PNetBlock, self).__init__() 20 | 21 | self.in_chns = in_channels 22 | self.out_chns = out_channels 23 | self.dilation = dilation 24 | self.padding = padding 25 | 26 | self.conv1 = nn.Conv2d(self.in_chns, self.out_chns, kernel_size=3, 27 | padding=self.padding, dilation=self.dilation, groups=1, bias=True) 28 | self.conv2 = nn.Conv2d(self.out_chns, self.out_chns, kernel_size=3, 29 | padding=self.padding, dilation=self.dilation, groups=1, bias=True) 30 | self.in1 = nn.BatchNorm2d(self.out_chns) 31 | self.in2 = nn.BatchNorm2d(self.out_chns) 32 | self.ac1 = nn.LeakyReLU() 33 | self.ac2 = nn.LeakyReLU() 34 | 35 | def forward(self, x): 36 | x = self.conv1(x) 37 | x = self.in1(x) 38 | x = self.ac1(x) 39 | x = self.conv2(x) 40 | x = self.in2(x) 41 | x = self.ac2(x) 42 | return x 43 | 44 | 45 | class ConcatBlock(nn.Module): 46 | def __init__(self, in_channels, out_channels): 47 | super(ConcatBlock, self).__init__() 48 | self.in_chns = in_channels 49 | self.out_chns = out_channels 50 | self.conv1 = nn.Conv2d( 51 | self.in_chns, self.in_chns, kernel_size=1, padding=0) 52 | self.conv2 = nn.Conv2d( 53 | self.in_chns, self.out_chns, kernel_size=1, padding=0) 54 | self.ac1 = nn.LeakyReLU() 55 | self.ac2 = nn.LeakyReLU() 56 | 57 | def forward(self, x): 58 | x = self.conv1(x) 59 | x = self.ac1(x) 60 | x = self.conv2(x) 61 | x = self.ac2(x) 62 | return x 63 | 64 | 65 | class OutPutBlock(nn.Module): 66 | def __init__(self, in_channels, out_channels): 67 | super(OutPutBlock, self).__init__() 68 | self.in_chns = in_channels 69 | self.out_chns = out_channels 70 | self.conv1 = nn.Conv2d( 71 | self.in_chns, self.in_chns // 2, kernel_size=1, padding=0) 72 | self.conv2 = nn.Conv2d( 73 | self.in_chns // 2, self.out_chns, kernel_size=1, padding=0) 74 | self.drop1 = nn.Dropout2d(0.3) 75 | self.drop2 = nn.Dropout2d(0.3) 76 | self.ac1 = nn.LeakyReLU() 77 | 78 | def forward(self, x): 79 | x = self.drop1(x) 80 | x = self.conv1(x) 81 | x = self.ac1(x) 82 | x = self.drop2(x) 83 | x = self.conv2(x) 84 | return x 85 | 86 | 87 | class PNet2D(nn.Module): 88 | def __init__(self, in_chns, out_chns, num_filters, ratios): 89 | super(PNet2D, self).__init__() 90 | 91 | self.in_chns = in_chns 92 | self.out_chns = out_chns 93 | self.ratios = ratios 94 | self.num_filters = num_filters 95 | 96 | self.block1 = PNetBlock( 97 | self.in_chns, self.num_filters, self.ratios[0], padding=self.ratios[0]) 98 | 99 | self.block2 = PNetBlock( 100 | self.num_filters, self.num_filters, self.ratios[1], padding=self.ratios[1]) 101 | 102 | self.block3 = PNetBlock( 103 | self.num_filters, self.num_filters, self.ratios[2], padding=self.ratios[2]) 104 | 105 | self.block4 = PNetBlock( 106 | self.num_filters, self.num_filters, self.ratios[3], padding=self.ratios[3]) 107 | 108 | self.block5 = PNetBlock( 109 | self.num_filters, self.num_filters, self.ratios[4], padding=self.ratios[4]) 110 | self.catblock = ConcatBlock(self.num_filters * 5, self.num_filters * 2) 111 | self.out = OutPutBlock(self.num_filters * 2, self.out_chns) 112 | 113 | def forward(self, x): 114 | x1 = self.block1(x) 115 | x2 = self.block2(x1) 116 | x3 = self.block3(x2) 117 | x4 = self.block4(x3) 118 | x5 = self.block5(x4) 119 | conx = torch.cat([x1, x2, x3, x4, x5], dim=1) 120 | conx = self.catblock(conx) 121 | out = self.out(conx) 122 | return out 123 | -------------------------------------------------------------------------------- /code/networks/unet_3D.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | An implementation of the 3D U-Net paper: 4 | Özgün Çiçek, Ahmed Abdulkadir, Soeren S. Lienkamp, Thomas Brox, Olaf Ronneberger: 5 | 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. 6 | MICCAI (2) 2016: 424-432 7 | Note that there are some modifications from the original paper, such as 8 | the use of batch normalization, dropout, and leaky relu here. 9 | The implementation is borrowed from: https://github.com/ozan-oktay/Attention-Gated-Networks 10 | """ 11 | import math 12 | 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from networks.networks_other import init_weights 17 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT 18 | 19 | 20 | class unet_3D(nn.Module): 21 | 22 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 23 | super(unet_3D, self).__init__() 24 | self.is_deconv = is_deconv 25 | self.in_channels = in_channels 26 | self.is_batchnorm = is_batchnorm 27 | self.feature_scale = feature_scale 28 | 29 | filters = [64, 128, 256, 512, 1024] 30 | filters = [int(x / self.feature_scale) for x in filters] 31 | 32 | # downsampling 33 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 34 | 3, 3, 3), padding_size=(1, 1, 1)) 35 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 36 | 37 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 38 | 3, 3, 3), padding_size=(1, 1, 1)) 39 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 40 | 41 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 42 | 3, 3, 3), padding_size=(1, 1, 1)) 43 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 44 | 45 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 46 | 3, 3, 3), padding_size=(1, 1, 1)) 47 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 48 | 49 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 50 | 3, 3, 3), padding_size=(1, 1, 1)) 51 | 52 | # upsampling 53 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 54 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 55 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 56 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 57 | 58 | # final conv (without any concat) 59 | self.final = nn.Conv3d(filters[0], n_classes, 1) 60 | 61 | self.dropout1 = nn.Dropout(p=0.3) 62 | self.dropout2 = nn.Dropout(p=0.3) 63 | 64 | # initialise weights 65 | for m in self.modules(): 66 | if isinstance(m, nn.Conv3d): 67 | init_weights(m, init_type='kaiming') 68 | elif isinstance(m, nn.BatchNorm3d): 69 | init_weights(m, init_type='kaiming') 70 | 71 | def forward(self, inputs): 72 | conv1 = self.conv1(inputs) 73 | maxpool1 = self.maxpool1(conv1) 74 | 75 | conv2 = self.conv2(maxpool1) 76 | maxpool2 = self.maxpool2(conv2) 77 | 78 | conv3 = self.conv3(maxpool2) 79 | maxpool3 = self.maxpool3(conv3) 80 | 81 | conv4 = self.conv4(maxpool3) 82 | maxpool4 = self.maxpool4(conv4) 83 | 84 | center = self.center(maxpool4) 85 | center = self.dropout1(center) 86 | up4 = self.up_concat4(conv4, center) 87 | up3 = self.up_concat3(conv3, up4) 88 | up2 = self.up_concat2(conv2, up3) 89 | up1 = self.up_concat1(conv1, up2) 90 | up1 = self.dropout2(up1) 91 | 92 | final = self.final(up1) 93 | 94 | return final 95 | 96 | @staticmethod 97 | def apply_argmax_softmax(pred): 98 | log_p = F.softmax(pred, dim=1) 99 | 100 | return log_p 101 | -------------------------------------------------------------------------------- /code/networks/unet_3D_dv_semi.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is adapted from https://github.com/ozan-oktay/Attention-Gated-Networks 3 | """ 4 | 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from networks.utils import UnetConv3, UnetUp3, UnetUp3_CT, UnetDsv3 9 | import torch.nn.functional as F 10 | from networks.networks_other import init_weights 11 | 12 | 13 | class unet_3D_dv_semi(nn.Module): 14 | 15 | def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): 16 | super(unet_3D_dv_semi, self).__init__() 17 | self.is_deconv = is_deconv 18 | self.in_channels = in_channels 19 | self.is_batchnorm = is_batchnorm 20 | self.feature_scale = feature_scale 21 | 22 | filters = [64, 128, 256, 512, 1024] 23 | filters = [int(x / self.feature_scale) for x in filters] 24 | 25 | # downsampling 26 | self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=( 27 | 3, 3, 3), padding_size=(1, 1, 1)) 28 | self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 29 | 30 | self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=( 31 | 3, 3, 3), padding_size=(1, 1, 1)) 32 | self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 33 | 34 | self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=( 35 | 3, 3, 3), padding_size=(1, 1, 1)) 36 | self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 37 | 38 | self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=( 39 | 3, 3, 3), padding_size=(1, 1, 1)) 40 | self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2)) 41 | 42 | self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=( 43 | 3, 3, 3), padding_size=(1, 1, 1)) 44 | 45 | # upsampling 46 | self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm) 47 | self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm) 48 | self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm) 49 | self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm) 50 | 51 | # deep supervision 52 | self.dsv4 = UnetDsv3( 53 | in_size=filters[3], out_size=n_classes, scale_factor=8) 54 | self.dsv3 = UnetDsv3( 55 | in_size=filters[2], out_size=n_classes, scale_factor=4) 56 | self.dsv2 = UnetDsv3( 57 | in_size=filters[1], out_size=n_classes, scale_factor=2) 58 | self.dsv1 = nn.Conv3d( 59 | in_channels=filters[0], out_channels=n_classes, kernel_size=1) 60 | 61 | self.dropout1 = nn.Dropout3d(p=0.5) 62 | self.dropout2 = nn.Dropout3d(p=0.3) 63 | self.dropout3 = nn.Dropout3d(p=0.2) 64 | self.dropout4 = nn.Dropout3d(p=0.1) 65 | 66 | # initialise weights 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv3d): 69 | init_weights(m, init_type='kaiming') 70 | elif isinstance(m, nn.BatchNorm3d): 71 | init_weights(m, init_type='kaiming') 72 | 73 | def forward(self, inputs): 74 | conv1 = self.conv1(inputs) 75 | maxpool1 = self.maxpool1(conv1) 76 | 77 | conv2 = self.conv2(maxpool1) 78 | maxpool2 = self.maxpool2(conv2) 79 | 80 | conv3 = self.conv3(maxpool2) 81 | maxpool3 = self.maxpool3(conv3) 82 | 83 | conv4 = self.conv4(maxpool3) 84 | maxpool4 = self.maxpool4(conv4) 85 | 86 | center = self.center(maxpool4) 87 | 88 | up4 = self.up_concat4(conv4, center) 89 | up4 = self.dropout1(up4) 90 | 91 | up3 = self.up_concat3(conv3, up4) 92 | up3 = self.dropout2(up3) 93 | 94 | up2 = self.up_concat2(conv2, up3) 95 | up2 = self.dropout3(up2) 96 | 97 | up1 = self.up_concat1(conv1, up2) 98 | up1 = self.dropout4(up1) 99 | 100 | # Deep Supervision 101 | dsv4 = self.dsv4(up4) 102 | dsv3 = self.dsv3(up3) 103 | dsv2 = self.dsv2(up2) 104 | dsv1 = self.dsv1(up1) 105 | 106 | return dsv1, dsv2, dsv3, dsv4 107 | 108 | @staticmethod 109 | def apply_argmax_softmax(pred): 110 | log_p = F.softmax(pred, dim=1) 111 | 112 | return log_p 113 | -------------------------------------------------------------------------------- /code/networks/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # This file borrowed from Swin-UNet: https://github.com/HuCaoFighting/Swin-Unet 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import copy 8 | import logging 9 | import math 10 | 11 | from os.path import join as pjoin 12 | 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | 17 | from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm 18 | from torch.nn.modules.utils import _pair 19 | from scipy import ndimage 20 | from networks.swin_transformer_unet_skip_expand_decoder_sys import SwinTransformerSys 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | class SwinUnet(nn.Module): 25 | def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False): 26 | super(SwinUnet, self).__init__() 27 | self.num_classes = num_classes 28 | self.zero_head = zero_head 29 | self.config = config 30 | 31 | self.swin_unet = SwinTransformerSys(img_size=config.DATA.IMG_SIZE, 32 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 33 | in_chans=config.MODEL.SWIN.IN_CHANS, 34 | num_classes=self.num_classes, 35 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 36 | depths=config.MODEL.SWIN.DEPTHS, 37 | num_heads=config.MODEL.SWIN.NUM_HEADS, 38 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 39 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 40 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 41 | qk_scale=config.MODEL.SWIN.QK_SCALE, 42 | drop_rate=config.MODEL.DROP_RATE, 43 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 44 | ape=config.MODEL.SWIN.APE, 45 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 46 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 47 | 48 | def forward(self, x): 49 | if x.size()[1] == 1: 50 | x = x.repeat(1,3,1,1) 51 | logits = self.swin_unet(x) 52 | return logits 53 | 54 | def load_from(self, config): 55 | pretrained_path = config.MODEL.PRETRAIN_CKPT 56 | if pretrained_path is not None: 57 | print("pretrained_path:{}".format(pretrained_path)) 58 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 59 | pretrained_dict = torch.load(pretrained_path, map_location=device) 60 | if "model" not in pretrained_dict: 61 | print("---start load pretrained modle by splitting---") 62 | pretrained_dict = {k[17:]:v for k,v in pretrained_dict.items()} 63 | for k in list(pretrained_dict.keys()): 64 | if "output" in k: 65 | print("delete key:{}".format(k)) 66 | del pretrained_dict[k] 67 | msg = self.swin_unet.load_state_dict(pretrained_dict,strict=False) 68 | # print(msg) 69 | return 70 | pretrained_dict = pretrained_dict['model'] 71 | print("---start load pretrained modle of swin encoder---") 72 | 73 | model_dict = self.swin_unet.state_dict() 74 | full_dict = copy.deepcopy(pretrained_dict) 75 | for k, v in pretrained_dict.items(): 76 | if "layers." in k: 77 | current_layer_num = 3-int(k[7:8]) 78 | current_k = "layers_up." + str(current_layer_num) + k[8:] 79 | full_dict.update({current_k:v}) 80 | for k in list(full_dict.keys()): 81 | if k in model_dict: 82 | if full_dict[k].shape != model_dict[k].shape: 83 | print("delete:{};shape pretrain:{};shape model:{}".format(k,v.shape,model_dict[k].shape)) 84 | del full_dict[k] 85 | 86 | msg = self.swin_unet.load_state_dict(full_dict, strict=False) 87 | # print(msg) 88 | else: 89 | print("none pretrain") 90 | -------------------------------------------------------------------------------- /code/pretrained_ckpt/readme.txt: -------------------------------------------------------------------------------- 1 | download pre-trained model to this folder, link:https://drive.google.com/drive/folders/1UC3XOoezeum0uck4KBVGa8osahs6rKUY 2 | -------------------------------------------------------------------------------- /code/test_2D_fully.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import h5py 6 | import nibabel as nib 7 | import numpy as np 8 | import SimpleITK as sitk 9 | import torch 10 | from medpy import metric 11 | from scipy.ndimage import zoom 12 | from scipy.ndimage.interpolation import zoom 13 | from tqdm import tqdm 14 | 15 | # from networks.efficientunet import UNet 16 | from networks.net_factory import net_factory 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--root_path', type=str, 20 | default='../data/ACDC', help='Name of Experiment') 21 | parser.add_argument('--exp', type=str, 22 | default='ACDC/Fully_Supervised', help='experiment_name') 23 | parser.add_argument('--model', type=str, 24 | default='unet', help='model_name') 25 | parser.add_argument('--num_classes', type=int, default=4, 26 | help='output channel of network') 27 | parser.add_argument('--labeled_num', type=int, default=3, 28 | help='labeled data') 29 | 30 | 31 | def calculate_metric_percase(pred, gt): 32 | pred[pred > 0] = 1 33 | gt[gt > 0] = 1 34 | dice = metric.binary.dc(pred, gt) 35 | asd = metric.binary.asd(pred, gt) 36 | hd95 = metric.binary.hd95(pred, gt) 37 | return dice, hd95, asd 38 | 39 | 40 | def test_single_volume(case, net, test_save_path, FLAGS): 41 | h5f = h5py.File(FLAGS.root_path + "/data/{}.h5".format(case), 'r') 42 | image = h5f['image'][:] 43 | label = h5f['label'][:] 44 | prediction = np.zeros_like(label) 45 | for ind in range(image.shape[0]): 46 | slice = image[ind, :, :] 47 | x, y = slice.shape[0], slice.shape[1] 48 | slice = zoom(slice, (256 / x, 256 / y), order=0) 49 | input = torch.from_numpy(slice).unsqueeze( 50 | 0).unsqueeze(0).float().cuda() 51 | net.eval() 52 | with torch.no_grad(): 53 | if FLAGS.model == "unet_urds": 54 | out_main, _, _, _ = net(input) 55 | else: 56 | out_main = net(input) 57 | out = torch.argmax(torch.softmax( 58 | out_main, dim=1), dim=1).squeeze(0) 59 | out = out.cpu().detach().numpy() 60 | pred = zoom(out, (x / 256, y / 256), order=0) 61 | prediction[ind] = pred 62 | 63 | first_metric = calculate_metric_percase(prediction == 1, label == 1) 64 | second_metric = calculate_metric_percase(prediction == 2, label == 2) 65 | third_metric = calculate_metric_percase(prediction == 3, label == 3) 66 | 67 | img_itk = sitk.GetImageFromArray(image.astype(np.float32)) 68 | img_itk.SetSpacing((1, 1, 10)) 69 | prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) 70 | prd_itk.SetSpacing((1, 1, 10)) 71 | lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) 72 | lab_itk.SetSpacing((1, 1, 10)) 73 | sitk.WriteImage(prd_itk, test_save_path + case + "_pred.nii.gz") 74 | sitk.WriteImage(img_itk, test_save_path + case + "_img.nii.gz") 75 | sitk.WriteImage(lab_itk, test_save_path + case + "_gt.nii.gz") 76 | return first_metric, second_metric, third_metric 77 | 78 | 79 | def Inference(FLAGS): 80 | with open(FLAGS.root_path + '/test.list', 'r') as f: 81 | image_list = f.readlines() 82 | image_list = sorted([item.replace('\n', '').split(".")[0] 83 | for item in image_list]) 84 | snapshot_path = "../model/{}_{}_labeled/{}".format( 85 | FLAGS.exp, FLAGS.labeled_num, FLAGS.model) 86 | test_save_path = "../model/{}_{}_labeled/{}_predictions/".format( 87 | FLAGS.exp, FLAGS.labeled_num, FLAGS.model) 88 | if os.path.exists(test_save_path): 89 | shutil.rmtree(test_save_path) 90 | os.makedirs(test_save_path) 91 | net = net_factory(net_type=FLAGS.model, in_chns=1, 92 | class_num=FLAGS.num_classes) 93 | save_mode_path = os.path.join( 94 | snapshot_path, '{}_best_model.pth'.format(FLAGS.model)) 95 | net.load_state_dict(torch.load(save_mode_path)) 96 | print("init weight from {}".format(save_mode_path)) 97 | net.eval() 98 | 99 | first_total = 0.0 100 | second_total = 0.0 101 | third_total = 0.0 102 | for case in tqdm(image_list): 103 | first_metric, second_metric, third_metric = test_single_volume( 104 | case, net, test_save_path, FLAGS) 105 | first_total += np.asarray(first_metric) 106 | second_total += np.asarray(second_metric) 107 | third_total += np.asarray(third_metric) 108 | avg_metric = [first_total / len(image_list), second_total / 109 | len(image_list), third_total / len(image_list)] 110 | return avg_metric 111 | 112 | 113 | if __name__ == '__main__': 114 | FLAGS = parser.parse_args() 115 | metric = Inference(FLAGS) 116 | print(metric) 117 | print((metric[0]+metric[1]+metric[2])/3) 118 | -------------------------------------------------------------------------------- /code/test_3D.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from glob import glob 5 | 6 | import torch 7 | 8 | from networks.unet_3D import unet_3D 9 | from test_3D_util import test_all_case 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--root_path', type=str, 13 | default='../data/BraTS2019', help='Name of Experiment') 14 | parser.add_argument('--exp', type=str, 15 | default='BraTS2019/Interpolation_Consistency_Training_25', help='experiment_name') 16 | parser.add_argument('--model', type=str, 17 | default='unet_3D', help='model_name') 18 | 19 | 20 | def Inference(FLAGS): 21 | snapshot_path = "../model/{}/{}".format(FLAGS.exp, FLAGS.model) 22 | num_classes = 2 23 | test_save_path = "../model/{}/Prediction".format(FLAGS.exp) 24 | if os.path.exists(test_save_path): 25 | shutil.rmtree(test_save_path) 26 | os.makedirs(test_save_path) 27 | net = unet_3D(n_classes=num_classes, in_channels=1).cuda() 28 | save_mode_path = os.path.join( 29 | snapshot_path, '{}_best_model.pth'.format(FLAGS.model)) 30 | net.load_state_dict(torch.load(save_mode_path)) 31 | print("init weight from {}".format(save_mode_path)) 32 | net.eval() 33 | avg_metric = test_all_case(net, base_dir=FLAGS.root_path, method=FLAGS.model, test_list="test.txt", num_classes=num_classes, 34 | patch_size=(96, 96, 96), stride_xy=64, stride_z=64, test_save_path=test_save_path) 35 | return avg_metric 36 | 37 | 38 | if __name__ == '__main__': 39 | FLAGS = parser.parse_args() 40 | metric = Inference(FLAGS) 41 | print(metric) 42 | -------------------------------------------------------------------------------- /code/test_3D_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import h5py 4 | import nibabel as nib 5 | import numpy as np 6 | import SimpleITK as sitk 7 | import torch 8 | import torch.nn.functional as F 9 | from medpy import metric 10 | from skimage.measure import label 11 | from tqdm import tqdm 12 | 13 | 14 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 15 | w, h, d = image.shape 16 | 17 | # if the size of image is less than patch_size, then padding it 18 | add_pad = False 19 | if w < patch_size[0]: 20 | w_pad = patch_size[0]-w 21 | add_pad = True 22 | else: 23 | w_pad = 0 24 | if h < patch_size[1]: 25 | h_pad = patch_size[1]-h 26 | add_pad = True 27 | else: 28 | h_pad = 0 29 | if d < patch_size[2]: 30 | d_pad = patch_size[2]-d 31 | add_pad = True 32 | else: 33 | d_pad = 0 34 | wl_pad, wr_pad = w_pad//2, w_pad-w_pad//2 35 | hl_pad, hr_pad = h_pad//2, h_pad-h_pad//2 36 | dl_pad, dr_pad = d_pad//2, d_pad-d_pad//2 37 | if add_pad: 38 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), 39 | (dl_pad, dr_pad)], mode='constant', constant_values=0) 40 | ww, hh, dd = image.shape 41 | 42 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 43 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 44 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 45 | # print("{}, {}, {}".format(sx, sy, sz)) 46 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 47 | cnt = np.zeros(image.shape).astype(np.float32) 48 | 49 | for x in range(0, sx): 50 | xs = min(stride_xy*x, ww-patch_size[0]) 51 | for y in range(0, sy): 52 | ys = min(stride_xy * y, hh-patch_size[1]) 53 | for z in range(0, sz): 54 | zs = min(stride_z * z, dd-patch_size[2]) 55 | test_patch = image[xs:xs+patch_size[0], 56 | ys:ys+patch_size[1], zs:zs+patch_size[2]] 57 | test_patch = np.expand_dims(np.expand_dims( 58 | test_patch, axis=0), axis=0).astype(np.float32) 59 | test_patch = torch.from_numpy(test_patch).cuda() 60 | 61 | with torch.no_grad(): 62 | y1 = net(test_patch) 63 | # ensemble 64 | y = torch.softmax(y1, dim=1) 65 | y = y.cpu().data.numpy() 66 | y = y[0, :, :, :, :] 67 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 68 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 69 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 70 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 71 | score_map = score_map/np.expand_dims(cnt, axis=0) 72 | label_map = np.argmax(score_map, axis=0) 73 | 74 | if add_pad: 75 | label_map = label_map[wl_pad:wl_pad+w, 76 | hl_pad:hl_pad+h, dl_pad:dl_pad+d] 77 | score_map = score_map[:, wl_pad:wl_pad + 78 | w, hl_pad:hl_pad+h, dl_pad:dl_pad+d] 79 | return label_map 80 | 81 | 82 | def cal_metric(gt, pred): 83 | if pred.sum() > 0 and gt.sum() > 0: 84 | dice = metric.binary.dc(pred, gt) 85 | hd95 = metric.binary.hd95(pred, gt) 86 | return np.array([dice, hd95]) 87 | else: 88 | return np.zeros(2) 89 | 90 | 91 | def test_all_case(net, base_dir, method="unet_3D", test_list="full_test.list", num_classes=4, patch_size=(48, 160, 160), stride_xy=32, stride_z=24, test_save_path=None): 92 | with open(base_dir + '/{}'.format(test_list), 'r') as f: 93 | image_list = f.readlines() 94 | image_list = [base_dir + "/data/{}.h5".format( 95 | item.replace('\n', '').split(",")[0]) for item in image_list] 96 | total_metric = np.zeros((num_classes-1, 4)) 97 | print("Testing begin") 98 | with open(test_save_path + "/{}.txt".format(method), "a") as f: 99 | for image_path in tqdm(image_list): 100 | ids = image_path.split("/")[-1].replace(".h5", "") 101 | h5f = h5py.File(image_path, 'r') 102 | image = h5f['image'][:] 103 | label = h5f['label'][:] 104 | prediction = test_single_case( 105 | net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 106 | metric = calculate_metric_percase(prediction == 1, label == 1) 107 | total_metric[0, :] += metric 108 | f.writelines("{},{},{},{},{}\n".format( 109 | ids, metric[0], metric[1], metric[2], metric[3])) 110 | 111 | pred_itk = sitk.GetImageFromArray(prediction.astype(np.uint8)) 112 | pred_itk.SetSpacing((1.0, 1.0, 1.0)) 113 | sitk.WriteImage(pred_itk, test_save_path + 114 | "/{}_pred.nii.gz".format(ids)) 115 | 116 | img_itk = sitk.GetImageFromArray(image) 117 | img_itk.SetSpacing((1.0, 1.0, 1.0)) 118 | sitk.WriteImage(img_itk, test_save_path + 119 | "/{}_img.nii.gz".format(ids)) 120 | 121 | lab_itk = sitk.GetImageFromArray(label.astype(np.uint8)) 122 | lab_itk.SetSpacing((1.0, 1.0, 1.0)) 123 | sitk.WriteImage(lab_itk, test_save_path + 124 | "/{}_lab.nii.gz".format(ids)) 125 | f.writelines("Mean metrics,{},{},{},{}".format(total_metric[0, 0] / len(image_list), total_metric[0, 1] / len( 126 | image_list), total_metric[0, 2] / len(image_list), total_metric[0, 3] / len(image_list))) 127 | f.close() 128 | print("Testing end") 129 | return total_metric / len(image_list) 130 | 131 | 132 | def cal_dice(prediction, label, num=2): 133 | total_dice = np.zeros(num-1) 134 | for i in range(1, num): 135 | prediction_tmp = (prediction == i) 136 | label_tmp = (label == i) 137 | prediction_tmp = prediction_tmp.astype(np.float) 138 | label_tmp = label_tmp.astype(np.float) 139 | 140 | dice = 2 * np.sum(prediction_tmp * label_tmp) / \ 141 | (np.sum(prediction_tmp) + np.sum(label_tmp)) 142 | total_dice[i - 1] += dice 143 | 144 | return total_dice 145 | 146 | 147 | def calculate_metric_percase(pred, gt): 148 | dice = metric.binary.dc(pred, gt) 149 | ravd = abs(metric.binary.ravd(pred, gt)) 150 | hd = metric.binary.hd95(pred, gt) 151 | asd = metric.binary.asd(pred, gt) 152 | return np.array([dice, ravd, hd, asd]) 153 | -------------------------------------------------------------------------------- /code/test_acdc_unet_semi_seg.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Fully_Supervised --num_classes 4 --labeled_num 7 && \ 2 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Entropy_Minimization --num_classes 4 --labeled_num 7 && \ 3 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Interpolation_Consistency_Training --num_classes 4 --labeled_num 7 && \ 4 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Mean_Teacher --num_classes 4 --labeled_num 7 && \ 5 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Uncertainty_Aware_Mean_Teacher --num_classes 4 --labeled_num 7 && \ 6 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Adversarial_Network --num_classes 4 --labeled_num 7 && \ 7 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Uncertainty_Rectified_Pyramid_Consistency --model unet_urpc --num_classes 4 --labeled_num 7 && \ 8 | CUDA_VISIBLE_DEVICES=0 python test_2D_fully.py --root_path ../data/ACDC --exp ACDC/Fully_Supervised --num_classes 4 --labeled_num 140 -------------------------------------------------------------------------------- /code/test_brats2019_semi_seg.sh: -------------------------------------------------------------------------------- 1 | # & means run these methods at the same time, and && means run these methods one by one 2 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Fully_supervised_25 --model unet_3D && 3 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Fully_supervised_250 --model unet_3D && 4 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Mean_Teacher_25 --model unet_3D && 5 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Uncertainty_Aware_Mean_Teacher_25 --model unet_3D && 6 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Interpolation_Consistency_Training_25 --model unet_3D && 7 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Entropy_Minimization_25 --model unet_3D && 8 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Cross_Pseudo_Supervision_25 --model unet_3D && 9 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Adversarial_Network_25 --model unet_3D && 10 | python -u test_3D.py --root_path ../data/BraTS2019 --exp BraTS2019/Uncertainty_Rectified_Pyramid_Consistency_25 --model unet_3D_dv_semi -------------------------------------------------------------------------------- /code/test_urpc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from glob import glob 5 | import numpy 6 | 7 | import torch 8 | 9 | from networks.unet_3D_dv_semi import unet_3D_dv_semi 10 | from networks.unet_3D import unet_3D 11 | from test_urpc_util import test_all_case 12 | 13 | 14 | def net_factory(net_type="unet_3D", num_classes=3, in_channels=1): 15 | if net_type == "unet_3D": 16 | net = unet_3D(n_classes=num_classes, in_channels=in_channels).cuda() 17 | elif net_type == "unet_3D_dv_semi": 18 | net = unet_3D_dv_semi(n_classes=num_classes, 19 | in_channels=in_channels).cuda() 20 | else: 21 | net = None 22 | return net 23 | 24 | 25 | def Inference(FLAGS): 26 | snapshot_path = "../model/{}/{}".format(FLAGS.exp, FLAGS.model) 27 | num_classes = 2 28 | test_save_path = "../model/{}/Prediction".format(FLAGS.exp) 29 | if os.path.exists(test_save_path): 30 | shutil.rmtree(test_save_path) 31 | os.makedirs(test_save_path) 32 | net = net_factory(FLAGS.model, num_classes, in_channels=1) 33 | save_mode_path = os.path.join( 34 | snapshot_path, '{}_best_model.pth'.format(FLAGS.model)) 35 | net.load_state_dict(torch.load(save_mode_path)) 36 | print("init weight from {}".format(save_mode_path)) 37 | net.eval() 38 | avg_metric = test_all_case(net, base_dir=FLAGS.root_path, method=FLAGS.model, test_list="test.txt", num_classes=num_classes, 39 | patch_size=(96, 96, 96), stride_xy=64, stride_z=64, test_save_path=test_save_path) 40 | return avg_metric 41 | 42 | 43 | if __name__ == '__main__': 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--root_path', type=str, 47 | default='../data/BraTS2019', help='Name of Experiment') 48 | parser.add_argument('--exp', type=str, 49 | default="BraTS2019/Uncertainty_Rectified_Pyramid_Consistency_25_labeled", help='experiment_name') 50 | parser.add_argument('--model', type=str, 51 | default="unet_3D_dv_semi", help='model_name') 52 | FLAGS = parser.parse_args() 53 | 54 | metric = Inference(FLAGS) 55 | print(metric) 56 | -------------------------------------------------------------------------------- /code/test_urpc_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import h5py 4 | import nibabel as nib 5 | import numpy as np 6 | import SimpleITK as sitk 7 | import torch 8 | import torch.nn.functional as F 9 | from medpy import metric 10 | from skimage.measure import label 11 | from tqdm import tqdm 12 | 13 | 14 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 15 | w, h, d = image.shape 16 | 17 | # if the size of image is less than patch_size, then padding it 18 | add_pad = False 19 | if w < patch_size[0]: 20 | w_pad = patch_size[0]-w 21 | add_pad = True 22 | else: 23 | w_pad = 0 24 | if h < patch_size[1]: 25 | h_pad = patch_size[1]-h 26 | add_pad = True 27 | else: 28 | h_pad = 0 29 | if d < patch_size[2]: 30 | d_pad = patch_size[2]-d 31 | add_pad = True 32 | else: 33 | d_pad = 0 34 | wl_pad, wr_pad = w_pad//2, w_pad-w_pad//2 35 | hl_pad, hr_pad = h_pad//2, h_pad-h_pad//2 36 | dl_pad, dr_pad = d_pad//2, d_pad-d_pad//2 37 | if add_pad: 38 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), 39 | (dl_pad, dr_pad)], mode='constant', constant_values=0) 40 | ww, hh, dd = image.shape 41 | 42 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 43 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 44 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 45 | # print("{}, {}, {}".format(sx, sy, sz)) 46 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 47 | cnt = np.zeros(image.shape).astype(np.float32) 48 | 49 | for x in range(0, sx): 50 | xs = min(stride_xy*x, ww-patch_size[0]) 51 | for y in range(0, sy): 52 | ys = min(stride_xy * y, hh-patch_size[1]) 53 | for z in range(0, sz): 54 | zs = min(stride_z * z, dd-patch_size[2]) 55 | test_patch = image[xs:xs+patch_size[0], 56 | ys:ys+patch_size[1], zs:zs+patch_size[2]] 57 | test_patch = np.expand_dims(np.expand_dims( 58 | test_patch, axis=0), axis=0).astype(np.float32) 59 | test_patch = torch.from_numpy(test_patch).cuda() 60 | 61 | with torch.no_grad(): 62 | y_main, y_aux1, y_aux2, y_aux3 = net(test_patch) 63 | # ensemble 64 | y_main = torch.softmax(y_main, dim=1) 65 | y_aux1 = torch.softmax(y_aux1, dim=1) 66 | y_aux2 = torch.softmax(y_aux2, dim=1) 67 | y_aux3 = torch.softmax(y_aux3, dim=1) 68 | y = y_main 69 | # y = (y_main+y_aux1+y_aux2+y_aux3) 70 | y = y.cpu().data.numpy() 71 | y = y[0, :, :, :, :] 72 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 73 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 74 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 75 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 76 | score_map = score_map/np.expand_dims(cnt, axis=0) 77 | label_map = np.argmax(score_map, axis=0) 78 | 79 | if add_pad: 80 | label_map = label_map[wl_pad:wl_pad+w, 81 | hl_pad:hl_pad+h, dl_pad:dl_pad+d] 82 | score_map = score_map[:, wl_pad:wl_pad + 83 | w, hl_pad:hl_pad+h, dl_pad:dl_pad+d] 84 | return label_map 85 | 86 | 87 | def cal_metric(gt, pred): 88 | if pred.sum() > 0 and gt.sum() > 0: 89 | dice = metric.binary.dc(pred, gt) 90 | hd95 = metric.binary.hd95(pred, gt) 91 | return np.array([dice, hd95]) 92 | else: 93 | return np.zeros(2) 94 | 95 | 96 | def test_all_case(net, base_dir, method="unet_3D", test_list="full_test.list", num_classes=4, patch_size=(48, 160, 160), stride_xy=32, stride_z=24, test_save_path=None): 97 | with open(base_dir + '/{}'.format(test_list), 'r') as f: 98 | image_list = f.readlines() 99 | image_list = [base_dir + "/data/{}.h5".format( 100 | item.replace('\n', '').split(",")[0]) for item in image_list] 101 | total_metric = np.zeros((num_classes - 1, 4)) 102 | print("Testing begin") 103 | with open(test_save_path + "/{}.txt".format(method), "a") as f: 104 | for image_path in tqdm(image_list): 105 | ids = image_path.split("/")[-1].replace(".h5", "") 106 | h5f = h5py.File(image_path, 'r') 107 | image = h5f['image'][:] 108 | label = h5f['label'][:] 109 | prediction = test_single_case( 110 | net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 111 | 112 | metric = calculate_metric_percase(prediction == 1, label == 1) 113 | total_metric[0, :] += metric 114 | f.writelines("{},{},{},{},{}\n".format( 115 | ids, metric[0], metric[1], metric[2], metric[3])) 116 | 117 | pred_itk = sitk.GetImageFromArray(prediction.astype(np.uint8)) 118 | pred_itk.SetSpacing((1.0, 1.0, 1.0)) 119 | sitk.WriteImage(pred_itk, test_save_path + 120 | "/{}_pred.nii.gz".format(ids)) 121 | 122 | img_itk = sitk.GetImageFromArray(image) 123 | img_itk.SetSpacing((1.0, 1.0, 1.0)) 124 | sitk.WriteImage(img_itk, test_save_path + 125 | "/{}_img.nii.gz".format(ids)) 126 | 127 | lab_itk = sitk.GetImageFromArray(label.astype(np.uint8)) 128 | lab_itk.SetSpacing((1.0, 1.0, 1.0)) 129 | sitk.WriteImage(lab_itk, test_save_path + 130 | "/{}_lab.nii.gz".format(ids)) 131 | f.writelines("Mean metrics,{},{},{},{}".format(total_metric[0, 0] / len(image_list), total_metric[0, 1] / len( 132 | image_list), total_metric[0, 2] / len(image_list), total_metric[0, 3] / len(image_list))) 133 | f.close() 134 | print("Testing end") 135 | return total_metric / len(image_list) 136 | 137 | 138 | def cal_dice(prediction, label, num=2): 139 | total_dice = np.zeros(num-1) 140 | for i in range(1, num): 141 | prediction_tmp = (prediction == i) 142 | label_tmp = (label == i) 143 | prediction_tmp = prediction_tmp.astype(np.float) 144 | label_tmp = label_tmp.astype(np.float) 145 | 146 | dice = 2 * np.sum(prediction_tmp * label_tmp) / \ 147 | (np.sum(prediction_tmp) + np.sum(label_tmp)) 148 | total_dice[i - 1] += dice 149 | 150 | return total_dice 151 | 152 | 153 | def calculate_metric_percase(pred, gt): 154 | if pred.sum() > 0 and gt.sum() > 0: 155 | dice = metric.binary.dc(pred, gt) 156 | ravd = abs(metric.binary.ravd(pred, gt)) 157 | hd = metric.binary.hd95(pred, gt) 158 | asd = metric.binary.asd(pred, gt) 159 | return np.array([dice, ravd, hd, asd]) 160 | else: 161 | return np.zeros(4) 162 | -------------------------------------------------------------------------------- /code/train_acdc_unet_semi_seg.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_fully_supervised_2D.py --root_path ../data/ACDC --exp ACDC/Fully_Supervised --num_classes 4 --labeled_num 7 && \ 2 | CUDA_VISIBLE_DEVICES=0 python train_entropy_minimization_2D.py --root_path ../data/ACDC --exp ACDC/Entropy_Minimization --num_classes 4 --labeled_num 7 && \ 3 | CUDA_VISIBLE_DEVICES=0 python train_interpolation_consistency_training_2D.py --root_path ../data/ACDC --exp ACDC/Interpolation_Consistency_Training --num_classes 4 --labeled_num 7 && \ 4 | CUDA_VISIBLE_DEVICES=0 python train_mean_teacher_2D.py --root_path ../data/ACDC --exp ACDC/Mean_Teacher --num_classes 4 --labeled_num 7 && \ 5 | CUDA_VISIBLE_DEVICES=0 python train_uncertainty_aware_mean_teacher_2D.py --root_path ../data/ACDC --exp ACDC/Uncertainty_Aware_Mean_Teacher --num_classes 4 --labeled_num 7 && \ 6 | CUDA_VISIBLE_DEVICES=0 python train_adversarial_network_2D.py --root_path ../data/ACDC --exp ACDC/Adversarial_Network --num_classes 4 --labeled_num 7 && \ 7 | CUDA_VISIBLE_DEVICES=0 python train_uncertainty_rectified_pyramid_consistency_2D.py --root_path ../data/ACDC --exp ACDC/Uncertainty_Rectified_Pyramid_Consistency --num_classes 4 --labeled_num 7 && \ 8 | CUDA_VISIBLE_DEVICES=0 python train_fully_supervised_2D.py --root_path ../data/ACDC --exp ACDC/Fully_Supervised --num_classes 4 --labeled_num 140 -------------------------------------------------------------------------------- /code/train_brats2019_semi_seg.sh: -------------------------------------------------------------------------------- 1 | # & means run these methods at the same time, and && means run these methods one by one 2 | python -u train_fully_supervised_3D.py --labeled_num 25 --root_path ../data/BraTS2019 --max_iterations 30000 --exp BraTS2019/Fully_supervised --base_lr 0.1 && 3 | python -u train_fully_supervised_3D.py --labeled_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --exp BraTS2019/Fully_supervised --base_lr 0.1 && 4 | python -u train_adversarial_network_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --exp BraTS2019/Adversarial_Network --base_lr 0.1 && 5 | python -u train_entropy_minimization_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --exp BraTS2019/Entropy_Minimization --base_lr 0.1 && 6 | python -u train_interpolation_consistency_training_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --base_lr 0.1 --exp BraTS2019/Interpolation_Consistency_Training && 7 | python -u train_mean_teacher_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --exp BraTS2019/Mean_Teacher --base_lr 0.1 && 8 | python -u train_uncertainty_aware_mean_teacher_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --base_lr 0.1 --exp BraTS2019/Uncertainty_Aware_Mean_Teacher && 9 | python -u train_uncertainty_rectified_pyramid_consistency_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --base_lr 0.1 --exp BraTS2019/Uncertainty_Rectified_Pyramid_Consistency && 10 | python -u train_cross_pseudo_supervision_3D.py --labeled_num 25 --total_num 250 --root_path ../data/BraTS2019 --max_iterations 30000 --base_lr 0.1 --exp BraTS2019/Cross_Pseudo_Supervision -------------------------------------------------------------------------------- /code/train_fully_supervised_2D.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from tensorboardX import SummaryWriter 16 | from torch.nn import BCEWithLogitsLoss 17 | from torch.nn.modules.loss import CrossEntropyLoss 18 | from torch.utils.data import DataLoader 19 | from torchvision import transforms 20 | from torchvision.utils import make_grid 21 | from tqdm import tqdm 22 | 23 | from dataloaders import utils 24 | from dataloaders.dataset import BaseDataSets, RandomGenerator 25 | from networks.net_factory import net_factory 26 | from utils import losses, metrics, ramps 27 | from val_2D import test_single_volume, test_single_volume_ds 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--root_path', type=str, 31 | default='../data/ACDC', help='Name of Experiment') 32 | parser.add_argument('--exp', type=str, 33 | default='ACDC/Fully_Supervised', help='experiment_name') 34 | parser.add_argument('--model', type=str, 35 | default='unet', help='model_name') 36 | parser.add_argument('--num_classes', type=int, default=4, 37 | help='output channel of network') 38 | parser.add_argument('--max_iterations', type=int, 39 | default=30000, help='maximum epoch number to train') 40 | parser.add_argument('--batch_size', type=int, default=24, 41 | help='batch_size per gpu') 42 | parser.add_argument('--deterministic', type=int, default=1, 43 | help='whether use deterministic training') 44 | parser.add_argument('--base_lr', type=float, default=0.01, 45 | help='segmentation network learning rate') 46 | parser.add_argument('--patch_size', type=list, default=[256, 256], 47 | help='patch size of network input') 48 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 49 | parser.add_argument('--labeled_num', type=int, default=50, 50 | help='labeled data') 51 | args = parser.parse_args() 52 | 53 | 54 | def patients_to_slices(dataset, patiens_num): 55 | ref_dict = None 56 | if "ACDC" in dataset: 57 | ref_dict = {"3": 68, "7": 136, 58 | "14": 256, "21": 396, "28": 512, "35": 664, "140": 1312} 59 | elif "Prostate": 60 | ref_dict = {"2": 27, "4": 53, "8": 120, 61 | "12": 179, "16": 256, "21": 312, "42": 623} 62 | else: 63 | print("Error") 64 | return ref_dict[str(patiens_num)] 65 | 66 | 67 | def train(args, snapshot_path): 68 | base_lr = args.base_lr 69 | num_classes = args.num_classes 70 | batch_size = args.batch_size 71 | max_iterations = args.max_iterations 72 | 73 | labeled_slice = patients_to_slices(args.root_path, args.labeled_num) 74 | 75 | model = net_factory(net_type=args.model, in_chns=1, class_num=num_classes) 76 | db_train = BaseDataSets(base_dir=args.root_path, split="train", num=labeled_slice, transform=transforms.Compose([ 77 | RandomGenerator(args.patch_size) 78 | ])) 79 | db_val = BaseDataSets(base_dir=args.root_path, split="val") 80 | 81 | def worker_init_fn(worker_id): 82 | random.seed(args.seed + worker_id) 83 | 84 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, 85 | num_workers=16, pin_memory=True, worker_init_fn=worker_init_fn) 86 | valloader = DataLoader(db_val, batch_size=1, shuffle=False, 87 | num_workers=1) 88 | 89 | model.train() 90 | 91 | optimizer = optim.SGD(model.parameters(), lr=base_lr, 92 | momentum=0.9, weight_decay=0.0001) 93 | ce_loss = CrossEntropyLoss() 94 | dice_loss = losses.DiceLoss(num_classes) 95 | 96 | writer = SummaryWriter(snapshot_path + '/log') 97 | logging.info("{} iterations per epoch".format(len(trainloader))) 98 | 99 | iter_num = 0 100 | max_epoch = max_iterations // len(trainloader) + 1 101 | best_performance = 0.0 102 | iterator = tqdm(range(max_epoch), ncols=70) 103 | for epoch_num in iterator: 104 | for i_batch, sampled_batch in enumerate(trainloader): 105 | 106 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 107 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 108 | 109 | outputs = model(volume_batch) 110 | outputs_soft = torch.softmax(outputs, dim=1) 111 | 112 | loss_ce = ce_loss(outputs, label_batch[:].long()) 113 | loss_dice = dice_loss(outputs_soft, label_batch.unsqueeze(1)) 114 | loss = 0.5 * (loss_dice + loss_ce) 115 | optimizer.zero_grad() 116 | loss.backward() 117 | optimizer.step() 118 | 119 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 120 | for param_group in optimizer.param_groups: 121 | param_group['lr'] = lr_ 122 | 123 | iter_num = iter_num + 1 124 | writer.add_scalar('info/lr', lr_, iter_num) 125 | writer.add_scalar('info/total_loss', loss, iter_num) 126 | writer.add_scalar('info/loss_ce', loss_ce, iter_num) 127 | writer.add_scalar('info/loss_dice', loss_dice, iter_num) 128 | 129 | logging.info( 130 | 'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' % 131 | (iter_num, loss.item(), loss_ce.item(), loss_dice.item())) 132 | 133 | if iter_num % 20 == 0: 134 | image = volume_batch[1, 0:1, :, :] 135 | writer.add_image('train/Image', image, iter_num) 136 | outputs = torch.argmax(torch.softmax( 137 | outputs, dim=1), dim=1, keepdim=True) 138 | writer.add_image('train/Prediction', 139 | outputs[1, ...] * 50, iter_num) 140 | labs = label_batch[1, ...].unsqueeze(0) * 50 141 | writer.add_image('train/GroundTruth', labs, iter_num) 142 | 143 | if iter_num > 0 and iter_num % 200 == 0: 144 | model.eval() 145 | metric_list = 0.0 146 | for i_batch, sampled_batch in enumerate(valloader): 147 | metric_i = test_single_volume( 148 | sampled_batch["image"], sampled_batch["label"], model, classes=num_classes) 149 | metric_list += np.array(metric_i) 150 | metric_list = metric_list / len(db_val) 151 | for class_i in range(num_classes-1): 152 | writer.add_scalar('info/val_{}_dice'.format(class_i+1), 153 | metric_list[class_i, 0], iter_num) 154 | writer.add_scalar('info/val_{}_hd95'.format(class_i+1), 155 | metric_list[class_i, 1], iter_num) 156 | 157 | performance = np.mean(metric_list, axis=0)[0] 158 | 159 | mean_hd95 = np.mean(metric_list, axis=0)[1] 160 | writer.add_scalar('info/val_mean_dice', performance, iter_num) 161 | writer.add_scalar('info/val_mean_hd95', mean_hd95, iter_num) 162 | 163 | if performance > best_performance: 164 | best_performance = performance 165 | save_mode_path = os.path.join(snapshot_path, 166 | 'iter_{}_dice_{}.pth'.format( 167 | iter_num, round(best_performance, 4))) 168 | save_best = os.path.join(snapshot_path, 169 | '{}_best_model.pth'.format(args.model)) 170 | torch.save(model.state_dict(), save_mode_path) 171 | torch.save(model.state_dict(), save_best) 172 | 173 | logging.info( 174 | 'iteration %d : mean_dice : %f mean_hd95 : %f' % (iter_num, performance, mean_hd95)) 175 | model.train() 176 | 177 | if iter_num % 3000 == 0: 178 | save_mode_path = os.path.join( 179 | snapshot_path, 'iter_' + str(iter_num) + '.pth') 180 | torch.save(model.state_dict(), save_mode_path) 181 | logging.info("save model to {}".format(save_mode_path)) 182 | 183 | if iter_num >= max_iterations: 184 | break 185 | if iter_num >= max_iterations: 186 | iterator.close() 187 | break 188 | writer.close() 189 | return "Training Finished!" 190 | 191 | 192 | if __name__ == "__main__": 193 | if not args.deterministic: 194 | cudnn.benchmark = True 195 | cudnn.deterministic = False 196 | else: 197 | cudnn.benchmark = False 198 | cudnn.deterministic = True 199 | 200 | random.seed(args.seed) 201 | np.random.seed(args.seed) 202 | torch.manual_seed(args.seed) 203 | torch.cuda.manual_seed(args.seed) 204 | 205 | snapshot_path = "../model/{}_{}_labeled/{}".format( 206 | args.exp, args.labeled_num, args.model) 207 | if not os.path.exists(snapshot_path): 208 | os.makedirs(snapshot_path) 209 | if os.path.exists(snapshot_path + '/code'): 210 | shutil.rmtree(snapshot_path + '/code') 211 | shutil.copytree('.', snapshot_path + '/code', 212 | shutil.ignore_patterns(['.git', '__pycache__'])) 213 | 214 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 215 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 216 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 217 | logging.info(str(args)) 218 | train(args, snapshot_path) 219 | -------------------------------------------------------------------------------- /code/train_fully_supervised_3D.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | import shutil 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from tensorboardX import SummaryWriter 16 | from torch.nn import BCEWithLogitsLoss 17 | from torch.nn.modules.loss import CrossEntropyLoss 18 | from torch.utils.data import DataLoader 19 | from torchvision import transforms 20 | from torchvision.utils import make_grid 21 | from tqdm import tqdm 22 | 23 | from dataloaders import utils 24 | from dataloaders.brats2019 import (BraTS2019, CenterCrop, RandomCrop, 25 | RandomRotFlip, ToTensor, 26 | TwoStreamBatchSampler) 27 | from networks.net_factory_3d import net_factory_3d 28 | from utils import losses, metrics, ramps 29 | from val_3D import test_all_case 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--root_path', type=str, 33 | default='../data/BraTS2019', help='Name of Experiment') 34 | parser.add_argument('--exp', type=str, 35 | default='BraTs2019_Fully_Supervised', help='experiment_name') 36 | parser.add_argument('--model', type=str, 37 | default='unet_3D', help='model_name') 38 | parser.add_argument('--max_iterations', type=int, 39 | default=30000, help='maximum epoch number to train') 40 | parser.add_argument('--batch_size', type=int, default=2, 41 | help='batch_size per gpu') 42 | parser.add_argument('--deterministic', type=int, default=1, 43 | help='whether use deterministic training') 44 | parser.add_argument('--base_lr', type=float, default=0.01, 45 | help='segmentation network learning rate') 46 | parser.add_argument('--patch_size', type=list, default=[96, 96, 96], 47 | help='patch size of network input') 48 | parser.add_argument('--seed', type=int, default=1337, help='random seed') 49 | parser.add_argument('--labeled_num', type=int, default=25, 50 | help='labeled data') 51 | 52 | args = parser.parse_args() 53 | 54 | 55 | def train(args, snapshot_path): 56 | base_lr = args.base_lr 57 | train_data_path = args.root_path 58 | batch_size = args.batch_size 59 | max_iterations = args.max_iterations 60 | num_classes = 2 61 | model = net_factory_3d(net_type=args.model, in_chns=1, class_num=num_classes) 62 | db_train = BraTS2019(base_dir=train_data_path, 63 | split='train', 64 | num=args.labeled_num, 65 | transform=transforms.Compose([ 66 | RandomRotFlip(), 67 | RandomCrop(args.patch_size), 68 | ToTensor(), 69 | ])) 70 | 71 | def worker_init_fn(worker_id): 72 | random.seed(args.seed + worker_id) 73 | 74 | trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, 75 | num_workers=16, pin_memory=True, worker_init_fn=worker_init_fn) 76 | 77 | model.train() 78 | 79 | optimizer = optim.SGD(model.parameters(), lr=base_lr, 80 | momentum=0.9, weight_decay=0.0001) 81 | ce_loss = CrossEntropyLoss() 82 | dice_loss = losses.DiceLoss(2) 83 | 84 | writer = SummaryWriter(snapshot_path + '/log') 85 | logging.info("{} iterations per epoch".format(len(trainloader))) 86 | 87 | iter_num = 0 88 | max_epoch = max_iterations // len(trainloader) + 1 89 | best_performance = 0.0 90 | iterator = tqdm(range(max_epoch), ncols=70) 91 | for epoch_num in iterator: 92 | for i_batch, sampled_batch in enumerate(trainloader): 93 | 94 | volume_batch, label_batch = sampled_batch['image'], sampled_batch['label'] 95 | volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda() 96 | 97 | outputs = model(volume_batch) 98 | outputs_soft = torch.softmax(outputs, dim=1) 99 | 100 | loss_ce = ce_loss(outputs, label_batch) 101 | loss_dice = dice_loss(outputs_soft, label_batch.unsqueeze(1)) 102 | loss = 0.5 * (loss_dice + loss_ce) 103 | optimizer.zero_grad() 104 | loss.backward() 105 | optimizer.step() 106 | 107 | lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9 108 | for param_group in optimizer.param_groups: 109 | param_group['lr'] = lr_ 110 | 111 | iter_num = iter_num + 1 112 | writer.add_scalar('info/lr', lr_, iter_num) 113 | writer.add_scalar('info/total_loss', loss, iter_num) 114 | writer.add_scalar('info/loss_ce', loss_ce, iter_num) 115 | writer.add_scalar('info/loss_dice', loss_dice, iter_num) 116 | 117 | logging.info( 118 | 'iteration %d : loss : %f, loss_ce: %f, loss_dice: %f' % 119 | (iter_num, loss.item(), loss_ce.item(), loss_dice.item())) 120 | writer.add_scalar('loss/loss', loss, iter_num) 121 | 122 | if iter_num % 20 == 0: 123 | image = volume_batch[0, 0:1, :, :, 20:61:10].permute( 124 | 3, 0, 1, 2).repeat(1, 3, 1, 1) 125 | grid_image = make_grid(image, 5, normalize=True) 126 | writer.add_image('train/Image', grid_image, iter_num) 127 | 128 | image = outputs_soft[0, 1:2, :, :, 20:61:10].permute( 129 | 3, 0, 1, 2).repeat(1, 3, 1, 1) 130 | grid_image = make_grid(image, 5, normalize=False) 131 | writer.add_image('train/Predicted_label', 132 | grid_image, iter_num) 133 | 134 | image = label_batch[0, :, :, 20:61:10].unsqueeze( 135 | 0).permute(3, 0, 1, 2).repeat(1, 3, 1, 1) 136 | grid_image = make_grid(image, 5, normalize=False) 137 | writer.add_image('train/Groundtruth_label', 138 | grid_image, iter_num) 139 | 140 | if iter_num > 0 and iter_num % 200 == 0: 141 | model.eval() 142 | avg_metric = test_all_case( 143 | model, args.root_path, test_list="val.txt", num_classes=2, patch_size=args.patch_size, 144 | stride_xy=64, stride_z=64) 145 | if avg_metric[:, 0].mean() > best_performance: 146 | best_performance = avg_metric[:, 0].mean() 147 | save_mode_path = os.path.join(snapshot_path, 148 | 'iter_{}_dice_{}.pth'.format( 149 | iter_num, round(best_performance, 4))) 150 | save_best = os.path.join(snapshot_path, 151 | '{}_best_model.pth'.format(args.model)) 152 | torch.save(model.state_dict(), save_mode_path) 153 | torch.save(model.state_dict(), save_best) 154 | 155 | writer.add_scalar('info/val_dice_score', 156 | avg_metric[0, 0], iter_num) 157 | writer.add_scalar('info/val_hd95', 158 | avg_metric[0, 1], iter_num) 159 | logging.info( 160 | 'iteration %d : dice_score : %f hd95 : %f' % (iter_num, avg_metric[0, 0].mean(), avg_metric[0, 1].mean())) 161 | model.train() 162 | 163 | if iter_num % 3000 == 0: 164 | save_mode_path = os.path.join( 165 | snapshot_path, 'iter_' + str(iter_num) + '.pth') 166 | torch.save(model.state_dict(), save_mode_path) 167 | logging.info("save model to {}".format(save_mode_path)) 168 | 169 | if iter_num >= max_iterations: 170 | break 171 | if iter_num >= max_iterations: 172 | iterator.close() 173 | break 174 | writer.close() 175 | return "Training Finished!" 176 | 177 | 178 | if __name__ == "__main__": 179 | if not args.deterministic: 180 | cudnn.benchmark = True 181 | cudnn.deterministic = False 182 | else: 183 | cudnn.benchmark = False 184 | cudnn.deterministic = True 185 | 186 | random.seed(args.seed) 187 | np.random.seed(args.seed) 188 | torch.manual_seed(args.seed) 189 | torch.cuda.manual_seed(args.seed) 190 | 191 | snapshot_path = "../model/{}/{}".format(args.exp, args.model) 192 | if not os.path.exists(snapshot_path): 193 | os.makedirs(snapshot_path) 194 | if os.path.exists(snapshot_path + '/code'): 195 | shutil.rmtree(snapshot_path + '/code') 196 | shutil.copytree('.', snapshot_path + '/code', 197 | shutil.ignore_patterns(['.git', '__pycache__'])) 198 | 199 | logging.basicConfig(filename=snapshot_path+"/log.txt", level=logging.INFO, 200 | format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S') 201 | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) 202 | logging.info(str(args)) 203 | train(args, snapshot_path) 204 | -------------------------------------------------------------------------------- /code/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | def dice_loss(score, target): 9 | target = target.float() 10 | smooth = 1e-5 11 | intersect = torch.sum(score * target) 12 | y_sum = torch.sum(target * target) 13 | z_sum = torch.sum(score * score) 14 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 15 | loss = 1 - loss 16 | return loss 17 | 18 | 19 | def dice_loss1(score, target): 20 | target = target.float() 21 | smooth = 1e-5 22 | intersect = torch.sum(score * target) 23 | y_sum = torch.sum(target) 24 | z_sum = torch.sum(score) 25 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 26 | loss = 1 - loss 27 | return loss 28 | 29 | 30 | def entropy_loss(p, C=2): 31 | # p N*C*W*H*D 32 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) / \ 33 | torch.tensor(np.log(C)).cuda() 34 | ent = torch.mean(y1) 35 | 36 | return ent 37 | 38 | 39 | def softmax_dice_loss(input_logits, target_logits): 40 | """Takes softmax on both sides and returns MSE loss 41 | 42 | Note: 43 | - Returns the sum over all examples. Divide by the batch size afterwards 44 | if you want the mean. 45 | - Sends gradients to inputs but not the targets. 46 | """ 47 | assert input_logits.size() == target_logits.size() 48 | input_softmax = F.softmax(input_logits, dim=1) 49 | target_softmax = F.softmax(target_logits, dim=1) 50 | n = input_logits.shape[1] 51 | dice = 0 52 | for i in range(0, n): 53 | dice += dice_loss1(input_softmax[:, i], target_softmax[:, i]) 54 | mean_dice = dice / n 55 | 56 | return mean_dice 57 | 58 | 59 | def entropy_loss_map(p, C=2): 60 | ent = -1*torch.sum(p * torch.log(p + 1e-6), dim=1, 61 | keepdim=True)/torch.tensor(np.log(C)).cuda() 62 | return ent 63 | 64 | 65 | def softmax_mse_loss(input_logits, target_logits, sigmoid=False): 66 | """Takes softmax on both sides and returns MSE loss 67 | 68 | Note: 69 | - Returns the sum over all examples. Divide by the batch size afterwards 70 | if you want the mean. 71 | - Sends gradients to inputs but not the targets. 72 | """ 73 | assert input_logits.size() == target_logits.size() 74 | if sigmoid: 75 | input_softmax = torch.sigmoid(input_logits) 76 | target_softmax = torch.sigmoid(target_logits) 77 | else: 78 | input_softmax = F.softmax(input_logits, dim=1) 79 | target_softmax = F.softmax(target_logits, dim=1) 80 | 81 | mse_loss = (input_softmax-target_softmax)**2 82 | return mse_loss 83 | 84 | 85 | def softmax_kl_loss(input_logits, target_logits, sigmoid=False): 86 | """Takes softmax on both sides and returns KL divergence 87 | 88 | Note: 89 | - Returns the sum over all examples. Divide by the batch size afterwards 90 | if you want the mean. 91 | - Sends gradients to inputs but not the targets. 92 | """ 93 | assert input_logits.size() == target_logits.size() 94 | if sigmoid: 95 | input_log_softmax = torch.log(torch.sigmoid(input_logits)) 96 | target_softmax = torch.sigmoid(target_logits) 97 | else: 98 | input_log_softmax = F.log_softmax(input_logits, dim=1) 99 | target_softmax = F.softmax(target_logits, dim=1) 100 | 101 | # return F.kl_div(input_log_softmax, target_softmax) 102 | kl_div = F.kl_div(input_log_softmax, target_softmax, reduction='mean') 103 | # mean_kl_div = torch.mean(0.2*kl_div[:,0,...]+0.8*kl_div[:,1,...]) 104 | return kl_div 105 | 106 | 107 | def symmetric_mse_loss(input1, input2): 108 | """Like F.mse_loss but sends gradients to both directions 109 | 110 | Note: 111 | - Returns the sum over all examples. Divide by the batch size afterwards 112 | if you want the mean. 113 | - Sends gradients to both input1 and input2. 114 | """ 115 | assert input1.size() == input2.size() 116 | return torch.mean((input1 - input2)**2) 117 | 118 | 119 | class FocalLoss(nn.Module): 120 | def __init__(self, gamma=2, alpha=None, size_average=True): 121 | super(FocalLoss, self).__init__() 122 | self.gamma = gamma 123 | self.alpha = alpha 124 | if isinstance(alpha, (float, int)): 125 | self.alpha = torch.Tensor([alpha, 1-alpha]) 126 | if isinstance(alpha, list): 127 | self.alpha = torch.Tensor(alpha) 128 | self.size_average = size_average 129 | 130 | def forward(self, input, target): 131 | if input.dim() > 2: 132 | # N,C,H,W => N,C,H*W 133 | input = input.view(input.size(0), input.size(1), -1) 134 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 135 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 136 | target = target.view(-1, 1) 137 | 138 | logpt = F.log_softmax(input, dim=1) 139 | logpt = logpt.gather(1, target) 140 | logpt = logpt.view(-1) 141 | pt = Variable(logpt.data.exp()) 142 | 143 | if self.alpha is not None: 144 | if self.alpha.type() != input.data.type(): 145 | self.alpha = self.alpha.type_as(input.data) 146 | at = self.alpha.gather(0, target.data.view(-1)) 147 | logpt = logpt * Variable(at) 148 | 149 | loss = -1 * (1-pt)**self.gamma * logpt 150 | if self.size_average: 151 | return loss.mean() 152 | else: 153 | return loss.sum() 154 | 155 | 156 | class DiceLoss(nn.Module): 157 | def __init__(self, n_classes): 158 | super(DiceLoss, self).__init__() 159 | self.n_classes = n_classes 160 | 161 | def _one_hot_encoder(self, input_tensor): 162 | tensor_list = [] 163 | for i in range(self.n_classes): 164 | temp_prob = input_tensor == i * torch.ones_like(input_tensor) 165 | tensor_list.append(temp_prob) 166 | output_tensor = torch.cat(tensor_list, dim=1) 167 | return output_tensor.float() 168 | 169 | def _dice_loss(self, score, target): 170 | target = target.float() 171 | smooth = 1e-5 172 | intersect = torch.sum(score * target) 173 | y_sum = torch.sum(target * target) 174 | z_sum = torch.sum(score * score) 175 | loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) 176 | loss = 1 - loss 177 | return loss 178 | 179 | def forward(self, inputs, target, weight=None, softmax=False): 180 | if softmax: 181 | inputs = torch.softmax(inputs, dim=1) 182 | target = self._one_hot_encoder(target) 183 | if weight is None: 184 | weight = [1] * self.n_classes 185 | assert inputs.size() == target.size(), 'predict & target shape do not match' 186 | class_wise_dice = [] 187 | loss = 0.0 188 | for i in range(0, self.n_classes): 189 | dice = self._dice_loss(inputs[:, i], target[:, i]) 190 | class_wise_dice.append(1.0 - dice.item()) 191 | loss += dice * weight[i] 192 | return loss / self.n_classes 193 | 194 | 195 | def entropy_minmization(p): 196 | y1 = -1*torch.sum(p*torch.log(p+1e-6), dim=1) 197 | ent = torch.mean(y1) 198 | 199 | return ent 200 | 201 | 202 | def entropy_map(p): 203 | ent_map = -1*torch.sum(p * torch.log(p + 1e-6), dim=1, 204 | keepdim=True) 205 | return ent_map 206 | 207 | 208 | def compute_kl_loss(p, q): 209 | p_loss = F.kl_div(F.log_softmax(p, dim=-1), 210 | F.softmax(q, dim=-1), reduction='none') 211 | q_loss = F.kl_div(F.log_softmax(q, dim=-1), 212 | F.softmax(p, dim=-1), reduction='none') 213 | 214 | # Using function "sum" and "mean" are depending on your task 215 | p_loss = p_loss.mean() 216 | q_loss = q_loss.mean() 217 | 218 | loss = (p_loss + q_loss) / 2 219 | return loss 220 | -------------------------------------------------------------------------------- /code/utils/metrics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/12/14 下午4:41 4 | # @Author : chuyu zhang 5 | # @File : metrics.py 6 | # @Software: PyCharm 7 | 8 | 9 | import numpy as np 10 | from medpy import metric 11 | 12 | 13 | def cal_dice(prediction, label, num=2): 14 | total_dice = np.zeros(num-1) 15 | for i in range(1, num): 16 | prediction_tmp = (prediction == i) 17 | label_tmp = (label == i) 18 | prediction_tmp = prediction_tmp.astype(np.float) 19 | label_tmp = label_tmp.astype(np.float) 20 | 21 | dice = 2 * np.sum(prediction_tmp * label_tmp) / (np.sum(prediction_tmp) + np.sum(label_tmp)) 22 | total_dice[i - 1] += dice 23 | 24 | return total_dice 25 | 26 | 27 | def calculate_metric_percase(pred, gt): 28 | dc = metric.binary.dc(pred, gt) 29 | jc = metric.binary.jc(pred, gt) 30 | hd = metric.binary.hd95(pred, gt) 31 | asd = metric.binary.asd(pred, gt) 32 | 33 | return dc, jc, hd, asd 34 | 35 | 36 | def dice(input, target, ignore_index=None): 37 | smooth = 1. 38 | # using clone, so that it can do change to original target. 39 | iflat = input.clone().view(-1) 40 | tflat = target.clone().view(-1) 41 | if ignore_index is not None: 42 | mask = tflat == ignore_index 43 | tflat[mask] = 0 44 | iflat[mask] = 0 45 | intersection = (iflat * tflat).sum() 46 | 47 | return (2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth) -------------------------------------------------------------------------------- /code/utils/ramps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, Curious AI Ltd. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial 4 | # 4.0 International License. To view a copy of this license, visit 5 | # http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 6 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 7 | 8 | """Functions for ramping hyperparameters up or down 9 | 10 | Each function takes the current training step or epoch, and the 11 | ramp length in the same format, and returns a multiplier between 12 | 0 and 1. 13 | """ 14 | 15 | 16 | import numpy as np 17 | 18 | 19 | def sigmoid_rampup(current, rampup_length): 20 | """Exponential rampup from https://arxiv.org/abs/1610.02242""" 21 | if rampup_length == 0: 22 | return 1.0 23 | else: 24 | current = np.clip(current, 0.0, rampup_length) 25 | phase = 1.0 - current / rampup_length 26 | return float(np.exp(-5.0 * phase * phase)) 27 | 28 | 29 | def linear_rampup(current, rampup_length): 30 | """Linear rampup""" 31 | assert current >= 0 and rampup_length >= 0 32 | if current >= rampup_length: 33 | return 1.0 34 | else: 35 | return current / rampup_length 36 | 37 | 38 | def cosine_rampdown(current, rampdown_length): 39 | """Cosine rampdown from https://arxiv.org/abs/1608.03983""" 40 | assert 0 <= current <= rampdown_length 41 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1)) 42 | -------------------------------------------------------------------------------- /code/utils/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import pickle 9 | import numpy as np 10 | import re 11 | from scipy.ndimage import distance_transform_edt as distance 12 | from skimage import segmentation as skimage_seg 13 | import torch 14 | from torch.utils.data.sampler import Sampler 15 | import torch.distributed as dist 16 | 17 | import networks 18 | 19 | # many issues with this function 20 | def load_model(path): 21 | """Loads model and return it without DataParallel table.""" 22 | if os.path.isfile(path): 23 | print("=> loading checkpoint '{}'".format(path)) 24 | checkpoint = torch.load(path) 25 | 26 | for key in checkpoint["state_dict"]: 27 | print(key) 28 | 29 | # size of the top layer 30 | N = checkpoint["state_dict"]["decoder.out_conv.bias"].size() 31 | 32 | # build skeleton of the model 33 | sob = "sobel.0.weight" in checkpoint["state_dict"].keys() 34 | model = models.__dict__[checkpoint["arch"]](sobel=sob, out=int(N[0])) 35 | 36 | # deal with a dataparallel table 37 | def rename_key(key): 38 | if not "module" in key: 39 | return key 40 | return "".join(key.split(".module")) 41 | 42 | checkpoint["state_dict"] = { 43 | rename_key(key): val for key, val in checkpoint["state_dict"].items() 44 | } 45 | 46 | # load weights 47 | model.load_state_dict(checkpoint["state_dict"]) 48 | print("Loaded") 49 | else: 50 | model = None 51 | print("=> no checkpoint found at '{}'".format(path)) 52 | return model 53 | 54 | 55 | def load_checkpoint(path, model, optimizer, from_ddp=False): 56 | """loads previous checkpoint 57 | 58 | Args: 59 | path (str): path to checkpoint 60 | model (model): model to restore checkpoint to 61 | optimizer (optimizer): torch optimizer to load optimizer state_dict to 62 | from_ddp (bool, optional): load DistributedDataParallel checkpoint to regular model. Defaults to False. 63 | 64 | Returns: 65 | model, optimizer, epoch_num, loss 66 | """ 67 | # load checkpoint 68 | checkpoint = torch.load(path) 69 | # transfer state_dict from checkpoint to model 70 | model.load_state_dict(checkpoint["state_dict"]) 71 | # transfer optimizer state_dict from checkpoint to model 72 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 73 | # track loss 74 | loss = checkpoint["loss"] 75 | return model, optimizer, checkpoint["epoch"], loss.item() 76 | 77 | 78 | def restore_model(logger, snapshot_path, model_num=None): 79 | """wrapper function to read log dir and load restore a previous checkpoint 80 | 81 | Args: 82 | logger (Logger): logger object (for info output to console) 83 | snapshot_path (str): path to checkpoint directory 84 | 85 | Returns: 86 | model, optimizer, start_epoch, performance 87 | """ 88 | try: 89 | # check if there is previous progress to be restored: 90 | logger.info(f"Snapshot path: {snapshot_path}") 91 | iter_num = [] 92 | name = "model_iter" 93 | if model_num: 94 | name = model_num 95 | for filename in os.listdir(snapshot_path): 96 | if name in filename: 97 | basename, extension = os.path.splitext(filename) 98 | iter_num.append(int(basename.split("_")[2])) 99 | iter_num = max(iter_num) 100 | for filename in os.listdir(snapshot_path): 101 | if name in filename and str(iter_num) in filename: 102 | model_checkpoint = filename 103 | except Exception as e: 104 | logger.warning(f"Error finding previous checkpoints: {e}") 105 | 106 | try: 107 | logger.info(f"Restoring model checkpoint: {model_checkpoint}") 108 | model, optimizer, start_epoch, performance = load_checkpoint( 109 | snapshot_path + "/" + model_checkpoint, model, optimizer 110 | ) 111 | logger.info(f"Models restored from iteration {iter_num}") 112 | return model, optimizer, start_epoch, performance 113 | except Exception as e: 114 | logger.warning(f"Unable to restore model checkpoint: {e}, using new model") 115 | 116 | 117 | def save_checkpoint(epoch, model, optimizer, loss, path): 118 | """Saves model as checkpoint""" 119 | torch.save( 120 | { 121 | "epoch": epoch, 122 | "state_dict": model.state_dict(), 123 | "optimizer_state_dict": optimizer.state_dict(), 124 | "loss": loss, 125 | }, 126 | path, 127 | ) 128 | 129 | 130 | class UnifLabelSampler(Sampler): 131 | """Samples elements uniformely accross pseudolabels. 132 | Args: 133 | N (int): size of returned iterator. 134 | images_lists: dict of key (target), value (list of data with this target) 135 | """ 136 | 137 | def __init__(self, N, images_lists): 138 | self.N = N 139 | self.images_lists = images_lists 140 | self.indexes = self.generate_indexes_epoch() 141 | 142 | def generate_indexes_epoch(self): 143 | size_per_pseudolabel = int(self.N / len(self.images_lists)) + 1 144 | res = np.zeros(size_per_pseudolabel * len(self.images_lists)) 145 | 146 | for i in range(len(self.images_lists)): 147 | indexes = np.random.choice( 148 | self.images_lists[i], 149 | size_per_pseudolabel, 150 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel), 151 | ) 152 | res[i * size_per_pseudolabel : (i + 1) * size_per_pseudolabel] = indexes 153 | 154 | np.random.shuffle(res) 155 | return res[: self.N].astype("int") 156 | 157 | def __iter__(self): 158 | return iter(self.indexes) 159 | 160 | def __len__(self): 161 | return self.N 162 | 163 | 164 | class AverageMeter(object): 165 | """Computes and stores the average and current value""" 166 | 167 | def __init__(self): 168 | self.reset() 169 | 170 | def reset(self): 171 | self.val = 0 172 | self.avg = 0 173 | self.sum = 0 174 | self.count = 0 175 | 176 | def update(self, val, n=1): 177 | self.val = val 178 | self.sum += val * n 179 | self.count += n 180 | self.avg = self.sum / self.count 181 | 182 | 183 | def learning_rate_decay(optimizer, t, lr_0): 184 | for param_group in optimizer.param_groups: 185 | lr = lr_0 / np.sqrt(1 + lr_0 * param_group["weight_decay"] * t) 186 | param_group["lr"] = lr 187 | 188 | 189 | class Logger: 190 | """Class to update every epoch to keep trace of the results 191 | Methods: 192 | - log() log and save 193 | """ 194 | 195 | def __init__(self, path): 196 | self.path = path 197 | self.data = [] 198 | 199 | def log(self, train_point): 200 | self.data.append(train_point) 201 | with open(os.path.join(self.path), "wb") as fp: 202 | pickle.dump(self.data, fp, -1) 203 | 204 | 205 | def compute_sdf(img_gt, out_shape): 206 | """ 207 | compute the signed distance map of binary mask 208 | input: segmentation, shape = (batch_size, x, y, z) 209 | output: the Signed Distance Map (SDM) 210 | sdf(x) = 0; x in segmentation boundary 211 | -inf|x-y|; x in segmentation 212 | +inf|x-y|; x out of segmentation 213 | normalize sdf to [-1,1] 214 | """ 215 | 216 | img_gt = img_gt.astype(np.uint8) 217 | normalized_sdf = np.zeros(out_shape) 218 | 219 | for b in range(out_shape[0]): # batch size 220 | posmask = img_gt[b].astype(np.bool) 221 | if posmask.any(): 222 | negmask = ~posmask 223 | posdis = distance(posmask) 224 | negdis = distance(negmask) 225 | boundary = skimage_seg.find_boundaries(posmask, mode="inner").astype( 226 | np.uint8 227 | ) 228 | sdf = (negdis - np.min(negdis)) / (np.max(negdis) - np.min(negdis)) - ( 229 | posdis - np.min(posdis) 230 | ) / (np.max(posdis) - np.min(posdis)) 231 | sdf[boundary == 1] = 0 232 | normalized_sdf[b] = sdf 233 | # assert np.min(sdf) == -1.0, print(np.min(posdis), np.max(posdis), np.min(negdis), np.max(negdis)) 234 | # assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis)) 235 | 236 | return normalized_sdf 237 | 238 | 239 | # set up process group for distributed computing 240 | def distributed_setup(rank, world_size): 241 | os.environ["MASTER_ADDR"] = "localhost" 242 | os.environ["MASTER_PORT"] = "12355" 243 | print("setting up dist process group now") 244 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 245 | 246 | 247 | def load_ddp_to_nddp(state_dict): 248 | pattern = re.compile("module") 249 | for k, v in state_dict.items(): 250 | if re.search("module", k): 251 | model_dict[re.sub(pattern, "", k)] = v 252 | else: 253 | model_dict = state_dict 254 | return model_dict 255 | -------------------------------------------------------------------------------- /code/val_2D.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from medpy import metric 4 | from scipy.ndimage import zoom 5 | 6 | 7 | def calculate_metric_percase(pred, gt): 8 | pred[pred > 0] = 1 9 | gt[gt > 0] = 1 10 | if pred.sum() > 0: 11 | dice = metric.binary.dc(pred, gt) 12 | hd95 = metric.binary.hd95(pred, gt) 13 | return dice, hd95 14 | else: 15 | return 0, 0 16 | 17 | 18 | def test_single_volume(image, label, net, classes, patch_size=[256, 256]): 19 | image, label = image.squeeze(0).cpu().detach( 20 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 21 | prediction = np.zeros_like(label) 22 | for ind in range(image.shape[0]): 23 | slice = image[ind, :, :] 24 | x, y = slice.shape[0], slice.shape[1] 25 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 26 | input = torch.from_numpy(slice).unsqueeze( 27 | 0).unsqueeze(0).float().cuda() 28 | net.eval() 29 | with torch.no_grad(): 30 | out = torch.argmax(torch.softmax( 31 | net(input), dim=1), dim=1).squeeze(0) 32 | out = out.cpu().detach().numpy() 33 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 34 | prediction[ind] = pred 35 | metric_list = [] 36 | for i in range(1, classes): 37 | metric_list.append(calculate_metric_percase( 38 | prediction == i, label == i)) 39 | return metric_list 40 | 41 | 42 | def test_single_volume_ds(image, label, net, classes, patch_size=[256, 256]): 43 | image, label = image.squeeze(0).cpu().detach( 44 | ).numpy(), label.squeeze(0).cpu().detach().numpy() 45 | prediction = np.zeros_like(label) 46 | for ind in range(image.shape[0]): 47 | slice = image[ind, :, :] 48 | x, y = slice.shape[0], slice.shape[1] 49 | slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=0) 50 | input = torch.from_numpy(slice).unsqueeze( 51 | 0).unsqueeze(0).float().cuda() 52 | net.eval() 53 | with torch.no_grad(): 54 | output_main, _, _, _ = net(input) 55 | out = torch.argmax(torch.softmax( 56 | output_main, dim=1), dim=1).squeeze(0) 57 | out = out.cpu().detach().numpy() 58 | pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) 59 | prediction[ind] = pred 60 | metric_list = [] 61 | for i in range(1, classes): 62 | metric_list.append(calculate_metric_percase( 63 | prediction == i, label == i)) 64 | return metric_list 65 | -------------------------------------------------------------------------------- /code/val_3D.py: -------------------------------------------------------------------------------- 1 | import math 2 | from glob import glob 3 | 4 | import h5py 5 | import nibabel as nib 6 | import numpy as np 7 | import SimpleITK as sitk 8 | import torch 9 | import torch.nn.functional as F 10 | from medpy import metric 11 | from tqdm import tqdm 12 | 13 | 14 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 15 | w, h, d = image.shape 16 | 17 | # if the size of image is less than patch_size, then padding it 18 | add_pad = False 19 | if w < patch_size[0]: 20 | w_pad = patch_size[0]-w 21 | add_pad = True 22 | else: 23 | w_pad = 0 24 | if h < patch_size[1]: 25 | h_pad = patch_size[1]-h 26 | add_pad = True 27 | else: 28 | h_pad = 0 29 | if d < patch_size[2]: 30 | d_pad = patch_size[2]-d 31 | add_pad = True 32 | else: 33 | d_pad = 0 34 | wl_pad, wr_pad = w_pad//2, w_pad-w_pad//2 35 | hl_pad, hr_pad = h_pad//2, h_pad-h_pad//2 36 | dl_pad, dr_pad = d_pad//2, d_pad-d_pad//2 37 | if add_pad: 38 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), 39 | (dl_pad, dr_pad)], mode='constant', constant_values=0) 40 | ww, hh, dd = image.shape 41 | 42 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 43 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 44 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 45 | # print("{}, {}, {}".format(sx, sy, sz)) 46 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 47 | cnt = np.zeros(image.shape).astype(np.float32) 48 | 49 | for x in range(0, sx): 50 | xs = min(stride_xy*x, ww-patch_size[0]) 51 | for y in range(0, sy): 52 | ys = min(stride_xy * y, hh-patch_size[1]) 53 | for z in range(0, sz): 54 | zs = min(stride_z * z, dd-patch_size[2]) 55 | test_patch = image[xs:xs+patch_size[0], 56 | ys:ys+patch_size[1], zs:zs+patch_size[2]] 57 | test_patch = np.expand_dims(np.expand_dims( 58 | test_patch, axis=0), axis=0).astype(np.float32) 59 | test_patch = torch.from_numpy(test_patch).cuda() 60 | 61 | with torch.no_grad(): 62 | y1 = net(test_patch) 63 | # ensemble 64 | y = torch.softmax(y1, dim=1) 65 | y = y.cpu().data.numpy() 66 | y = y[0, :, :, :, :] 67 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 68 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 69 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 70 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 71 | score_map = score_map/np.expand_dims(cnt, axis=0) 72 | label_map = np.argmax(score_map, axis=0) 73 | 74 | if add_pad: 75 | label_map = label_map[wl_pad:wl_pad+w, 76 | hl_pad:hl_pad+h, dl_pad:dl_pad+d] 77 | score_map = score_map[:, wl_pad:wl_pad + 78 | w, hl_pad:hl_pad+h, dl_pad:dl_pad+d] 79 | return label_map 80 | 81 | 82 | def cal_metric(gt, pred): 83 | if pred.sum() > 0 and gt.sum() > 0: 84 | dice = metric.binary.dc(pred, gt) 85 | hd95 = metric.binary.hd95(pred, gt) 86 | return np.array([dice, hd95]) 87 | else: 88 | return np.zeros(2) 89 | 90 | 91 | def test_all_case(net, base_dir, test_list="full_test.list", num_classes=4, patch_size=(48, 160, 160), stride_xy=32, stride_z=24): 92 | with open(base_dir + '/{}'.format(test_list), 'r') as f: 93 | image_list = f.readlines() 94 | image_list = [base_dir + "/data/{}.h5".format( 95 | item.replace('\n', '').split(",")[0]) for item in image_list] 96 | total_metric = np.zeros((num_classes-1, 2)) 97 | print("Validation begin") 98 | for image_path in tqdm(image_list): 99 | h5f = h5py.File(image_path, 'r') 100 | image = h5f['image'][:] 101 | label = h5f['label'][:] 102 | prediction = test_single_case( 103 | net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 104 | for i in range(1, num_classes): 105 | total_metric[i-1, :] += cal_metric(label == i, prediction == i) 106 | print("Validation end") 107 | return total_metric / len(image_list) 108 | -------------------------------------------------------------------------------- /code/val_urpc_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | from glob import glob 3 | 4 | import h5py 5 | import nibabel as nib 6 | import numpy as np 7 | import SimpleITK as sitk 8 | import torch 9 | import torch.nn.functional as F 10 | from medpy import metric 11 | from tqdm import tqdm 12 | 13 | 14 | def test_single_case(net, image, stride_xy, stride_z, patch_size, num_classes=1): 15 | w, h, d = image.shape 16 | 17 | # if the size of image is less than patch_size, then padding it 18 | add_pad = False 19 | if w < patch_size[0]: 20 | w_pad = patch_size[0]-w 21 | add_pad = True 22 | else: 23 | w_pad = 0 24 | if h < patch_size[1]: 25 | h_pad = patch_size[1]-h 26 | add_pad = True 27 | else: 28 | h_pad = 0 29 | if d < patch_size[2]: 30 | d_pad = patch_size[2]-d 31 | add_pad = True 32 | else: 33 | d_pad = 0 34 | wl_pad, wr_pad = w_pad//2, w_pad-w_pad//2 35 | hl_pad, hr_pad = h_pad//2, h_pad-h_pad//2 36 | dl_pad, dr_pad = d_pad//2, d_pad-d_pad//2 37 | if add_pad: 38 | image = np.pad(image, [(wl_pad, wr_pad), (hl_pad, hr_pad), 39 | (dl_pad, dr_pad)], mode='constant', constant_values=0) 40 | ww, hh, dd = image.shape 41 | 42 | sx = math.ceil((ww - patch_size[0]) / stride_xy) + 1 43 | sy = math.ceil((hh - patch_size[1]) / stride_xy) + 1 44 | sz = math.ceil((dd - patch_size[2]) / stride_z) + 1 45 | # print("{}, {}, {}".format(sx, sy, sz)) 46 | score_map = np.zeros((num_classes, ) + image.shape).astype(np.float32) 47 | cnt = np.zeros(image.shape).astype(np.float32) 48 | 49 | for x in range(0, sx): 50 | xs = min(stride_xy*x, ww-patch_size[0]) 51 | for y in range(0, sy): 52 | ys = min(stride_xy * y, hh-patch_size[1]) 53 | for z in range(0, sz): 54 | zs = min(stride_z * z, dd-patch_size[2]) 55 | test_patch = image[xs:xs+patch_size[0], 56 | ys:ys+patch_size[1], zs:zs+patch_size[2]] 57 | test_patch = np.expand_dims(np.expand_dims( 58 | test_patch, axis=0), axis=0).astype(np.float32) 59 | test_patch = torch.from_numpy(test_patch).cuda() 60 | 61 | with torch.no_grad(): 62 | y1, _, _, _ = net(test_patch) 63 | # ensemble 64 | y = torch.softmax(y1, dim=1) 65 | y = y.cpu().data.numpy() 66 | y = y[0, :, :, :, :] 67 | score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 68 | = score_map[:, xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + y 69 | cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] \ 70 | = cnt[xs:xs+patch_size[0], ys:ys+patch_size[1], zs:zs+patch_size[2]] + 1 71 | score_map = score_map/np.expand_dims(cnt, axis=0) 72 | label_map = np.argmax(score_map, axis=0) 73 | 74 | if add_pad: 75 | label_map = label_map[wl_pad:wl_pad+w, 76 | hl_pad:hl_pad+h, dl_pad:dl_pad+d] 77 | score_map = score_map[:, wl_pad:wl_pad + 78 | w, hl_pad:hl_pad+h, dl_pad:dl_pad+d] 79 | return label_map 80 | 81 | 82 | def cal_metric(gt, pred): 83 | if pred.sum() > 0 and gt.sum() > 0: 84 | dice = metric.binary.dc(pred, gt) 85 | hd95 = metric.binary.hd95(pred, gt) 86 | return np.array([dice, hd95]) 87 | else: 88 | return np.zeros(2) 89 | 90 | 91 | def test_all_case(net, base_dir, test_list="val.list", num_classes=4, patch_size=(48, 160, 160), stride_xy=32, stride_z=24): 92 | with open(base_dir + '/{}'.format(test_list), 'r') as f: 93 | image_list = f.readlines() 94 | image_list = [base_dir + "/data/{}.h5".format( 95 | item.replace('\n', '').split(",")[0]) for item in image_list] 96 | total_metric = np.zeros((num_classes-1, 2)) 97 | print("Validation begin") 98 | for image_path in tqdm(image_list): 99 | h5f = h5py.File(image_path, 'r') 100 | image = h5f['image'][:] 101 | label = h5f['label'][:] 102 | prediction = test_single_case( 103 | net, image, stride_xy, stride_z, patch_size, num_classes=num_classes) 104 | for i in range(1, num_classes): 105 | total_metric[i-1, :] += cal_metric(label == i, prediction == i) 106 | print("Validation end") 107 | return total_metric / len(image_list) 108 | -------------------------------------------------------------------------------- /data/ACDC/README.md: -------------------------------------------------------------------------------- 1 | - Download the processed ACDC data from [BaiduDisk](https://pan.baidu.com/s/1d0cFhj3LU029oHajNni8KQ), the password is *code*, and decompress the zip file to [data/ACDC](https://github.com/Luoxd1996/SSL4MIS/edit/master/data/ACDC). More details of this dataset can be found at: https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html. 2 | - If you want to use the [ACDC dataset](https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html) in your paper, please cite the original paper [TMI2018](https://ieeexplore.ieee.org/document/8360453). 3 | -------------------------------------------------------------------------------- /data/ACDC/test.list: -------------------------------------------------------------------------------- 1 | patient011_frame01 2 | patient011_frame02 3 | patient013_frame01 4 | patient013_frame02 5 | patient084_frame01 6 | patient084_frame02 7 | patient033_frame01 8 | patient033_frame02 9 | patient093_frame01 10 | patient093_frame02 11 | patient022_frame01 12 | patient022_frame02 13 | patient068_frame01 14 | patient068_frame02 15 | patient024_frame01 16 | patient024_frame02 17 | patient083_frame01 18 | patient083_frame02 19 | patient081_frame01 20 | patient081_frame02 21 | patient080_frame01 22 | patient080_frame02 23 | patient001_frame01 24 | patient001_frame02 25 | patient007_frame01 26 | patient007_frame02 27 | patient066_frame01 28 | patient066_frame02 29 | patient008_frame01 30 | patient008_frame02 31 | patient065_frame01 32 | patient065_frame02 33 | patient075_frame01 34 | patient075_frame02 35 | patient064_frame01 36 | patient064_frame02 37 | patient059_frame01 38 | patient059_frame02 39 | patient052_frame01 40 | patient052_frame02 41 | -------------------------------------------------------------------------------- /data/ACDC/train.list: -------------------------------------------------------------------------------- 1 | patient099_frame01 2 | patient099_frame02 3 | patient038_frame01 4 | patient038_frame02 5 | patient050_frame01 6 | patient050_frame02 7 | patient100_frame01 8 | patient100_frame02 9 | patient058_frame01 10 | patient058_frame02 11 | patient021_frame01 12 | patient021_frame02 13 | patient049_frame01 14 | patient049_frame02 15 | patient020_frame01 16 | patient020_frame02 17 | patient072_frame01 18 | patient072_frame02 19 | patient040_frame01 20 | patient040_frame02 21 | patient060_frame01 22 | patient060_frame02 23 | patient089_frame01 24 | patient089_frame02 25 | patient004_frame01 26 | patient004_frame02 27 | patient056_frame01 28 | patient056_frame02 29 | patient098_frame01 30 | patient098_frame02 31 | patient096_frame01 32 | patient096_frame02 33 | patient031_frame01 34 | patient031_frame02 35 | patient018_frame01 36 | patient018_frame02 37 | patient094_frame01 38 | patient094_frame02 39 | patient047_frame01 40 | patient047_frame02 41 | patient048_frame01 42 | patient048_frame02 43 | patient055_frame01 44 | patient055_frame02 45 | patient097_frame01 46 | patient097_frame02 47 | patient074_frame01 48 | patient074_frame02 49 | patient043_frame01 50 | patient043_frame02 51 | patient041_frame01 52 | patient041_frame02 53 | patient063_frame01 54 | patient063_frame02 55 | patient037_frame01 56 | patient037_frame02 57 | patient095_frame01 58 | patient095_frame02 59 | patient054_frame01 60 | patient054_frame02 61 | patient026_frame01 62 | patient026_frame02 63 | patient088_frame01 64 | patient088_frame02 65 | patient032_frame01 66 | patient032_frame02 67 | patient069_frame01 68 | patient069_frame02 69 | patient006_frame01 70 | patient006_frame02 71 | patient071_frame01 72 | patient071_frame02 73 | patient012_frame01 74 | patient012_frame02 75 | patient073_frame01 76 | patient073_frame02 77 | patient061_frame01 78 | patient061_frame02 79 | patient017_frame01 80 | patient017_frame02 81 | patient025_frame01 82 | patient025_frame02 83 | patient010_frame01 84 | patient010_frame02 85 | patient057_frame01 86 | patient057_frame02 87 | patient029_frame01 88 | patient029_frame02 89 | patient051_frame01 90 | patient051_frame02 91 | patient005_frame01 92 | patient005_frame02 93 | patient036_frame01 94 | patient036_frame02 95 | patient046_frame01 96 | patient046_frame02 97 | patient062_frame01 98 | patient062_frame02 99 | patient034_frame01 100 | patient034_frame02 101 | patient076_frame01 102 | patient076_frame02 103 | patient092_frame01 104 | patient092_frame02 105 | patient070_frame01 106 | patient070_frame02 107 | patient077_frame01 108 | patient077_frame02 109 | patient067_frame01 110 | patient067_frame02 111 | patient003_frame01 112 | patient003_frame02 113 | patient091_frame01 114 | patient091_frame02 115 | patient016_frame01 116 | patient016_frame02 117 | patient014_frame01 118 | patient014_frame02 119 | patient044_frame01 120 | patient044_frame02 121 | patient042_frame01 122 | patient042_frame02 123 | patient090_frame01 124 | patient090_frame02 125 | patient053_frame01 126 | patient053_frame02 127 | patient027_frame01 128 | patient027_frame02 129 | patient035_frame01 130 | patient035_frame02 131 | patient086_frame01 132 | patient086_frame02 133 | patient023_frame01 134 | patient023_frame02 135 | patient009_frame01 136 | patient009_frame02 137 | patient079_frame01 138 | patient079_frame02 139 | patient015_frame01 140 | patient015_frame02 141 | -------------------------------------------------------------------------------- /data/ACDC/val.list: -------------------------------------------------------------------------------- 1 | patient028_frame01 2 | patient028_frame02 3 | patient085_frame01 4 | patient085_frame02 5 | patient082_frame01 6 | patient082_frame02 7 | patient087_frame01 8 | patient087_frame02 9 | patient019_frame01 10 | patient019_frame02 11 | patient030_frame01 12 | patient030_frame02 13 | patient078_frame01 14 | patient078_frame02 15 | patient045_frame01 16 | patient045_frame02 17 | patient002_frame01 18 | patient002_frame02 19 | patient039_frame01 20 | patient039_frame02 21 | -------------------------------------------------------------------------------- /data/BraTS2019/README.md: -------------------------------------------------------------------------------- 1 | - Download the processed BraTS2019 data (we just used the Flair images for whole tumor segmentation) from [BaiduDisk](https://pan.baidu.com/s/1CrMNP8hUExGuQNrHPuGb7w), the password is *code*, and decompress the zip file to [data/BraTS2019](https://github.com/Luoxd1996/SSL4MIS/edit/master/data/BraTS2019). 2 | -------------------------------------------------------------------------------- /data/BraTS2019/test.txt: -------------------------------------------------------------------------------- 1 | BraTS19_TCIA02_309_1 2 | BraTS19_CBICA_AYA_1 3 | BraTS19_CBICA_AYG_1 4 | BraTS19_CBICA_ANV_1 5 | BraTS19_CBICA_BAN_1 6 | BraTS19_TCIA10_408_1 7 | BraTS19_TCIA01_448_1 8 | BraTS19_TCIA10_261_1 9 | BraTS19_CBICA_ALU_1 10 | BraTS19_CBICA_AWH_1 11 | BraTS19_TCIA10_420_1 12 | BraTS19_2013_11_1 13 | BraTS19_TCIA01_460_1 14 | BraTS19_TCIA10_639_1 15 | BraTS19_TCIA10_130_1 16 | BraTS19_TCIA09_254_1 17 | BraTS19_CBICA_BLJ_1 18 | BraTS19_TCIA01_390_1 19 | BraTS19_2013_19_1 20 | BraTS19_CBICA_BDK_1 21 | BraTS19_TCIA13_630_1 22 | BraTS19_TCIA02_430_1 23 | BraTS19_TCIA06_165_1 24 | BraTS19_CBICA_AOS_1 25 | BraTS19_CBICA_AZH_1 26 | BraTS19_CBICA_BGX_1 27 | BraTS19_TCIA02_455_1 28 | BraTS19_TCIA10_330_1 29 | BraTS19_TMC_12866_1 30 | BraTS19_TCIA02_321_1 31 | BraTS19_TCIA03_257_1 32 | BraTS19_CBICA_AWI_1 33 | BraTS19_CBICA_AQZ_1 34 | BraTS19_TCIA10_152_1 35 | BraTS19_TMC_27374_1 36 | BraTS19_TCIA01_378_1 37 | BraTS19_TCIA08_218_1 38 | BraTS19_CBICA_ASY_1 39 | BraTS19_TCIA02_168_1 40 | BraTS19_CBICA_BJY_1 41 | BraTS19_2013_18_1 42 | BraTS19_TCIA03_121_1 43 | BraTS19_CBICA_BIC_1 44 | BraTS19_TCIA03_133_1 45 | BraTS19_TCIA02_171_1 46 | BraTS19_TMC_06290_1 47 | BraTS19_TCIA10_175_1 48 | BraTS19_CBICA_BGR_1 49 | BraTS19_TCIA10_276_1 50 | BraTS19_TCIA09_402_1 51 | BraTS19_TCIA10_442_1 52 | BraTS19_CBICA_AUN_1 53 | BraTS19_2013_20_1 54 | BraTS19_TCIA10_490_1 55 | BraTS19_TCIA06_409_1 56 | BraTS19_TMC_21360_1 57 | BraTS19_CBICA_AQQ_1 58 | BraTS19_TCIA10_307_1 59 | BraTS19_TCIA04_437_1 60 | BraTS19_CBICA_AQR_1 61 | -------------------------------------------------------------------------------- /data/BraTS2019/train.txt: -------------------------------------------------------------------------------- 1 | BraTS19_TCIA02_370_1 2 | BraTS19_CBICA_ASA_1 3 | BraTS19_TCIA12_470_1 4 | BraTS19_2013_8_1 5 | BraTS19_TCIA01_429_1 6 | BraTS19_TCIA08_234_1 7 | BraTS19_TCIA10_266_1 8 | BraTS19_TCIA13_633_1 9 | BraTS19_TCIA03_199_1 10 | BraTS19_TCIA10_629_1 11 | BraTS19_CBICA_ATB_1 12 | BraTS19_CBICA_BCL_1 13 | BraTS19_TCIA08_469_1 14 | BraTS19_TCIA04_343_1 15 | BraTS19_CBICA_AOO_1 16 | BraTS19_CBICA_ASF_1 17 | BraTS19_TCIA13_618_1 18 | BraTS19_CBICA_AVG_1 19 | BraTS19_TCIA09_255_1 20 | BraTS19_TCIA08_280_1 21 | BraTS19_2013_9_1 22 | BraTS19_TCIA10_202_1 23 | BraTS19_TCIA04_149_1 24 | BraTS19_CBICA_AZD_1 25 | BraTS19_TCIA01_131_1 26 | BraTS19_CBICA_BHZ_1 27 | BraTS19_TCIA02_179_1 28 | BraTS19_CBICA_BEM_1 29 | BraTS19_TCIA02_300_1 30 | BraTS19_CBICA_ARF_1 31 | BraTS19_CBICA_ABY_1 32 | BraTS19_TCIA02_608_1 33 | BraTS19_2013_17_1 34 | BraTS19_CBICA_BHV_1 35 | BraTS19_TCIA02_117_1 36 | BraTS19_TCIA12_249_1 37 | BraTS19_TCIA08_162_1 38 | BraTS19_TCIA03_498_1 39 | BraTS19_TCIA01_235_1 40 | BraTS19_2013_15_1 41 | BraTS19_CBICA_AOP_1 42 | BraTS19_CBICA_AUA_1 43 | BraTS19_CBICA_AAB_1 44 | BraTS19_CBICA_BFB_1 45 | BraTS19_TCIA09_451_1 46 | BraTS19_TCIA02_322_1 47 | BraTS19_CBICA_ATV_1 48 | BraTS19_CBICA_BCF_1 49 | BraTS19_CBICA_AQJ_1 50 | BraTS19_CBICA_AVV_1 51 | BraTS19_CBICA_ASU_1 52 | BraTS19_CBICA_AYW_1 53 | BraTS19_CBICA_AUR_1 54 | BraTS19_CBICA_AYC_1 55 | BraTS19_TCIA02_607_1 56 | BraTS19_TCIA08_436_1 57 | BraTS19_TCIA02_471_1 58 | BraTS19_CBICA_AQG_1 59 | BraTS19_TCIA10_387_1 60 | BraTS19_TCIA02_606_1 61 | BraTS19_CBICA_ASR_1 62 | BraTS19_CBICA_ASO_1 63 | BraTS19_TCIA08_167_1 64 | BraTS19_CBICA_AXN_1 65 | BraTS19_CBICA_ABB_1 66 | BraTS19_CBICA_AWG_1 67 | BraTS19_TCIA13_653_1 68 | BraTS19_2013_23_1 69 | BraTS19_2013_14_1 70 | BraTS19_CBICA_APY_1 71 | BraTS19_CBICA_ATX_1 72 | BraTS19_CBICA_ATF_1 73 | BraTS19_CBICA_AQV_1 74 | BraTS19_CBICA_ASW_1 75 | BraTS19_TCIA10_449_1 76 | BraTS19_CBICA_AQO_1 77 | BraTS19_TCIA06_247_1 78 | BraTS19_CBICA_AXQ_1 79 | BraTS19_TCIA02_290_1 80 | BraTS19_CBICA_APZ_1 81 | BraTS19_TCIA10_640_1 82 | BraTS19_TCIA02_118_1 83 | BraTS19_TCIA02_151_1 84 | BraTS19_CBICA_BGE_1 85 | BraTS19_CBICA_AOD_1 86 | BraTS19_TCIA01_401_1 87 | BraTS19_2013_27_1 88 | BraTS19_TCIA01_180_1 89 | BraTS19_TCIA01_231_1 90 | BraTS19_TCIA02_491_1 91 | BraTS19_TCIA05_444_1 92 | BraTS19_CBICA_AQT_1 93 | BraTS19_CBICA_AQU_1 94 | BraTS19_CBICA_AME_1 95 | BraTS19_TCIA12_298_1 96 | BraTS19_CBICA_AMH_1 97 | BraTS19_CBICA_ANI_1 98 | BraTS19_2013_28_1 99 | BraTS19_CBICA_BGN_1 100 | BraTS19_TCIA13_650_1 101 | BraTS19_TCIA13_634_1 102 | BraTS19_CBICA_APK_1 103 | BraTS19_TCIA13_624_1 104 | BraTS19_TCIA10_628_1 105 | BraTS19_TCIA01_221_1 106 | BraTS19_TCIA06_332_1 107 | BraTS19_CBICA_BHK_1 108 | BraTS19_TCIA09_428_1 109 | BraTS19_2013_3_1 110 | BraTS19_CBICA_BAX_1 111 | BraTS19_TCIA10_299_1 112 | BraTS19_TCIA10_310_1 113 | BraTS19_TCIA02_331_1 114 | BraTS19_TCIA03_419_1 115 | BraTS19_CBICA_AAG_1 116 | BraTS19_TCIA12_480_1 117 | BraTS19_CBICA_BGO_1 118 | BraTS19_TCIA01_425_1 119 | BraTS19_TCIA10_637_1 120 | BraTS19_2013_2_1 121 | BraTS19_CBICA_AWX_1 122 | BraTS19_TCIA13_642_1 123 | BraTS19_TCIA04_328_1 124 | BraTS19_CBICA_AWV_1 125 | BraTS19_CBICA_AAL_1 126 | BraTS19_CBICA_AVF_1 127 | BraTS19_TCIA08_406_1 128 | BraTS19_TCIA03_296_1 129 | BraTS19_CBICA_ASK_1 130 | BraTS19_TCIA05_396_1 131 | BraTS19_TCIA02_368_1 132 | BraTS19_CBICA_AUQ_1 133 | BraTS19_2013_21_1 134 | BraTS19_CBICA_BGG_1 135 | BraTS19_TCIA13_654_1 136 | BraTS19_CBICA_AUW_1 137 | BraTS19_CBICA_BHM_1 138 | BraTS19_2013_29_1 139 | BraTS19_CBICA_AOZ_1 140 | BraTS19_2013_0_1 141 | BraTS19_CBICA_ASH_1 142 | BraTS19_CBICA_ANZ_1 143 | BraTS19_2013_26_1 144 | BraTS19_TCIA02_283_1 145 | BraTS19_TCIA02_473_1 146 | BraTS19_TCIA09_141_1 147 | BraTS19_TCIA08_278_1 148 | BraTS19_TCIA03_375_1 149 | BraTS19_2013_10_1 150 | BraTS19_2013_7_1 151 | BraTS19_TCIA12_101_1 152 | BraTS19_TCIA04_192_1 153 | BraTS19_CBICA_ASN_1 154 | BraTS19_TCIA08_242_1 155 | BraTS19_TCIA02_394_1 156 | BraTS19_CBICA_AXW_1 157 | BraTS19_TCIA04_361_1 158 | BraTS19_CBICA_AQY_1 159 | BraTS19_TCIA01_412_1 160 | BraTS19_CBICA_AVT_1 161 | BraTS19_TCIA01_186_1 162 | BraTS19_TCIA02_377_1 163 | BraTS19_TCIA01_411_1 164 | BraTS19_TCIA02_198_1 165 | BraTS19_CBICA_ASE_1 166 | BraTS19_TCIA10_413_1 167 | BraTS19_CBICA_APR_1 168 | BraTS19_CBICA_ALN_1 169 | BraTS19_TCIA10_393_1 170 | BraTS19_TMC_09043_1 171 | BraTS19_2013_5_1 172 | BraTS19_TCIA01_201_1 173 | BraTS19_CBICA_ABM_1 174 | BraTS19_2013_4_1 175 | BraTS19_TCIA13_621_1 176 | BraTS19_TCIA09_620_1 177 | BraTS19_TCIA10_325_1 178 | BraTS19_CBICA_ANP_1 179 | BraTS19_TCIA01_335_1 180 | BraTS19_CBICA_AXM_1 181 | BraTS19_TCIA08_105_1 182 | BraTS19_TCIA02_226_1 183 | BraTS19_CBICA_ANG_1 184 | BraTS19_TCIA03_338_1 185 | BraTS19_TCIA08_113_1 186 | BraTS19_TCIA02_222_1 187 | BraTS19_TCIA10_625_1 188 | BraTS19_TCIA09_493_1 189 | BraTS19_CBICA_AOH_1 190 | BraTS19_CBICA_BAP_1 191 | BraTS19_TCIA02_605_1 192 | BraTS19_TCIA01_190_1 193 | BraTS19_TCIA10_109_1 194 | BraTS19_CBICA_AUX_1 195 | BraTS19_2013_13_1 196 | BraTS19_TMC_06643_1 197 | BraTS19_TCIA10_241_1 198 | BraTS19_CBICA_AXL_1 199 | BraTS19_CBICA_AXO_1 200 | BraTS19_TCIA10_632_1 201 | BraTS19_TCIA13_645_1 202 | BraTS19_TCIA02_314_1 203 | BraTS19_TCIA05_277_1 204 | BraTS19_TCIA03_474_1 205 | BraTS19_TCIA02_374_1 206 | BraTS19_CBICA_AXJ_1 207 | BraTS19_TCIA10_351_1 208 | BraTS19_CBICA_AQD_1 209 | BraTS19_TCIA09_462_1 210 | BraTS19_CBICA_AYI_1 211 | BraTS19_2013_25_1 212 | BraTS19_CBICA_AVJ_1 213 | BraTS19_2013_12_1 214 | BraTS19_TCIA06_372_1 215 | BraTS19_2013_24_1 216 | BraTS19_TCIA04_479_1 217 | BraTS19_CBICA_ASV_1 218 | BraTS19_CBICA_ABN_1 219 | BraTS19_CBICA_ATN_1 220 | BraTS19_CBICA_ABE_1 221 | BraTS19_TCIA08_319_1 222 | BraTS19_CBICA_AQP_1 223 | BraTS19_TCIA03_138_1 224 | BraTS19_TCIA10_410_1 225 | BraTS19_CBICA_BGW_1 226 | BraTS19_CBICA_BNR_1 227 | BraTS19_CBICA_BHB_1 228 | BraTS19_CBICA_BFP_1 229 | BraTS19_CBICA_BGT_1 230 | BraTS19_CBICA_AYU_1 231 | BraTS19_CBICA_ATD_1 232 | BraTS19_CBICA_ATP_1 233 | BraTS19_CBICA_BBG_1 234 | BraTS19_CBICA_ARZ_1 235 | BraTS19_TCIA06_211_1 236 | BraTS19_2013_6_1 237 | BraTS19_CBICA_ALX_1 238 | BraTS19_TCIA05_478_1 239 | BraTS19_TCIA06_603_1 240 | BraTS19_TCIA03_265_1 241 | BraTS19_TCIA13_623_1 242 | BraTS19_TCIA04_111_1 243 | BraTS19_CBICA_ASG_1 244 | BraTS19_TCIA01_499_1 245 | BraTS19_TCIA12_466_1 246 | BraTS19_TCIA01_203_1 247 | BraTS19_2013_22_1 248 | BraTS19_TCIA02_274_1 249 | BraTS19_TCIA01_150_1 250 | BraTS19_TCIA10_346_1 251 | -------------------------------------------------------------------------------- /data/BraTS2019/val.txt: -------------------------------------------------------------------------------- 1 | BraTS19_CBICA_AAP_1 2 | BraTS19_TCIA10_282_1 3 | BraTS19_TCIA01_147_1 4 | BraTS19_TCIA13_615_1 5 | BraTS19_CBICA_AQA_1 6 | BraTS19_TCIA08_205_1 7 | BraTS19_TCIA09_177_1 8 | BraTS19_TCIA02_208_1 9 | BraTS19_CBICA_AQN_1 10 | BraTS19_2013_16_1 11 | BraTS19_CBICA_BHQ_1 12 | BraTS19_TCIA02_135_1 13 | BraTS19_CBICA_AVB_1 14 | BraTS19_TCIA10_644_1 15 | BraTS19_CBICA_BKV_1 16 | BraTS19_TMC_15477_1 17 | BraTS19_2013_1_1 18 | BraTS19_CBICA_ARW_1 19 | BraTS19_TMC_11964_1 20 | BraTS19_TCIA09_312_1 21 | BraTS19_TCIA06_184_1 22 | BraTS19_TMC_30014_1 23 | BraTS19_TCIA10_103_1 24 | BraTS19_CBICA_AOC_1 25 | BraTS19_CBICA_ABO_1 26 | --------------------------------------------------------------------------------