├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── assets ├── TCGA_CS_4944.png ├── TCGA_CS_4944_20010208.gif ├── TCGA_DU_6404_19850629.gif ├── TCGA_HT_7879_19981009.gif ├── brain-mri-lgg.png ├── dsc.png └── unet.png ├── dataset.py ├── hubconf.py ├── inference.py ├── logger.py ├── loss.py ├── requirements.txt ├── train.py ├── transform.py ├── unet.py ├── utils.py └── weights └── unet.pt /.dockerignore: -------------------------------------------------------------------------------- 1 | kaggle_3m 2 | weights 3 | logs 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | kaggle_3m 2 | logs 3 | predictions 4 | 5 | # Created by https://www.gitignore.io/api/osx,linux,matlab,python,pycharm+all,intellij+all,jupyternotebook 6 | 7 | ### Intellij+all ### 8 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 9 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 10 | 11 | # User-specific stuff: 12 | .idea/**/workspace.xml 13 | .idea/**/tasks.xml 14 | .idea/dictionaries 15 | 16 | # Sensitive or high-churn files: 17 | .idea/**/dataSources/ 18 | .idea/**/dataSources.ids 19 | .idea/**/dataSources.xml 20 | .idea/**/dataSources.local.xml 21 | .idea/**/sqlDataSources.xml 22 | .idea/**/dynamic.xml 23 | .idea/**/uiDesigner.xml 24 | 25 | # Gradle: 26 | .idea/**/gradle.xml 27 | .idea/**/libraries 28 | 29 | # CMake 30 | cmake-build-debug/ 31 | 32 | # Mongo Explorer plugin: 33 | .idea/**/mongoSettings.xml 34 | 35 | ## File-based project format: 36 | *.iws 37 | 38 | ## Plugin-specific files: 39 | 40 | # IntelliJ 41 | /out/ 42 | 43 | # mpeltonen/sbt-idea plugin 44 | .idea_modules/ 45 | 46 | # JIRA plugin 47 | atlassian-ide-plugin.xml 48 | 49 | # Cursive Clojure plugin 50 | .idea/replstate.xml 51 | 52 | # Ruby plugin and RubyMine 53 | /.rakeTasks 54 | 55 | # Crashlytics plugin (for Android Studio and IntelliJ) 56 | com_crashlytics_export_strings.xml 57 | crashlytics.properties 58 | crashlytics-build.properties 59 | fabric.properties 60 | 61 | ### Intellij+all Patch ### 62 | # Ignores the whole idea folder 63 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 64 | 65 | .idea/ 66 | 67 | ### JupyterNotebook ### 68 | .ipynb_checkpoints 69 | */.ipynb_checkpoints/* 70 | 71 | # Remove previous ipynb_checkpoints 72 | # git rm -r .ipynb_checkpoints/ 73 | # 74 | ### Linux ### 75 | *~ 76 | 77 | # temporary files which can be created if a process still has a handle open of a deleted file 78 | .fuse_hidden* 79 | 80 | # KDE directory preferences 81 | .directory 82 | 83 | # Linux trash folder which might appear on any partition or disk 84 | .Trash-* 85 | 86 | # .nfs files are created when an open file is removed but is still being accessed 87 | .nfs* 88 | 89 | ### Matlab ### 90 | ##--------------------------------------------------- 91 | ## Remove autosaves generated by the Matlab editor 92 | ## We have git for backups! 93 | ##--------------------------------------------------- 94 | 95 | # Windows default autosave extension 96 | *.asv 97 | 98 | # OSX / *nix default autosave extension 99 | *.m~ 100 | 101 | # Compiled MEX binaries (all platforms) 102 | *.mex* 103 | 104 | # Simulink Code Generation 105 | slprj/ 106 | 107 | # Session info 108 | octave-workspace 109 | 110 | # Simulink autosave extension 111 | *.autosave 112 | 113 | ### OSX ### 114 | *.DS_Store 115 | .AppleDouble 116 | .LSOverride 117 | 118 | # Icon must end with two \r 119 | Icon 120 | 121 | # Thumbnails 122 | ._* 123 | 124 | # Files that might appear in the root of a volume 125 | .DocumentRevisions-V100 126 | .fseventsd 127 | .Spotlight-V100 128 | .TemporaryItems 129 | .Trashes 130 | .VolumeIcon.icns 131 | .com.apple.timemachine.donotpresent 132 | 133 | # Directories potentially created on remote AFP share 134 | .AppleDB 135 | .AppleDesktop 136 | Network Trash Folder 137 | Temporary Items 138 | .apdisk 139 | 140 | ### PyCharm+all ### 141 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 142 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 143 | 144 | # User-specific stuff: 145 | 146 | # Sensitive or high-churn files: 147 | 148 | # Gradle: 149 | 150 | # CMake 151 | 152 | # Mongo Explorer plugin: 153 | 154 | ## File-based project format: 155 | 156 | ## Plugin-specific files: 157 | 158 | # IntelliJ 159 | 160 | # mpeltonen/sbt-idea plugin 161 | 162 | # JIRA plugin 163 | 164 | # Cursive Clojure plugin 165 | 166 | # Ruby plugin and RubyMine 167 | 168 | # Crashlytics plugin (for Android Studio and IntelliJ) 169 | 170 | ### PyCharm+all Patch ### 171 | # Ignores the whole idea folder 172 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 173 | 174 | 175 | ### Python ### 176 | # Byte-compiled / optimized / DLL files 177 | __pycache__/ 178 | *.py[cod] 179 | *$py.class 180 | 181 | # C extensions 182 | *.so 183 | 184 | # Distribution / packaging 185 | .Python 186 | build/ 187 | develop-eggs/ 188 | dist/ 189 | downloads/ 190 | eggs/ 191 | .eggs/ 192 | lib/ 193 | lib64/ 194 | parts/ 195 | sdist/ 196 | var/ 197 | wheels/ 198 | *.egg-info/ 199 | .installed.cfg 200 | *.egg 201 | 202 | # PyInstaller 203 | # Usually these files are written by a python script from a template 204 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 205 | *.manifest 206 | *.spec 207 | 208 | # Installer logs 209 | pip-log.txt 210 | pip-delete-this-directory.txt 211 | 212 | # Unit test / coverage reports 213 | htmlcov/ 214 | .tox/ 215 | .coverage 216 | .coverage.* 217 | .cache 218 | .pytest_cache/ 219 | nosetests.xml 220 | coverage.xml 221 | *.cover 222 | .hypothesis/ 223 | 224 | # Translations 225 | *.mo 226 | *.pot 227 | 228 | # Flask stuff: 229 | instance/ 230 | .webassets-cache 231 | 232 | # Scrapy stuff: 233 | .scrapy 234 | 235 | # Sphinx documentation 236 | docs/_build/ 237 | 238 | # PyBuilder 239 | target/ 240 | 241 | # Jupyter Notebook 242 | 243 | # pyenv 244 | .python-version 245 | 246 | # celery beat schedule file 247 | celerybeat-schedule.* 248 | 249 | # SageMath parsed files 250 | *.sage.py 251 | 252 | # Environments 253 | .env 254 | .venv 255 | env/ 256 | venv/ 257 | ENV/ 258 | env.bak/ 259 | venv.bak/ 260 | 261 | # Spyder project settings 262 | .spyderproject 263 | .spyproject 264 | 265 | # Rope project settings 266 | .ropeproject 267 | 268 | # mkdocs documentation 269 | /site 270 | 271 | # mypy 272 | .mypy_cache/ 273 | 274 | 275 | # End of https://www.gitignore.io/api/osx,linux,matlab,python,pycharm+all,intellij+all,jupyternotebook 276 | 277 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 2 | 3 | RUN apt-get update && apt-get install -y --no-install-recommends \ 4 | ca-certificates \ 5 | curl \ 6 | netbase \ 7 | wget \ 8 | git \ 9 | openssh-client \ 10 | ssh \ 11 | vim \ 12 | && rm -rf /var/lib/apt/lists/* 13 | 14 | # http://bugs.python.org/issue19846 15 | ENV LANG C.UTF-8 16 | # https://github.com/docker-library/python/issues/147 17 | ENV PYTHONIOENCODING UTF-8 18 | 19 | RUN apt-get update && apt-get install -y --no-install-recommends \ 20 | python3.6 \ 21 | python3.6-dev \ 22 | python3-pip \ 23 | python3-setuptools \ 24 | && rm -rf /var/lib/apt/lists/* 25 | 26 | RUN pip3 install --upgrade pip 27 | 28 | RUN pip3 install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp36-cp36m-linux_x86_64.whl \ 29 | && pip3 install https://download.pytorch.org/whl/cu100/torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl 30 | 31 | WORKDIR /workspace 32 | 33 | COPY requirements.txt ./ 34 | 35 | RUN pip3 install --no-cache-dir -r requirements.txt 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 mateuszbuda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # U-Net for brain segmentation 2 | 3 | U-Net implementation in PyTorch for FLAIR abnormality segmentation in brain MRI based on a deep learning segmentation algorithm used in [Association of genomic subtypes of lower-grade gliomas with shape features automatically extracted by a deep learning algorithm](https://doi.org/10.1016/j.compbiomed.2019.05.002). 4 | 5 | This repository is an all Python port of official MATLAB/Keras implementation in [brain-segmentation](https://github.com/mateuszbuda/brain-segmentation). 6 | Weights for trained models are provided and can be used for inference or fine-tuning on a different dataset. 7 | If you use code or weights shared in this repository, please consider citing: 8 | 9 | ``` 10 | @article{buda2019association, 11 | title={Association of genomic subtypes of lower-grade gliomas with shape features automatically extracted by a deep learning algorithm}, 12 | author={Buda, Mateusz and Saha, Ashirbani and Mazurowski, Maciej A}, 13 | journal={Computers in Biology and Medicine}, 14 | volume={109}, 15 | year={2019}, 16 | publisher={Elsevier}, 17 | doi={10.1016/j.compbiomed.2019.05.002} 18 | } 19 | ``` 20 | 21 | ## docker 22 | 23 | ``` 24 | docker build -t brainseg . 25 | ``` 26 | 27 | ``` 28 | nvidia-docker run --rm --shm-size 8G -it -v `pwd`:/workspace brainseg 29 | ``` 30 | 31 | ## PyTorch Hub 32 | 33 | Loading model using PyTorch Hub: [pytorch.org/hub/mateuszbuda\_brain-segmentation-pytorch\_unet](https://pytorch.org/hub/mateuszbuda_brain-segmentation-pytorch_unet/) 34 | 35 | ```python 36 | import torch 37 | model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', 38 | in_channels=3, out_channels=1, init_features=32, pretrained=True) 39 | ``` 40 | 41 | ## data 42 | 43 | ![dataset](./assets/brain-mri-lgg.png) 44 | 45 | Dataset used for development and evaluation was made publicly available on Kaggle: [kaggle.com/mateuszbuda/lgg-mri-segmentation](https://www.kaggle.com/mateuszbuda/lgg-mri-segmentation). 46 | It contains MR images from [TCIA LGG collection](https://wiki.cancerimagingarchive.net/display/Public/TCGA-LGG) with segmentation masks approved by a board-certified radiologist at Duke University. 47 | 48 | ## model 49 | 50 | A segmentation model implemented in this repository is U-Net as described in [Association of genomic subtypes of lower-grade gliomas with shape features automatically extracted by a deep learning algorithm](https://doi.org/10.1016/j.compbiomed.2019.05.002) with added batch normalization. 51 | 52 | ![unet](./assets/unet.png) 53 | 54 | ## results 55 | 56 | |![TCGA_DU_6404_19850629](./assets/TCGA_DU_6404_19850629.gif)|![TCGA_HT_7879_19981009](./assets/TCGA_HT_7879_19981009.gif)|![TCGA_CS_4944_20010208](./assets/TCGA_CS_4944_20010208.gif)| 57 | |:-------:|:-------:|:-------:| 58 | | 94% DSC | 91% DSC | 89% DSC | 59 | 60 | Qualitative results for validation cases from three different institutions with DSC of 94%, 91%, and 89%. 61 | Green outlines correspond to ground truth and red to model predictions. 62 | Images show FLAIR modality after preprocessing. 63 | 64 | ![dsc](./assets/dsc.png) 65 | 66 | Distribution of DSC for 10 randomly selected validation cases. 67 | The red vertical line corresponds to mean DSC (91%) and the green one to median DSC (92%). 68 | Results may be biased since model selection was based on the mean DSC on these validation cases. 69 | 70 | ## inference 71 | 72 | 1. Download and extract the dataset from [Kaggle](https://www.kaggle.com/mateuszbuda/lgg-mri-segmentation). 73 | 2. Run docker container. 74 | 3. Run `inference.py` script with specified paths to weights and images. Trained weights for input images of size 256x256 are provided in `./weights/unet.pt` file. For more options and help run: `python3 inference.py --help`. 75 | 76 | ## train 77 | 78 | 1. Download and extract the dataset from [Kaggle](https://www.kaggle.com/mateuszbuda/lgg-mri-segmentation). 79 | 2. Run docker container. 80 | 3. Run `train.py` script. Default path to images is `./kaggle_3m`. For more options and help run: `python3 train.py --help`. 81 | 82 | Training can be also run using Kaggle kernel shared together with the dataset: [kaggle.com/mateuszbuda/brain-segmentation-pytorch](https://www.kaggle.com/mateuszbuda/brain-segmentation-pytorch). 83 | Due to memory limitations for Kaggle kernels, input images are of size 224x224 instead of 256x256. 84 | 85 | Running this code on a custom dataset would likely require adjustments in `dataset.py`. 86 | Should you need help with this, just open an issue. 87 | 88 | ## TensorRT inference 89 | 90 | If you want to run the model inference with TensorRT runtime, here is a blog post from Nvidia that covers this: [Speeding Up Deep Learning Inference Using TensorRT](https://developer.nvidia.com/blog/speeding-up-deep-learning-inference-using-tensorrt/). 91 | -------------------------------------------------------------------------------- /assets/TCGA_CS_4944.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateuszbuda/brain-segmentation-pytorch/d45f8908ab2f0246ba204c702a6161c9eb25f902/assets/TCGA_CS_4944.png -------------------------------------------------------------------------------- /assets/TCGA_CS_4944_20010208.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateuszbuda/brain-segmentation-pytorch/d45f8908ab2f0246ba204c702a6161c9eb25f902/assets/TCGA_CS_4944_20010208.gif -------------------------------------------------------------------------------- /assets/TCGA_DU_6404_19850629.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateuszbuda/brain-segmentation-pytorch/d45f8908ab2f0246ba204c702a6161c9eb25f902/assets/TCGA_DU_6404_19850629.gif -------------------------------------------------------------------------------- /assets/TCGA_HT_7879_19981009.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateuszbuda/brain-segmentation-pytorch/d45f8908ab2f0246ba204c702a6161c9eb25f902/assets/TCGA_HT_7879_19981009.gif -------------------------------------------------------------------------------- /assets/brain-mri-lgg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateuszbuda/brain-segmentation-pytorch/d45f8908ab2f0246ba204c702a6161c9eb25f902/assets/brain-mri-lgg.png -------------------------------------------------------------------------------- /assets/dsc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateuszbuda/brain-segmentation-pytorch/d45f8908ab2f0246ba204c702a6161c9eb25f902/assets/dsc.png -------------------------------------------------------------------------------- /assets/unet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateuszbuda/brain-segmentation-pytorch/d45f8908ab2f0246ba204c702a6161c9eb25f902/assets/unet.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from skimage.io import imread 7 | from torch.utils.data import Dataset 8 | 9 | from utils import crop_sample, pad_sample, resize_sample, normalize_volume 10 | 11 | 12 | class BrainSegmentationDataset(Dataset): 13 | """Brain MRI dataset for FLAIR abnormality segmentation""" 14 | 15 | in_channels = 3 16 | out_channels = 1 17 | 18 | def __init__( 19 | self, 20 | images_dir, 21 | transform=None, 22 | image_size=256, 23 | subset="train", 24 | random_sampling=True, 25 | validation_cases=10, 26 | seed=42, 27 | ): 28 | assert subset in ["all", "train", "validation"] 29 | 30 | # read images 31 | volumes = {} 32 | masks = {} 33 | print("reading {} images...".format(subset)) 34 | for (dirpath, dirnames, filenames) in os.walk(images_dir): 35 | image_slices = [] 36 | mask_slices = [] 37 | for filename in sorted( 38 | filter(lambda f: ".tif" in f, filenames), 39 | key=lambda x: int(x.split(".")[-2].split("_")[4]), 40 | ): 41 | filepath = os.path.join(dirpath, filename) 42 | if "mask" in filename: 43 | mask_slices.append(imread(filepath, as_gray=True)) 44 | else: 45 | image_slices.append(imread(filepath)) 46 | if len(image_slices) > 0: 47 | patient_id = dirpath.split("/")[-1] 48 | volumes[patient_id] = np.array(image_slices[1:-1]) 49 | masks[patient_id] = np.array(mask_slices[1:-1]) 50 | 51 | self.patients = sorted(volumes) 52 | 53 | # select cases to subset 54 | if not subset == "all": 55 | random.seed(seed) 56 | validation_patients = random.sample(self.patients, k=validation_cases) 57 | if subset == "validation": 58 | self.patients = validation_patients 59 | else: 60 | self.patients = sorted( 61 | list(set(self.patients).difference(validation_patients)) 62 | ) 63 | 64 | print("preprocessing {} volumes...".format(subset)) 65 | # create list of tuples (volume, mask) 66 | self.volumes = [(volumes[k], masks[k]) for k in self.patients] 67 | 68 | print("cropping {} volumes...".format(subset)) 69 | # crop to smallest enclosing volume 70 | self.volumes = [crop_sample(v) for v in self.volumes] 71 | 72 | print("padding {} volumes...".format(subset)) 73 | # pad to square 74 | self.volumes = [pad_sample(v) for v in self.volumes] 75 | 76 | print("resizing {} volumes...".format(subset)) 77 | # resize 78 | self.volumes = [resize_sample(v, size=image_size) for v in self.volumes] 79 | 80 | print("normalizing {} volumes...".format(subset)) 81 | # normalize channel-wise 82 | self.volumes = [(normalize_volume(v), m) for v, m in self.volumes] 83 | 84 | # probabilities for sampling slices based on masks 85 | self.slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in self.volumes] 86 | self.slice_weights = [ 87 | (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights 88 | ] 89 | 90 | # add channel dimension to masks 91 | self.volumes = [(v, m[..., np.newaxis]) for (v, m) in self.volumes] 92 | 93 | print("done creating {} dataset".format(subset)) 94 | 95 | # create global index for patient and slice (idx -> (p_idx, s_idx)) 96 | num_slices = [v.shape[0] for v, m in self.volumes] 97 | self.patient_slice_index = list( 98 | zip( 99 | sum([[i] * num_slices[i] for i in range(len(num_slices))], []), 100 | sum([list(range(x)) for x in num_slices], []), 101 | ) 102 | ) 103 | 104 | self.random_sampling = random_sampling 105 | 106 | self.transform = transform 107 | 108 | def __len__(self): 109 | return len(self.patient_slice_index) 110 | 111 | def __getitem__(self, idx): 112 | patient = self.patient_slice_index[idx][0] 113 | slice_n = self.patient_slice_index[idx][1] 114 | 115 | if self.random_sampling: 116 | patient = np.random.randint(len(self.volumes)) 117 | slice_n = np.random.choice( 118 | range(self.volumes[patient][0].shape[0]), p=self.slice_weights[patient] 119 | ) 120 | 121 | v, m = self.volumes[patient] 122 | image = v[slice_n] 123 | mask = m[slice_n] 124 | 125 | if self.transform is not None: 126 | image, mask = self.transform((image, mask)) 127 | 128 | # fix dimensions (C, H, W) 129 | image = image.transpose(2, 0, 1) 130 | mask = mask.transpose(2, 0, 1) 131 | 132 | image_tensor = torch.from_numpy(image.astype(np.float32)) 133 | mask_tensor = torch.from_numpy(mask.astype(np.float32)) 134 | 135 | # return tensors 136 | return image_tensor, mask_tensor 137 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ["torch"] 2 | 3 | import torch 4 | 5 | from unet import UNet 6 | 7 | 8 | def unet(pretrained=False, **kwargs): 9 | """ 10 | U-Net segmentation model with batch normalization for biomedical image segmentation 11 | pretrained (bool): load pretrained weights into the model 12 | in_channels (int): number of input channels 13 | out_channels (int): number of output channels 14 | init_features (int): number of feature-maps in the first encoder layer 15 | """ 16 | model = UNet(**kwargs) 17 | 18 | if pretrained: 19 | checkpoint = "https://github.com/mateuszbuda/brain-segmentation-pytorch/releases/download/v1.0/unet-e012d006.pt" 20 | state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=False, map_location='cpu') 21 | model.load_state_dict(state_dict) 22 | 23 | return model 24 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from matplotlib import pyplot as plt 7 | from matplotlib.backends.backend_agg import FigureCanvasAgg 8 | from medpy.filter.binary import largest_connected_component 9 | from skimage.io import imsave 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from dataset import BrainSegmentationDataset as Dataset 14 | from unet import UNet 15 | from utils import dsc, gray2rgb, outline 16 | 17 | 18 | def main(args): 19 | makedirs(args) 20 | device = torch.device("cpu" if not torch.cuda.is_available() else args.device) 21 | 22 | loader = data_loader(args) 23 | 24 | with torch.set_grad_enabled(False): 25 | unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels) 26 | state_dict = torch.load(args.weights, map_location=device) 27 | unet.load_state_dict(state_dict) 28 | unet.eval() 29 | unet.to(device) 30 | 31 | input_list = [] 32 | pred_list = [] 33 | true_list = [] 34 | 35 | for i, data in tqdm(enumerate(loader)): 36 | x, y_true = data 37 | x, y_true = x.to(device), y_true.to(device) 38 | 39 | y_pred = unet(x) 40 | y_pred_np = y_pred.detach().cpu().numpy() 41 | pred_list.extend([y_pred_np[s] for s in range(y_pred_np.shape[0])]) 42 | 43 | y_true_np = y_true.detach().cpu().numpy() 44 | true_list.extend([y_true_np[s] for s in range(y_true_np.shape[0])]) 45 | 46 | x_np = x.detach().cpu().numpy() 47 | input_list.extend([x_np[s] for s in range(x_np.shape[0])]) 48 | 49 | volumes = postprocess_per_volume( 50 | input_list, 51 | pred_list, 52 | true_list, 53 | loader.dataset.patient_slice_index, 54 | loader.dataset.patients, 55 | ) 56 | 57 | dsc_dist = dsc_distribution(volumes) 58 | 59 | dsc_dist_plot = plot_dsc(dsc_dist) 60 | imsave(args.figure, dsc_dist_plot) 61 | 62 | for p in volumes: 63 | x = volumes[p][0] 64 | y_pred = volumes[p][1] 65 | y_true = volumes[p][2] 66 | for s in range(x.shape[0]): 67 | image = gray2rgb(x[s, 1]) # channel 1 is for FLAIR 68 | image = outline(image, y_pred[s, 0], color=[255, 0, 0]) 69 | image = outline(image, y_true[s, 0], color=[0, 255, 0]) 70 | filename = "{}-{}.png".format(p, str(s).zfill(2)) 71 | filepath = os.path.join(args.predictions, filename) 72 | imsave(filepath, image) 73 | 74 | 75 | def data_loader(args): 76 | dataset = Dataset( 77 | images_dir=args.images, 78 | subset="validation", 79 | image_size=args.image_size, 80 | random_sampling=False, 81 | ) 82 | loader = DataLoader( 83 | dataset, batch_size=args.batch_size, drop_last=False, num_workers=1 84 | ) 85 | return loader 86 | 87 | 88 | def postprocess_per_volume( 89 | input_list, pred_list, true_list, patient_slice_index, patients 90 | ): 91 | volumes = {} 92 | num_slices = np.bincount([p[0] for p in patient_slice_index]) 93 | index = 0 94 | for p in range(len(num_slices)): 95 | volume_in = np.array(input_list[index : index + num_slices[p]]) 96 | volume_pred = np.round( 97 | np.array(pred_list[index : index + num_slices[p]]) 98 | ).astype(int) 99 | volume_pred = largest_connected_component(volume_pred) 100 | volume_true = np.array(true_list[index : index + num_slices[p]]) 101 | volumes[patients[p]] = (volume_in, volume_pred, volume_true) 102 | index += num_slices[p] 103 | return volumes 104 | 105 | 106 | def dsc_distribution(volumes): 107 | dsc_dict = {} 108 | for p in volumes: 109 | y_pred = volumes[p][1] 110 | y_true = volumes[p][2] 111 | dsc_dict[p] = dsc(y_pred, y_true, lcc=False) 112 | return dsc_dict 113 | 114 | 115 | def plot_dsc(dsc_dist): 116 | y_positions = np.arange(len(dsc_dist)) 117 | dsc_dist = sorted(dsc_dist.items(), key=lambda x: x[1]) 118 | values = [x[1] for x in dsc_dist] 119 | labels = [x[0] for x in dsc_dist] 120 | labels = ["_".join(l.split("_")[1:-1]) for l in labels] 121 | fig = plt.figure(figsize=(12, 8)) 122 | canvas = FigureCanvasAgg(fig) 123 | plt.barh(y_positions, values, align="center", color="skyblue") 124 | plt.yticks(y_positions, labels) 125 | plt.xticks(np.arange(0.0, 1.0, 0.1)) 126 | plt.xlim([0.0, 1.0]) 127 | plt.gca().axvline(np.mean(values), color="tomato", linewidth=2) 128 | plt.gca().axvline(np.median(values), color="forestgreen", linewidth=2) 129 | plt.xlabel("Dice coefficient", fontsize="x-large") 130 | plt.gca().xaxis.grid(color="silver", alpha=0.5, linestyle="--", linewidth=1) 131 | plt.tight_layout() 132 | canvas.draw() 133 | plt.close() 134 | s, (width, height) = canvas.print_to_buffer() 135 | return np.fromstring(s, np.uint8).reshape((height, width, 4)) 136 | 137 | 138 | def makedirs(args): 139 | os.makedirs(args.predictions, exist_ok=True) 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser( 144 | description="Inference for segmentation of brain MRI" 145 | ) 146 | parser.add_argument( 147 | "--device", 148 | type=str, 149 | default="cuda:0", 150 | help="device for training (default: cuda:0)", 151 | ) 152 | parser.add_argument( 153 | "--batch-size", 154 | type=int, 155 | default=32, 156 | help="input batch size for training (default: 32)", 157 | ) 158 | parser.add_argument( 159 | "--weights", type=str, required=True, help="path to weights file" 160 | ) 161 | parser.add_argument( 162 | "--images", type=str, default="./kaggle_3m", help="root folder with images" 163 | ) 164 | parser.add_argument( 165 | "--image-size", 166 | type=int, 167 | default=256, 168 | help="target input image size (default: 256)", 169 | ) 170 | parser.add_argument( 171 | "--predictions", 172 | type=str, 173 | default="./predictions", 174 | help="folder for saving images with prediction outlines", 175 | ) 176 | parser.add_argument( 177 | "--figure", 178 | type=str, 179 | default="./dsc.png", 180 | help="filename for DSC distribution figure", 181 | ) 182 | 183 | args = parser.parse_args() 184 | main(args) 185 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import scipy.misc 4 | import tensorflow as tf 5 | 6 | 7 | class Logger(object): 8 | 9 | def __init__(self, log_dir): 10 | self.writer = tf.summary.FileWriter(log_dir) 11 | 12 | def scalar_summary(self, tag, value, step): 13 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 14 | self.writer.add_summary(summary, step) 15 | self.writer.flush() 16 | 17 | def image_summary(self, tag, image, step): 18 | s = BytesIO() 19 | scipy.misc.toimage(image).save(s, format="png") 20 | 21 | # Create an Image object 22 | img_sum = tf.Summary.Image( 23 | encoded_image_string=s.getvalue(), 24 | height=image.shape[0], 25 | width=image.shape[1], 26 | ) 27 | 28 | # Create and write Summary 29 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, image=img_sum)]) 30 | self.writer.add_summary(summary, step) 31 | self.writer.flush() 32 | 33 | def image_list_summary(self, tag, images, step): 34 | if len(images) == 0: 35 | return 36 | img_summaries = [] 37 | for i, img in enumerate(images): 38 | s = BytesIO() 39 | scipy.misc.toimage(img).save(s, format="png") 40 | 41 | # Create an Image object 42 | img_sum = tf.Summary.Image( 43 | encoded_image_string=s.getvalue(), 44 | height=img.shape[0], 45 | width=img.shape[1], 46 | ) 47 | 48 | # Create a Summary value 49 | img_summaries.append( 50 | tf.Summary.Value(tag="{}/{}".format(tag, i), image=img_sum) 51 | ) 52 | 53 | # Create and write Summary 54 | summary = tf.Summary(value=img_summaries) 55 | self.writer.add_summary(summary, step) 56 | self.writer.flush() 57 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class DiceLoss(nn.Module): 5 | 6 | def __init__(self): 7 | super(DiceLoss, self).__init__() 8 | self.smooth = 1.0 9 | 10 | def forward(self, y_pred, y_true): 11 | assert y_pred.size() == y_true.size() 12 | y_pred = y_pred[:, 0].contiguous().view(-1) 13 | y_true = y_true[:, 0].contiguous().view(-1) 14 | intersection = (y_pred * y_true).sum() 15 | dsc = (2. * intersection + self.smooth) / ( 16 | y_pred.sum() + y_true.sum() + self.smooth 17 | ) 18 | return 1. - dsc 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.2 2 | tensorflow==1.12.2 3 | scikit-learn==0.20.3 4 | scikit-image==0.14.2 5 | imageio==2.5.0 6 | medpy==0.4.0 7 | Pillow==6.0.0 8 | scipy==1.2.1 9 | pandas==0.24.2 10 | tqdm==4.32.1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from dataset import BrainSegmentationDataset as Dataset 12 | from logger import Logger 13 | from loss import DiceLoss 14 | from transform import transforms 15 | from unet import UNet 16 | from utils import log_images, dsc 17 | 18 | 19 | def main(args): 20 | makedirs(args) 21 | snapshotargs(args) 22 | device = torch.device("cpu" if not torch.cuda.is_available() else args.device) 23 | 24 | loader_train, loader_valid = data_loaders(args) 25 | loaders = {"train": loader_train, "valid": loader_valid} 26 | 27 | unet = UNet(in_channels=Dataset.in_channels, out_channels=Dataset.out_channels) 28 | unet.to(device) 29 | 30 | dsc_loss = DiceLoss() 31 | best_validation_dsc = 0.0 32 | 33 | optimizer = optim.Adam(unet.parameters(), lr=args.lr) 34 | 35 | logger = Logger(args.logs) 36 | loss_train = [] 37 | loss_valid = [] 38 | 39 | step = 0 40 | 41 | for epoch in tqdm(range(args.epochs), total=args.epochs): 42 | for phase in ["train", "valid"]: 43 | if phase == "train": 44 | unet.train() 45 | else: 46 | unet.eval() 47 | 48 | validation_pred = [] 49 | validation_true = [] 50 | 51 | for i, data in enumerate(loaders[phase]): 52 | if phase == "train": 53 | step += 1 54 | 55 | x, y_true = data 56 | x, y_true = x.to(device), y_true.to(device) 57 | 58 | optimizer.zero_grad() 59 | 60 | with torch.set_grad_enabled(phase == "train"): 61 | y_pred = unet(x) 62 | 63 | loss = dsc_loss(y_pred, y_true) 64 | 65 | if phase == "valid": 66 | loss_valid.append(loss.item()) 67 | y_pred_np = y_pred.detach().cpu().numpy() 68 | validation_pred.extend( 69 | [y_pred_np[s] for s in range(y_pred_np.shape[0])] 70 | ) 71 | y_true_np = y_true.detach().cpu().numpy() 72 | validation_true.extend( 73 | [y_true_np[s] for s in range(y_true_np.shape[0])] 74 | ) 75 | if (epoch % args.vis_freq == 0) or (epoch == args.epochs - 1): 76 | if i * args.batch_size < args.vis_images: 77 | tag = "image/{}".format(i) 78 | num_images = args.vis_images - i * args.batch_size 79 | logger.image_list_summary( 80 | tag, 81 | log_images(x, y_true, y_pred)[:num_images], 82 | step, 83 | ) 84 | 85 | if phase == "train": 86 | loss_train.append(loss.item()) 87 | loss.backward() 88 | optimizer.step() 89 | 90 | if phase == "train" and (step + 1) % 10 == 0: 91 | log_loss_summary(logger, loss_train, step) 92 | loss_train = [] 93 | 94 | if phase == "valid": 95 | log_loss_summary(logger, loss_valid, step, prefix="val_") 96 | mean_dsc = np.mean( 97 | dsc_per_volume( 98 | validation_pred, 99 | validation_true, 100 | loader_valid.dataset.patient_slice_index, 101 | ) 102 | ) 103 | logger.scalar_summary("val_dsc", mean_dsc, step) 104 | if mean_dsc > best_validation_dsc: 105 | best_validation_dsc = mean_dsc 106 | torch.save(unet.state_dict(), os.path.join(args.weights, "unet.pt")) 107 | loss_valid = [] 108 | 109 | print("Best validation mean DSC: {:4f}".format(best_validation_dsc)) 110 | 111 | 112 | def data_loaders(args): 113 | dataset_train, dataset_valid = datasets(args) 114 | 115 | def worker_init(worker_id): 116 | np.random.seed(42 + worker_id) 117 | 118 | loader_train = DataLoader( 119 | dataset_train, 120 | batch_size=args.batch_size, 121 | shuffle=True, 122 | drop_last=True, 123 | num_workers=args.workers, 124 | worker_init_fn=worker_init, 125 | ) 126 | loader_valid = DataLoader( 127 | dataset_valid, 128 | batch_size=args.batch_size, 129 | drop_last=False, 130 | num_workers=args.workers, 131 | worker_init_fn=worker_init, 132 | ) 133 | 134 | return loader_train, loader_valid 135 | 136 | 137 | def datasets(args): 138 | train = Dataset( 139 | images_dir=args.images, 140 | subset="train", 141 | image_size=args.image_size, 142 | transform=transforms(scale=args.aug_scale, angle=args.aug_angle, flip_prob=0.5), 143 | ) 144 | valid = Dataset( 145 | images_dir=args.images, 146 | subset="validation", 147 | image_size=args.image_size, 148 | random_sampling=False, 149 | ) 150 | return train, valid 151 | 152 | 153 | def dsc_per_volume(validation_pred, validation_true, patient_slice_index): 154 | dsc_list = [] 155 | num_slices = np.bincount([p[0] for p in patient_slice_index]) 156 | index = 0 157 | for p in range(len(num_slices)): 158 | y_pred = np.array(validation_pred[index : index + num_slices[p]]) 159 | y_true = np.array(validation_true[index : index + num_slices[p]]) 160 | dsc_list.append(dsc(y_pred, y_true)) 161 | index += num_slices[p] 162 | return dsc_list 163 | 164 | 165 | def log_loss_summary(logger, loss, step, prefix=""): 166 | logger.scalar_summary(prefix + "loss", np.mean(loss), step) 167 | 168 | 169 | def makedirs(args): 170 | os.makedirs(args.weights, exist_ok=True) 171 | os.makedirs(args.logs, exist_ok=True) 172 | 173 | 174 | def snapshotargs(args): 175 | args_file = os.path.join(args.logs, "args.json") 176 | with open(args_file, "w") as fp: 177 | json.dump(vars(args), fp) 178 | 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser( 182 | description="Training U-Net model for segmentation of brain MRI" 183 | ) 184 | parser.add_argument( 185 | "--batch-size", 186 | type=int, 187 | default=16, 188 | help="input batch size for training (default: 16)", 189 | ) 190 | parser.add_argument( 191 | "--epochs", 192 | type=int, 193 | default=100, 194 | help="number of epochs to train (default: 100)", 195 | ) 196 | parser.add_argument( 197 | "--lr", 198 | type=float, 199 | default=0.0001, 200 | help="initial learning rate (default: 0.001)", 201 | ) 202 | parser.add_argument( 203 | "--device", 204 | type=str, 205 | default="cuda:0", 206 | help="device for training (default: cuda:0)", 207 | ) 208 | parser.add_argument( 209 | "--workers", 210 | type=int, 211 | default=4, 212 | help="number of workers for data loading (default: 4)", 213 | ) 214 | parser.add_argument( 215 | "--vis-images", 216 | type=int, 217 | default=200, 218 | help="number of visualization images to save in log file (default: 200)", 219 | ) 220 | parser.add_argument( 221 | "--vis-freq", 222 | type=int, 223 | default=10, 224 | help="frequency of saving images to log file (default: 10)", 225 | ) 226 | parser.add_argument( 227 | "--weights", type=str, default="./weights", help="folder to save weights" 228 | ) 229 | parser.add_argument( 230 | "--logs", type=str, default="./logs", help="folder to save logs" 231 | ) 232 | parser.add_argument( 233 | "--images", type=str, default="./kaggle_3m", help="root folder with images" 234 | ) 235 | parser.add_argument( 236 | "--image-size", 237 | type=int, 238 | default=256, 239 | help="target input image size (default: 256)", 240 | ) 241 | parser.add_argument( 242 | "--aug-scale", 243 | type=int, 244 | default=0.05, 245 | help="scale factor range for augmentation (default: 0.05)", 246 | ) 247 | parser.add_argument( 248 | "--aug-angle", 249 | type=int, 250 | default=15, 251 | help="rotation angle range in degrees for augmentation (default: 15)", 252 | ) 253 | args = parser.parse_args() 254 | main(args) 255 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.transform import rescale, rotate 3 | from torchvision.transforms import Compose 4 | 5 | 6 | def transforms(scale=None, angle=None, flip_prob=None): 7 | transform_list = [] 8 | 9 | if scale is not None: 10 | transform_list.append(Scale(scale)) 11 | if angle is not None: 12 | transform_list.append(Rotate(angle)) 13 | if flip_prob is not None: 14 | transform_list.append(HorizontalFlip(flip_prob)) 15 | 16 | return Compose(transform_list) 17 | 18 | 19 | class Scale(object): 20 | 21 | def __init__(self, scale): 22 | self.scale = scale 23 | 24 | def __call__(self, sample): 25 | image, mask = sample 26 | 27 | img_size = image.shape[0] 28 | 29 | scale = np.random.uniform(low=1.0 - self.scale, high=1.0 + self.scale) 30 | 31 | image = rescale( 32 | image, 33 | (scale, scale), 34 | multichannel=True, 35 | preserve_range=True, 36 | mode="constant", 37 | anti_aliasing=False, 38 | ) 39 | mask = rescale( 40 | mask, 41 | (scale, scale), 42 | order=0, 43 | multichannel=True, 44 | preserve_range=True, 45 | mode="constant", 46 | anti_aliasing=False, 47 | ) 48 | 49 | if scale < 1.0: 50 | diff = (img_size - image.shape[0]) / 2.0 51 | padding = ((int(np.floor(diff)), int(np.ceil(diff))),) * 2 + ((0, 0),) 52 | image = np.pad(image, padding, mode="constant", constant_values=0) 53 | mask = np.pad(mask, padding, mode="constant", constant_values=0) 54 | else: 55 | x_min = (image.shape[0] - img_size) // 2 56 | x_max = x_min + img_size 57 | image = image[x_min:x_max, x_min:x_max, ...] 58 | mask = mask[x_min:x_max, x_min:x_max, ...] 59 | 60 | return image, mask 61 | 62 | 63 | class Rotate(object): 64 | 65 | def __init__(self, angle): 66 | self.angle = angle 67 | 68 | def __call__(self, sample): 69 | image, mask = sample 70 | 71 | angle = np.random.uniform(low=-self.angle, high=self.angle) 72 | image = rotate(image, angle, resize=False, preserve_range=True, mode="constant") 73 | mask = rotate( 74 | mask, angle, resize=False, order=0, preserve_range=True, mode="constant" 75 | ) 76 | return image, mask 77 | 78 | 79 | class HorizontalFlip(object): 80 | 81 | def __init__(self, flip_prob): 82 | self.flip_prob = flip_prob 83 | 84 | def __call__(self, sample): 85 | image, mask = sample 86 | 87 | if np.random.rand() > self.flip_prob: 88 | return image, mask 89 | 90 | image = np.fliplr(image).copy() 91 | mask = np.fliplr(mask).copy() 92 | 93 | return image, mask 94 | -------------------------------------------------------------------------------- /unet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class UNet(nn.Module): 8 | 9 | def __init__(self, in_channels=3, out_channels=1, init_features=32): 10 | super(UNet, self).__init__() 11 | 12 | features = init_features 13 | self.encoder1 = UNet._block(in_channels, features, name="enc1") 14 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 15 | self.encoder2 = UNet._block(features, features * 2, name="enc2") 16 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 17 | self.encoder3 = UNet._block(features * 2, features * 4, name="enc3") 18 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 19 | self.encoder4 = UNet._block(features * 4, features * 8, name="enc4") 20 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 21 | 22 | self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck") 23 | 24 | self.upconv4 = nn.ConvTranspose2d( 25 | features * 16, features * 8, kernel_size=2, stride=2 26 | ) 27 | self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4") 28 | self.upconv3 = nn.ConvTranspose2d( 29 | features * 8, features * 4, kernel_size=2, stride=2 30 | ) 31 | self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3") 32 | self.upconv2 = nn.ConvTranspose2d( 33 | features * 4, features * 2, kernel_size=2, stride=2 34 | ) 35 | self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2") 36 | self.upconv1 = nn.ConvTranspose2d( 37 | features * 2, features, kernel_size=2, stride=2 38 | ) 39 | self.decoder1 = UNet._block(features * 2, features, name="dec1") 40 | 41 | self.conv = nn.Conv2d( 42 | in_channels=features, out_channels=out_channels, kernel_size=1 43 | ) 44 | 45 | def forward(self, x): 46 | enc1 = self.encoder1(x) 47 | enc2 = self.encoder2(self.pool1(enc1)) 48 | enc3 = self.encoder3(self.pool2(enc2)) 49 | enc4 = self.encoder4(self.pool3(enc3)) 50 | 51 | bottleneck = self.bottleneck(self.pool4(enc4)) 52 | 53 | dec4 = self.upconv4(bottleneck) 54 | dec4 = torch.cat((dec4, enc4), dim=1) 55 | dec4 = self.decoder4(dec4) 56 | dec3 = self.upconv3(dec4) 57 | dec3 = torch.cat((dec3, enc3), dim=1) 58 | dec3 = self.decoder3(dec3) 59 | dec2 = self.upconv2(dec3) 60 | dec2 = torch.cat((dec2, enc2), dim=1) 61 | dec2 = self.decoder2(dec2) 62 | dec1 = self.upconv1(dec2) 63 | dec1 = torch.cat((dec1, enc1), dim=1) 64 | dec1 = self.decoder1(dec1) 65 | return torch.sigmoid(self.conv(dec1)) 66 | 67 | @staticmethod 68 | def _block(in_channels, features, name): 69 | return nn.Sequential( 70 | OrderedDict( 71 | [ 72 | ( 73 | name + "conv1", 74 | nn.Conv2d( 75 | in_channels=in_channels, 76 | out_channels=features, 77 | kernel_size=3, 78 | padding=1, 79 | bias=False, 80 | ), 81 | ), 82 | (name + "norm1", nn.BatchNorm2d(num_features=features)), 83 | (name + "relu1", nn.ReLU(inplace=True)), 84 | ( 85 | name + "conv2", 86 | nn.Conv2d( 87 | in_channels=features, 88 | out_channels=features, 89 | kernel_size=3, 90 | padding=1, 91 | bias=False, 92 | ), 93 | ), 94 | (name + "norm2", nn.BatchNorm2d(num_features=features)), 95 | (name + "relu2", nn.ReLU(inplace=True)), 96 | ] 97 | ) 98 | ) 99 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from medpy.filter.binary import largest_connected_component 3 | from skimage.exposure import rescale_intensity 4 | from skimage.transform import resize 5 | 6 | 7 | def dsc(y_pred, y_true, lcc=True): 8 | if lcc and np.any(y_pred): 9 | y_pred = np.round(y_pred).astype(int) 10 | y_true = np.round(y_true).astype(int) 11 | y_pred = largest_connected_component(y_pred) 12 | return np.sum(y_pred[y_true == 1]) * 2.0 / (np.sum(y_pred) + np.sum(y_true)) 13 | 14 | 15 | def crop_sample(x): 16 | volume, mask = x 17 | volume[volume < np.max(volume) * 0.1] = 0 18 | z_projection = np.max(np.max(np.max(volume, axis=-1), axis=-1), axis=-1) 19 | z_nonzero = np.nonzero(z_projection) 20 | z_min = np.min(z_nonzero) 21 | z_max = np.max(z_nonzero) + 1 22 | y_projection = np.max(np.max(np.max(volume, axis=0), axis=-1), axis=-1) 23 | y_nonzero = np.nonzero(y_projection) 24 | y_min = np.min(y_nonzero) 25 | y_max = np.max(y_nonzero) + 1 26 | x_projection = np.max(np.max(np.max(volume, axis=0), axis=0), axis=-1) 27 | x_nonzero = np.nonzero(x_projection) 28 | x_min = np.min(x_nonzero) 29 | x_max = np.max(x_nonzero) + 1 30 | return ( 31 | volume[z_min:z_max, y_min:y_max, x_min:x_max], 32 | mask[z_min:z_max, y_min:y_max, x_min:x_max], 33 | ) 34 | 35 | 36 | def pad_sample(x): 37 | volume, mask = x 38 | a = volume.shape[1] 39 | b = volume.shape[2] 40 | if a == b: 41 | return volume, mask 42 | diff = (max(a, b) - min(a, b)) / 2.0 43 | if a > b: 44 | padding = ((0, 0), (0, 0), (int(np.floor(diff)), int(np.ceil(diff)))) 45 | else: 46 | padding = ((0, 0), (int(np.floor(diff)), int(np.ceil(diff))), (0, 0)) 47 | mask = np.pad(mask, padding, mode="constant", constant_values=0) 48 | padding = padding + ((0, 0),) 49 | volume = np.pad(volume, padding, mode="constant", constant_values=0) 50 | return volume, mask 51 | 52 | 53 | def resize_sample(x, size=256): 54 | volume, mask = x 55 | v_shape = volume.shape 56 | out_shape = (v_shape[0], size, size) 57 | mask = resize( 58 | mask, 59 | output_shape=out_shape, 60 | order=0, 61 | mode="constant", 62 | cval=0, 63 | anti_aliasing=False, 64 | ) 65 | out_shape = out_shape + (v_shape[3],) 66 | volume = resize( 67 | volume, 68 | output_shape=out_shape, 69 | order=2, 70 | mode="constant", 71 | cval=0, 72 | anti_aliasing=False, 73 | ) 74 | return volume, mask 75 | 76 | 77 | def normalize_volume(volume): 78 | p10 = np.percentile(volume, 10) 79 | p99 = np.percentile(volume, 99) 80 | volume = rescale_intensity(volume, in_range=(p10, p99)) 81 | m = np.mean(volume, axis=(0, 1, 2)) 82 | s = np.std(volume, axis=(0, 1, 2)) 83 | volume = (volume - m) / s 84 | return volume 85 | 86 | 87 | def log_images(x, y_true, y_pred, channel=1): 88 | images = [] 89 | x_np = x[:, channel].cpu().numpy() 90 | y_true_np = y_true[:, 0].cpu().numpy() 91 | y_pred_np = y_pred[:, 0].cpu().numpy() 92 | for i in range(x_np.shape[0]): 93 | image = gray2rgb(np.squeeze(x_np[i])) 94 | image = outline(image, y_pred_np[i], color=[255, 0, 0]) 95 | image = outline(image, y_true_np[i], color=[0, 255, 0]) 96 | images.append(image) 97 | return images 98 | 99 | 100 | def gray2rgb(image): 101 | w, h = image.shape 102 | image += np.abs(np.min(image)) 103 | image_max = np.abs(np.max(image)) 104 | if image_max > 0: 105 | image /= image_max 106 | ret = np.empty((w, h, 3), dtype=np.uint8) 107 | ret[:, :, 2] = ret[:, :, 1] = ret[:, :, 0] = image * 255 108 | return ret 109 | 110 | 111 | def outline(image, mask, color): 112 | mask = np.round(mask) 113 | yy, xx = np.nonzero(mask) 114 | for y, x in zip(yy, xx): 115 | if 0.0 < np.mean(mask[max(0, y - 1) : y + 2, max(0, x - 1) : x + 2]) < 1.0: 116 | image[max(0, y) : y + 1, max(0, x) : x + 1] = color 117 | return image 118 | -------------------------------------------------------------------------------- /weights/unet.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mateuszbuda/brain-segmentation-pytorch/d45f8908ab2f0246ba204c702a6161c9eb25f902/weights/unet.pt --------------------------------------------------------------------------------