├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── data ├── README.md ├── convert.py ├── manifest-T2.tcia └── preprocess_requirements.txt ├── figs ├── DiceScore_comparision.png └── supervision_comparision.png ├── inference.py ├── models └── README.md ├── network └── unet2d5.py ├── requirements.txt ├── splits ├── split_fully_budget1.csv ├── split_fully_budget13.csv └── split_inextremis_budget1.csv ├── train.py └── utilities ├── figure_fully.py ├── focal.py ├── geodesics.py ├── losses.py ├── scores.py ├── utils.py └── write_geodesics.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "ScribbleDA"] 2 | path = ScribbleDA 3 | url = https://github.com/KCL-BMEIS/ScribbleDA/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 ReubenDo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Inter Extreme Points Geodesics for End-to-End Weakly Supervised Image Segmentation 2 | 3 | Public PyTorch implementation for our paper [Inter Extreme Points Geodesics for End-to-End Weakly Supervised Image Segmentation](https://arxiv.org/abs/2107.00583), 4 | which was accepted for presentation at [MICCAI 2021](https://www.miccai2021.org). 5 | 6 | If you find this code useful for your research, please cite the following paper: 7 | 8 | ``` 9 | @article{InExtremIS2021Dorent, 10 | author={Dorent, Reuben and Joutard, Samuel and Shapey, Jonathan and 11 | Kujawa, Aaron and Modat, Marc and Ourselin, S\'ebastien and Vercauteren, Tom}, 12 | title={Inter Extreme Points Geodesics for End-to-End Weakly Supervised Image Segmentation}, 13 | journal={MICCAI}, 14 | year={2021}, 15 | } 16 | ``` 17 | 18 | ## Method Overview 19 | We introduce InExtremIS, a weakly supervised 3D approach to train a deep image segmentation network using particularly weak train-time annotations: only 6 extreme clicks at the boundary of the objects of interest. Our fully automatic method is trained end-to-end and does not require any test-time annotations. 20 | 21 | *Example of weak labels for our use case of Vestibular Schwannoma (VS) segmentation. Magenta: Background. Green: VS:* 22 |

23 | 24 |

25 | 26 | 27 | 28 | ## Virtual Environment Setup 29 | 30 | The code is implemented in Python 3.6 using the PyTorch library. 31 | Requirements: 32 | 33 | * Set up a virtual environment (e.g. conda or virtualenv) with Python >=3.6.9 34 | * Install all requirements using: 35 | 36 | ````pip install -r requirements.txt```` 37 | * Install the CUDA implementation of the permutohedral lattice and the CRF Loss. 38 | ```` 39 | cd ScribbleDA/Permutohedral_attention_module/PAM_cuda/ 40 | python3 setup.py build 41 | python3 setup.py install --user 42 | ```` 43 | 44 | 45 | ## Data 46 | 47 | The data and annotations are publicly available. Details are provided in [data](/data/). 48 | 49 | ## Running the code 50 | `train.py` is the main file for training the models. 51 | 52 | Example 1: Training InExtreMIS with manual extreme points: 53 | ```` 54 | python3 train.py \ 55 | --model_dir ./models/manual_gradient_eucl_deep_crf/ \ 56 | --alpha 15 \ 57 | --beta 0.05 \ 58 | --weight_crf 0.0001 \ 59 | --mode extreme_points \ 60 | --label_postfix Extremes_man \ 61 | --img_gradient_postfix Sobel_man \ 62 | --path_data data/T2/ \ 63 | --path_labels data/extreme_points/manual/ \ 64 | --with_euclidean \ 65 | --with_prob 66 | ```` 67 | Example 2: Training InExtreMIS with simulated extreme points: 68 | ```` 69 | python3 train.py \ 70 | --model_dir ./models/simulated_gradient_eucl_deep_crf/ \ 71 | --alpha 15 \ 72 | --beta 0.05 \ 73 | --weight_crf 0.0001 \ 74 | --mode extreme_points \ 75 | --label_postfix Extremes \ 76 | --img_gradient_postfix Sobel \ 77 | --path_data data/T2/ \ 78 | --path_labels data/extreme_points/simulated/ \ 79 | --with_euclidean \ 80 | --with_prob 81 | ```` 82 | 83 | `inference.py` is the main file for running the inference: 84 | ```` 85 | python3 inference.py \ 86 | --model_dir ./models/manual_gradient_eucl_deep_crf/ \ 87 | --path_data data/T2/ \ 88 | ```` 89 | 90 | ## Using the code with your own data 91 | 92 | If you want to use your own data, you just need to change the source and target paths, 93 | the splits and potentially the modality used. 94 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Downloading the data used for the experiments 2 | 3 | In this work, we used a large (N=242) dataset for Vestibular Schwannoma segmentation. This dataset is publicly available on 4 | [TCIA](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=70229053). 5 | 6 | This readme explains how to download and pre-process the raw data from TCIA. We also provide an open access to the extreme points and pre-computed geodesics used in this work. 7 | 8 | ## Download the fully annotated TCIA-VS dataset 9 | 10 | ### Option 1 - Downloading the T2 scans only and their segmentation maps (Recommended): 11 | 12 | Please follow the following steps: 13 | 14 | **Step 1**: Download the NBIA Data Retriever: 15 | * Please follow the instructions [here](https://wiki.cancerimagingarchive.net/display/NBIA/Downloading+TCIA+Images). 16 | 17 | **Step 2**: Download the T2 scans only: 18 | * Open `manifest-T2.tcia` with NBIA Data Retriever and download the T2 images (DICOM, 6GB) with the "Descriptive Directory Name" format. 19 | 20 | **Step 3**: DICOM to Nifti conversion: 21 | * Install dependencies: `pip install -r preprocess_requirements.txt` 22 | * Execute the conversion script: 23 | `python3 convert.py --input --output ` 24 | * `` is the directory containing the raw T2 images (e.g. `/home/admin/manifest-T2/Vestibular-Schwannoma-SEG/`). 25 | * `` is the directory in which the pre-processed data will be saved. 26 | 27 | **Step 4**: Download the fully annotated segmentation masks [here](https://zenodo.org/record/5081986/files/full_annotations.zip?download=1). 28 | 29 | ### Option 2 - Downloading the full dataset and manually convert contours into segmentation masks: 30 | Please follow the instructions from the [VS_Seg repository](https://github.com/KCL-BMEIS/VS_Seg/tree/master/preprocessing). 31 | 32 | ## Download the extreme points and pre-computed geodesics 33 | The manual and simulated extreme points can be found [here](https://zenodo.org/record/5081986/files/extreme_points.zip?download=1). 34 | The pre-computed geodesics using the image gradient information (`grad` folder) and with the additional Euclidean distance (`grad_eucl` folder) can be found [here](https://zenodo.org/record/5081986/files/precomputed_geodesics.zip?download=1). 35 | 36 | The 6 extreme points are the voxels with the values _1,2,3,4,5,6_. Specifically, the pairs of extreme points in the _x_, _y_ and _z_ axis are respectively _{1,2}_, _{3,4}_ and _{5,6}_. 37 | 38 | ## Citations 39 | If you use this VS data, please cite: 40 | 41 | ``` 42 | @article { ShapeyJNS21, 43 | author = "Jonathan Shapey and Guotai Wang and Reuben Dorent and Alexis Dimitriadis and Wenqi Li and Ian Paddick and Neil Kitchen and Sotirios Bisdas and Shakeel R. Saeed and Sebastien Ourselin and Robert Bradford and Tom Vercauteren", 44 | title = "{An artificial intelligence framework for automatic segmentation and volumetry of vestibular schwannomas from contrast-enhanced T1-weighted and high-resolution T2-weighted MRI}", 45 | journal = "Journal of Neurosurgery JNS", 46 | year = "2021", 47 | publisher = "American Association of Neurological Surgeons", 48 | volume = "134", 49 | number = "1", 50 | doi = "10.3171/2019.9.JNS191949", 51 | pages= "171 - 179", 52 | url = "https://thejns.org/view/journals/j-neurosurg/134/1/article-p171.xml" 53 | } 54 | ``` 55 | 56 | If you use the extreme points, please additionally cite: 57 | 58 | ``` 59 | @article{InExtremIS2021Dorent, 60 | author={Dorent, Reuben and Joutard, Samuel and Shapey, Jonathan and 61 | Kujawa, Aaron and Modat, Marc and Ourselin, S\'ebastien and Vercauteren, Tom}, 62 | title={Inter Extreme Points Geodesics for End-to-End Weakly Supervised Image Segmentation}, 63 | journal={MICCAI}, 64 | year={2021}, 65 | } 66 | ``` 67 | ## Credits 68 | The conversion script is based on https://github.com/KCL-BMEIS/VS_Seg/tree/master/preprocessing. 69 | -------------------------------------------------------------------------------- /data/convert.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/KCL-BMEIS/VS_Seg/blob/master/preprocessing/TCIA_data_convert_into_convenient_folder_structure.py 2 | 3 | #!/usr/bin/env python 4 | # coding: utf-8 5 | 6 | import os 7 | from glob import glob 8 | from natsort import natsorted 9 | import pydicom 10 | import SimpleITK as sitk 11 | import re 12 | import argparse 13 | from tqdm import tqdm 14 | 15 | parser = argparse.ArgumentParser(description='Convert the T2 scans from the TCIA dataset into the Nifti format') 16 | parser.add_argument('--input', type=str, help='(string) path to TCIA dataset, in "Descriptive Directory Name" format, for example /home/user/.../manifest-T2/Vestibular-Schwannoma-SEG') 17 | parser.add_argument('--output', type=str, help='(string) path to output folder') 18 | args = parser.parse_args() 19 | 20 | input_path = args.input 21 | output_path = args.output 22 | 23 | if not os.path.isdir(output_path): 24 | os.makedirs(output_path, exist_ok=True) 25 | 26 | cases = natsorted(glob(os.path.join(input_path, '*'))) 27 | 28 | for case in tqdm(cases): 29 | folders = glob(case+'/*/*') 30 | 31 | MRs = [] 32 | MRs_paths = [] 33 | 34 | # Test that the DICOM is a MRI 35 | for folder in folders: 36 | first_file = glob(folder+"/*")[0] 37 | dd = pydicom.read_file(first_file) 38 | 39 | if dd['Modality'].value == 'MR': 40 | MRs.append(dd) 41 | MRs_paths.append(first_file) 42 | 43 | else: 44 | raise Exception 45 | 46 | file_paths = None 47 | # Test that the DICOM is a T2 scan 48 | for MR, path in zip(MRs, MRs_paths): 49 | if "t2_" in MR['SeriesDescription'].value: 50 | MR_T2 = MR 51 | file_paths = path 52 | else: 53 | raise Exception 54 | 55 | # write files into new folder structure 56 | p = re.compile(r'VS-SEG-(\d+)') 57 | case_idx = int(p.findall(case)[0]) 58 | old_T2_folder = os.path.dirname(file_paths) 59 | 60 | # Output path 61 | new_T2_path = os.path.join(output_path, 'vs_gk_' + str(case_idx) +'_T2.nii.gz') 62 | 63 | # Conversion DICOM to NIFTI using SITK 64 | reader = sitk.ImageSeriesReader() 65 | dicom_names = reader.GetGDCMSeriesFileNames(old_T2_folder) 66 | reader.SetFileNames(dicom_names) 67 | image = reader.Execute() 68 | 69 | sitk.WriteImage(image, new_T2_path) 70 | -------------------------------------------------------------------------------- /data/manifest-T2.tcia: -------------------------------------------------------------------------------- 1 | downloadServerUrl=https://public.cancerimagingarchive.net/nbia-download/servlet/DownloadServlet 2 | includeAnnotation=true 3 | noOfrRetry=4 4 | databasketId=manifest-T2.tcia 5 | manifestVersion=3.0 6 | ListOfSeriesToDownload= 7 | 1.3.6.1.4.1.14519.5.2.1.97824612055862366318560427964793890998 8 | 1.3.6.1.4.1.14519.5.2.1.124502390760917755084744319866454742880 9 | 1.3.6.1.4.1.14519.5.2.1.336209512066065542527308757738778134743 10 | 1.3.6.1.4.1.14519.5.2.1.61686433061457847902102354630069993489 11 | 1.3.6.1.4.1.14519.5.2.1.285266602767317398976732799437378140608 12 | 1.3.6.1.4.1.14519.5.2.1.128575943230315285809456617735815081338 13 | 1.3.6.1.4.1.14519.5.2.1.119974278459672124882397556735745886618 14 | 1.3.6.1.4.1.14519.5.2.1.328411971284129245304472742027457770689 15 | 1.3.6.1.4.1.14519.5.2.1.205715002896735863602196046118617324323 16 | 1.3.6.1.4.1.14519.5.2.1.47368711747357463116341094276606315239 17 | 1.3.6.1.4.1.14519.5.2.1.277564943275913764149628784943750055165 18 | 1.3.6.1.4.1.14519.5.2.1.216916846757178094156064648053287356705 19 | 1.3.6.1.4.1.14519.5.2.1.292451296497326330780007043491927669965 20 | 1.3.6.1.4.1.14519.5.2.1.37407646254753171628888832676854647641 21 | 1.3.6.1.4.1.14519.5.2.1.2154096351050073340417051795616964329 22 | 1.3.6.1.4.1.14519.5.2.1.288822954068116720576409516209064805137 23 | 1.3.6.1.4.1.14519.5.2.1.157455919411531329445971228461834454864 24 | 1.3.6.1.4.1.14519.5.2.1.300550246357357923662626736244214512354 25 | 1.3.6.1.4.1.14519.5.2.1.46875774646497052019561950116185371595 26 | 1.3.6.1.4.1.14519.5.2.1.265856471014025537322703012203178074916 27 | 1.3.6.1.4.1.14519.5.2.1.217554160929652517188403089378973632127 28 | 1.3.6.1.4.1.14519.5.2.1.81190427874393211914843328182235403675 29 | 1.3.6.1.4.1.14519.5.2.1.239006515845888908570518896552813145905 30 | 1.3.6.1.4.1.14519.5.2.1.226103095049025971830798326183529727687 31 | 1.3.6.1.4.1.14519.5.2.1.241519636120309329966889620612275264106 32 | 1.3.6.1.4.1.14519.5.2.1.250256352839480733008907732689438039992 33 | 1.3.6.1.4.1.14519.5.2.1.155625578034220411317764200158067934139 34 | 1.3.6.1.4.1.14519.5.2.1.223264893116803537954728806540487445907 35 | 1.3.6.1.4.1.14519.5.2.1.56803546825794638336125075839814402436 36 | 1.3.6.1.4.1.14519.5.2.1.168133623024284256524654755031136793033 37 | 1.3.6.1.4.1.14519.5.2.1.41051157220867917432644808566260695421 38 | 1.3.6.1.4.1.14519.5.2.1.13168934333778663511634424074788249440 39 | 1.3.6.1.4.1.14519.5.2.1.98519287807856841042726497993241875149 40 | 1.3.6.1.4.1.14519.5.2.1.67345468089650726208933617170352542940 41 | 1.3.6.1.4.1.14519.5.2.1.128016141184262134692211776899422909852 42 | 1.3.6.1.4.1.14519.5.2.1.261904727650229936032200059779168110546 43 | 1.3.6.1.4.1.14519.5.2.1.80306368584051877336147767062103188154 44 | 1.3.6.1.4.1.14519.5.2.1.81013456284548964428831720032766314407 45 | 1.3.6.1.4.1.14519.5.2.1.137851401856381638890659062774470307621 46 | 1.3.6.1.4.1.14519.5.2.1.274222328416232835281613446776194950887 47 | 1.3.6.1.4.1.14519.5.2.1.189932513571932107895723417841672250908 48 | 1.3.6.1.4.1.14519.5.2.1.218365832761211504858210426893853705350 49 | 1.3.6.1.4.1.14519.5.2.1.299478506810267521121655201175658435054 50 | 1.3.6.1.4.1.14519.5.2.1.175770890260169282489793361209219236540 51 | 1.3.6.1.4.1.14519.5.2.1.288937279368837050893511455153770225852 52 | 1.3.6.1.4.1.14519.5.2.1.137305226373921385879942188391264951989 53 | 1.3.6.1.4.1.14519.5.2.1.159740827447177453253673680290849733064 54 | 1.3.6.1.4.1.14519.5.2.1.233605121987468319370647828457502185983 55 | 1.3.6.1.4.1.14519.5.2.1.29747780949545019036183341430656302425 56 | 1.3.6.1.4.1.14519.5.2.1.233973029399216582850133589760654139151 57 | 1.3.6.1.4.1.14519.5.2.1.9669015644716471476945977883357982292 58 | 1.3.6.1.4.1.14519.5.2.1.67158313966057965569167946723639338387 59 | 1.3.6.1.4.1.14519.5.2.1.78418533061355773065152791264029309750 60 | 1.3.6.1.4.1.14519.5.2.1.95190271536129325658493975898105509278 61 | 1.3.6.1.4.1.14519.5.2.1.8744922103190999130080234468699487701 62 | 1.3.6.1.4.1.14519.5.2.1.277261916620071690562868021692602577695 63 | 1.3.6.1.4.1.14519.5.2.1.333826283119726799662204205087684554536 64 | 1.3.6.1.4.1.14519.5.2.1.294465790050902079649170070800458333118 65 | 1.3.6.1.4.1.14519.5.2.1.230012513674586138089728143330133632463 66 | 1.3.6.1.4.1.14519.5.2.1.35307211638230668632830028521014104687 67 | 1.3.6.1.4.1.14519.5.2.1.337494328277457405561277816098934383813 68 | 1.3.6.1.4.1.14519.5.2.1.162851729691398294622797105490114199357 69 | 1.3.6.1.4.1.14519.5.2.1.100420693789734401722866581473132024737 70 | 1.3.6.1.4.1.14519.5.2.1.229384024656851675147876036353267258487 71 | 1.3.6.1.4.1.14519.5.2.1.304198865028657374537924002347255938106 72 | 1.3.6.1.4.1.14519.5.2.1.20670976059556792609125791949916925570 73 | 1.3.6.1.4.1.14519.5.2.1.229528759693718677584709692621291077626 74 | 1.3.6.1.4.1.14519.5.2.1.21877083602789524438704314882665653577 75 | 1.3.6.1.4.1.14519.5.2.1.86855432222058507463718198553899529003 76 | 1.3.6.1.4.1.14519.5.2.1.223475447957754453121941963938410079266 77 | 1.3.6.1.4.1.14519.5.2.1.117858331830377878190511865298944867098 78 | 1.3.6.1.4.1.14519.5.2.1.15167479224344864655434050143434357061 79 | 1.3.6.1.4.1.14519.5.2.1.297889053157145327024131887531428225347 80 | 1.3.6.1.4.1.14519.5.2.1.90465408233381138722540652694651320969 81 | 1.3.6.1.4.1.14519.5.2.1.302714639670083142058852903953837264697 82 | 1.3.6.1.4.1.14519.5.2.1.16904271606903764104530936528625259781 83 | 1.3.6.1.4.1.14519.5.2.1.12764749088880384028924045444467299852 84 | 1.3.6.1.4.1.14519.5.2.1.267862471627412118753428341131185702764 85 | 1.3.6.1.4.1.14519.5.2.1.303624579553286757476664912625222198071 86 | 1.3.6.1.4.1.14519.5.2.1.82000297190327313616877429419376399973 87 | 1.3.6.1.4.1.14519.5.2.1.110063895628460115704321585826189095537 88 | 1.3.6.1.4.1.14519.5.2.1.336499584971688286794823333380374653244 89 | 1.3.6.1.4.1.14519.5.2.1.139392630476501192761585947339412212897 90 | 1.3.6.1.4.1.14519.5.2.1.113300425991115767575938977397298840060 91 | 1.3.6.1.4.1.14519.5.2.1.307568576320578693001905075814766671550 92 | 1.3.6.1.4.1.14519.5.2.1.282832052300721436992245894412315111912 93 | 1.3.6.1.4.1.14519.5.2.1.213400619248062551895416596109098398213 94 | 1.3.6.1.4.1.14519.5.2.1.36934820552005519735534479931332292846 95 | 1.3.6.1.4.1.14519.5.2.1.41216239343660574082540246837790859853 96 | 1.3.6.1.4.1.14519.5.2.1.63157081399858537808770660614894164681 97 | 1.3.6.1.4.1.14519.5.2.1.167260142371217243657393980232701016137 98 | 1.3.6.1.4.1.14519.5.2.1.61398984386978615487608646301170730391 99 | 1.3.6.1.4.1.14519.5.2.1.170913910290121505784426628803446970389 100 | 1.3.6.1.4.1.14519.5.2.1.208490847474630115886915444238291380368 101 | 1.3.6.1.4.1.14519.5.2.1.298671145848708339737222217998577755570 102 | 1.3.6.1.4.1.14519.5.2.1.9755483032379805473706391940166669036 103 | 1.3.6.1.4.1.14519.5.2.1.268718756853741766983354453912312801730 104 | 1.3.6.1.4.1.14519.5.2.1.261237309343469689065576631723962333840 105 | 1.3.6.1.4.1.14519.5.2.1.224003143144336561123702641959818284842 106 | 1.3.6.1.4.1.14519.5.2.1.163015487469502642386883094241440581072 107 | 1.3.6.1.4.1.14519.5.2.1.170347731217393158076277121980482188304 108 | 1.3.6.1.4.1.14519.5.2.1.199464597880518074498752987949770064952 109 | 1.3.6.1.4.1.14519.5.2.1.20347614015017630361319142934914011596 110 | 1.3.6.1.4.1.14519.5.2.1.9721409353341461465501650910815656993 111 | 1.3.6.1.4.1.14519.5.2.1.14262170121308979137577589173619571563 112 | 1.3.6.1.4.1.14519.5.2.1.311333098893435802624860311753091976002 113 | 1.3.6.1.4.1.14519.5.2.1.220840979774567365812993669047801644120 114 | 1.3.6.1.4.1.14519.5.2.1.77558762586485924195080424411910919171 115 | 1.3.6.1.4.1.14519.5.2.1.168507932534014014653293784243131332894 116 | 1.3.6.1.4.1.14519.5.2.1.116723325586106252707395359498797460520 117 | 1.3.6.1.4.1.14519.5.2.1.29910714713700110800496924621325713622 118 | 1.3.6.1.4.1.14519.5.2.1.52208737328500500547869186262562161680 119 | 1.3.6.1.4.1.14519.5.2.1.41341777065628435310650949856777228905 120 | 1.3.6.1.4.1.14519.5.2.1.251606799383580969844189799392861321861 121 | 1.3.6.1.4.1.14519.5.2.1.282485957146535605228825771941012962357 122 | 1.3.6.1.4.1.14519.5.2.1.317359270245808302145533856585672960883 123 | 1.3.6.1.4.1.14519.5.2.1.135514198492981058482845815249880889690 124 | 1.3.6.1.4.1.14519.5.2.1.96562875876973827961400745245219169962 125 | 1.3.6.1.4.1.14519.5.2.1.256648491507353229813182334221467374181 126 | 1.3.6.1.4.1.14519.5.2.1.70317311370767903368272174720906718342 127 | 1.3.6.1.4.1.14519.5.2.1.127703226653324327717860855502627722306 128 | 1.3.6.1.4.1.14519.5.2.1.17891221028856069257173473170851398654 129 | 1.3.6.1.4.1.14519.5.2.1.30354814832219495751611811177811844787 130 | 1.3.6.1.4.1.14519.5.2.1.196082117235151381832692400871564748236 131 | 1.3.6.1.4.1.14519.5.2.1.334508876733181239165112956951269361398 132 | 1.3.6.1.4.1.14519.5.2.1.137213097494387692044497332427005289056 133 | 1.3.6.1.4.1.14519.5.2.1.66602819257976265659201750697700497901 134 | 1.3.6.1.4.1.14519.5.2.1.254907007513982459959439589947824133569 135 | 1.3.6.1.4.1.14519.5.2.1.277932022685732692159807028685158125644 136 | 1.3.6.1.4.1.14519.5.2.1.75869570711830604911027500287084681088 137 | 1.3.6.1.4.1.14519.5.2.1.43394638547031237271853527848743247298 138 | 1.3.6.1.4.1.14519.5.2.1.219920872738792573764887676715372489467 139 | 1.3.6.1.4.1.14519.5.2.1.304966211758865415213459139643277133250 140 | 1.3.6.1.4.1.14519.5.2.1.59097307935533647142368110661201694954 141 | 1.3.6.1.4.1.14519.5.2.1.104342378995544082321527649818115719381 142 | 1.3.6.1.4.1.14519.5.2.1.275976308826344111988529521225074469122 143 | 1.3.6.1.4.1.14519.5.2.1.76966817086725830177633239884425707885 144 | 1.3.6.1.4.1.14519.5.2.1.37500502292073832653905111102576102168 145 | 1.3.6.1.4.1.14519.5.2.1.15051876378477273432527967382885690193 146 | 1.3.6.1.4.1.14519.5.2.1.75023595377115501952746333684971209569 147 | 1.3.6.1.4.1.14519.5.2.1.10244709163143315037706780774420781635 148 | 1.3.6.1.4.1.14519.5.2.1.315702370575273964468589032564024942002 149 | 1.3.6.1.4.1.14519.5.2.1.232581844696793658570866321285325476363 150 | 1.3.6.1.4.1.14519.5.2.1.255106474671713017130022360202121610175 151 | 1.3.6.1.4.1.14519.5.2.1.184438972748249848205519432640579417556 152 | 1.3.6.1.4.1.14519.5.2.1.15594630066063018098965121607394131685 153 | 1.3.6.1.4.1.14519.5.2.1.77895049258647801574952424776894897832 154 | 1.3.6.1.4.1.14519.5.2.1.194222281136456272002461195297286978806 155 | 1.3.6.1.4.1.14519.5.2.1.292361225108641180764247877135850676426 156 | 1.3.6.1.4.1.14519.5.2.1.84787388055021215524629025646692795407 157 | 1.3.6.1.4.1.14519.5.2.1.126181735854342571807747278015493395427 158 | 1.3.6.1.4.1.14519.5.2.1.50857688276531731282813856253493358728 159 | 1.3.6.1.4.1.14519.5.2.1.158363608374533284432503372099530292444 160 | 1.3.6.1.4.1.14519.5.2.1.221009419007488062284911758544674914233 161 | 1.3.6.1.4.1.14519.5.2.1.214587129890999158719654764602468482956 162 | 1.3.6.1.4.1.14519.5.2.1.301199852589293592180667435413243749421 163 | 1.3.6.1.4.1.14519.5.2.1.91307143982593757658329925269213107444 164 | 1.3.6.1.4.1.14519.5.2.1.115916457443958592970450702751155898672 165 | 1.3.6.1.4.1.14519.5.2.1.251387290435470956632040740033833126539 166 | 1.3.6.1.4.1.14519.5.2.1.35592985115515755107463670280222959193 167 | 1.3.6.1.4.1.14519.5.2.1.189790696754423307903604808866635619814 168 | 1.3.6.1.4.1.14519.5.2.1.278447060082661819769695941196761676970 169 | 1.3.6.1.4.1.14519.5.2.1.74127510988834203379779920116706931829 170 | 1.3.6.1.4.1.14519.5.2.1.158837645871384737652904421894110435515 171 | 1.3.6.1.4.1.14519.5.2.1.261027331158345315128560572348660712113 172 | 1.3.6.1.4.1.14519.5.2.1.320208736433890307354533957518098015251 173 | 1.3.6.1.4.1.14519.5.2.1.255721252600754855960977522233110250936 174 | 1.3.6.1.4.1.14519.5.2.1.265562325801308771480973706098784037904 175 | 1.3.6.1.4.1.14519.5.2.1.59496075545541835714082685052623985780 176 | 1.3.6.1.4.1.14519.5.2.1.29707760275931807980484755409216479658 177 | 1.3.6.1.4.1.14519.5.2.1.71847268882604614946395654926696180481 178 | 1.3.6.1.4.1.14519.5.2.1.139899967629592458992273833701974367027 179 | 1.3.6.1.4.1.14519.5.2.1.60645495344224097120659188750733663433 180 | 1.3.6.1.4.1.14519.5.2.1.76234766989041933862422924921980946094 181 | 1.3.6.1.4.1.14519.5.2.1.54647801429999999880279745055162166396 182 | 1.3.6.1.4.1.14519.5.2.1.245881267132473546456095715670497260366 183 | 1.3.6.1.4.1.14519.5.2.1.271500782689034842482795482397599708503 184 | 1.3.6.1.4.1.14519.5.2.1.260858848998229883842233413025317121641 185 | 1.3.6.1.4.1.14519.5.2.1.138787690545499304848411355088748136195 186 | 1.3.6.1.4.1.14519.5.2.1.138114624312317833838264235708033684106 187 | 1.3.6.1.4.1.14519.5.2.1.190044300314624874115076527729396088731 188 | 1.3.6.1.4.1.14519.5.2.1.327102410004294511180793158899656978209 189 | 1.3.6.1.4.1.14519.5.2.1.290140223859345390541699250506135777329 190 | 1.3.6.1.4.1.14519.5.2.1.235533134899035107929505275392119110300 191 | 1.3.6.1.4.1.14519.5.2.1.256223606579513472138624134007553794517 192 | 1.3.6.1.4.1.14519.5.2.1.292025009655148971742680910641710654260 193 | 1.3.6.1.4.1.14519.5.2.1.278019916823314020426350685915747751677 194 | 1.3.6.1.4.1.14519.5.2.1.119338114603372415979087053769916795752 195 | 1.3.6.1.4.1.14519.5.2.1.154846187556208082175515830387781463718 196 | 1.3.6.1.4.1.14519.5.2.1.241158227314558271467289218638320805912 197 | 1.3.6.1.4.1.14519.5.2.1.229985221915959115066070872369878709015 198 | 1.3.6.1.4.1.14519.5.2.1.129069453057329278334463872222670450244 199 | 1.3.6.1.4.1.14519.5.2.1.130306192263501254884259911459312278931 200 | 1.3.6.1.4.1.14519.5.2.1.99099341904959222126543542841859021660 201 | 1.3.6.1.4.1.14519.5.2.1.226937211758235683777950388828543487213 202 | 1.3.6.1.4.1.14519.5.2.1.222046637334875981449965463114585543169 203 | 1.3.6.1.4.1.14519.5.2.1.246909769442555605040748585982394591148 204 | 1.3.6.1.4.1.14519.5.2.1.197011928085586949947985979972960177494 205 | 1.3.6.1.4.1.14519.5.2.1.210688939016236904061492475261436050128 206 | 1.3.6.1.4.1.14519.5.2.1.194203897008456569475098794782196805092 207 | 1.3.6.1.4.1.14519.5.2.1.243028238309713788501554716738850686675 208 | 1.3.6.1.4.1.14519.5.2.1.333979424673348586669984373810626623120 209 | 1.3.6.1.4.1.14519.5.2.1.195824806623977377348042506440300253900 210 | 1.3.6.1.4.1.14519.5.2.1.86663105982361050637985901375777674119 211 | 1.3.6.1.4.1.14519.5.2.1.80596214383930081973613702686831725427 212 | 1.3.6.1.4.1.14519.5.2.1.297601975245157269521546365597109246887 213 | 1.3.6.1.4.1.14519.5.2.1.175508249861315421813970648276848747955 214 | 1.3.6.1.4.1.14519.5.2.1.197835772424729511288285229448608628340 215 | 1.3.6.1.4.1.14519.5.2.1.216583092100832104191093532137558942068 216 | 1.3.6.1.4.1.14519.5.2.1.323720174552274881744647199360391559339 217 | 1.3.6.1.4.1.14519.5.2.1.76123013711237089310162869568540894767 218 | 1.3.6.1.4.1.14519.5.2.1.57055430988757827707470998009719003163 219 | 1.3.6.1.4.1.14519.5.2.1.282176336559963455064991060259622459376 220 | 1.3.6.1.4.1.14519.5.2.1.27477011593957184467450523730643807258 221 | 1.3.6.1.4.1.14519.5.2.1.210671159041526701646380148425994405952 222 | 1.3.6.1.4.1.14519.5.2.1.251865584911266925837724605787007632031 223 | 1.3.6.1.4.1.14519.5.2.1.191494736838268150020432657327243028009 224 | 1.3.6.1.4.1.14519.5.2.1.44323770382390486457593204813672763488 225 | 1.3.6.1.4.1.14519.5.2.1.156769265159911961821857178602133057319 226 | 1.3.6.1.4.1.14519.5.2.1.216294443593141688683154232645411947601 227 | 1.3.6.1.4.1.14519.5.2.1.201429277212342505564727134166942685942 228 | 1.3.6.1.4.1.14519.5.2.1.232657652330222096161712085710696008345 229 | 1.3.6.1.4.1.14519.5.2.1.232111980308888243130020179540191986379 230 | 1.3.6.1.4.1.14519.5.2.1.171450547646265436894896475413907994571 231 | 1.3.6.1.4.1.14519.5.2.1.260208429721371945913593488747008392402 232 | 1.3.6.1.4.1.14519.5.2.1.96074972317210856257940726753393079402 233 | 1.3.6.1.4.1.14519.5.2.1.3281197893335564763282849900767310827 234 | 1.3.6.1.4.1.14519.5.2.1.33496187979283913017539607570678513263 235 | 1.3.6.1.4.1.14519.5.2.1.67049266716817140881910601189350295111 236 | 1.3.6.1.4.1.14519.5.2.1.53214686551225675208337966913982664260 237 | 1.3.6.1.4.1.14519.5.2.1.23925966695469658554360021791190910677 238 | 1.3.6.1.4.1.14519.5.2.1.249041531104210908961143956340899475303 239 | 1.3.6.1.4.1.14519.5.2.1.233826382362237990205162021534559156502 240 | 1.3.6.1.4.1.14519.5.2.1.144422009854692416651143605689152425184 241 | 1.3.6.1.4.1.14519.5.2.1.106786035069646005027604098072287812951 242 | 1.3.6.1.4.1.14519.5.2.1.236752444493873818319999694516257268975 243 | 1.3.6.1.4.1.14519.5.2.1.238358624802117233802379047360345454460 244 | 1.3.6.1.4.1.14519.5.2.1.300748886239012525308783476240258125999 245 | 1.3.6.1.4.1.14519.5.2.1.271262780142356626547695550879331444566 246 | 1.3.6.1.4.1.14519.5.2.1.83992889315124485147505742018229332342 247 | 1.3.6.1.4.1.14519.5.2.1.288810963554028825092258346944809083760 248 | 1.3.6.1.4.1.14519.5.2.1.173896487730572575760488822814514876976 249 | -------------------------------------------------------------------------------- /data/preprocess_requirements.txt: -------------------------------------------------------------------------------- 1 | SimpleITK 2 | pydicom 3 | tqdm 4 | natsort -------------------------------------------------------------------------------- /figs/DiceScore_comparision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReubenDo/InExtremIS/1512ddf9b8c11c4d9f0ebd465d904ef3d539d350/figs/DiceScore_comparision.png -------------------------------------------------------------------------------- /figs/supervision_comparision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReubenDo/InExtremIS/1512ddf9b8c11c4d9f0ebd465d904ef3d539d350/figs/supervision_comparision.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import os 6 | from tqdm import tqdm 7 | 8 | import pandas as pd 9 | 10 | import torch 11 | from torch import nn 12 | 13 | from monai.data import DataLoader, Dataset, NiftiSaver 14 | from monai.transforms import ( 15 | Compose, 16 | LoadNiftid, 17 | AddChanneld, 18 | NormalizeIntensityd, 19 | Orientationd, 20 | ToTensord, 21 | ) 22 | from monai.utils import set_determinism 23 | from monai.inferers import sliding_window_inference 24 | 25 | from network.unet2d5 import UNet2D5 26 | 27 | # Define training and patches sampling parameters 28 | SPATIAL_SHAPE = (224,224,48) 29 | 30 | NB_CLASSES = 2 31 | 32 | # Number of worker 33 | workers = 20 34 | 35 | # Training parameters 36 | val_eval_criterion_alpha = 0.95 37 | train_loss_MA_alpha = 0.95 38 | nb_patience = 10 39 | patience_lr = 5 40 | weight_decay = 1e-5 41 | 42 | PHASES = ['training', 'validation', 'inference'] 43 | 44 | def infinite_iterable(i): 45 | while True: 46 | yield from i 47 | 48 | def inference(paths_dict, model, transform_inference, device, opt): 49 | 50 | # Define transforms for data normalization and augmentation 51 | dataloaders = dict() 52 | subjects_dataset = dict() 53 | 54 | checkpoint_path = os.path.join(opt.model_dir,'models', './CP_{}.pth') 55 | checkpoint_path = checkpoint_path.format(opt.epoch_inf) 56 | assert os.path.isfile(checkpoint_path), 'no checkpoint found' 57 | print(checkpoint_path) 58 | model.load_state_dict(torch.load(checkpoint_path)) 59 | 60 | model = model.to(device) 61 | 62 | for phase in ['inference']: 63 | subjects_dataset[phase] = Dataset(paths_dict, transform=transform_inference) 64 | dataloaders[phase] = DataLoader(subjects_dataset[phase], batch_size=1, shuffle=False) 65 | 66 | 67 | model.eval() # Set model to evaluate mode 68 | 69 | fold_name = 'output_pred' 70 | # Iterate over data 71 | with torch.no_grad(): 72 | saver = NiftiSaver(output_dir=os.path.join(opt.model_dir,fold_name)) 73 | for batch in tqdm(dataloaders['inference']): 74 | inputs = batch['img'].to(device) 75 | 76 | pred = sliding_window_inference(inputs, opt.spatial_shape, 1, model, mode='gaussian') 77 | 78 | pred = pred.argmax(1, keepdim=True).detach() 79 | saver.save_batch(pred, batch["img_meta_dict"]) 80 | 81 | 82 | def main(): 83 | opt = parsing_data() 84 | 85 | set_determinism(seed=2) 86 | 87 | if torch.cuda.is_available(): 88 | print('[INFO] GPU available.') 89 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 90 | else: 91 | raise Exception( 92 | "[INFO] No GPU found.") 93 | 94 | 95 | print("[INFO] Reading data") 96 | # PHASES 97 | split_path = os.path.join(opt.dataset_split) 98 | df_split = pd.read_csv(split_path,header =None) 99 | list_file = dict() 100 | for phase in PHASES: # list of patient name associated to each phase 101 | list_file[phase] = df_split[df_split[1].isin([phase])][0].tolist() 102 | 103 | # CREATING DICT FOR DATASET 104 | mod_ext = "_T2.nii.gz" 105 | paths_dict = {split:[] for split in PHASES} 106 | for split in PHASES: 107 | for subject in list_file[split]: 108 | subject_data = dict() 109 | if os.path.exists(os.path.join(opt.path_data,subject+mod_ext)): 110 | subject_data["img"] = os.path.join(opt.path_data,subject+mod_ext) 111 | paths_dict[split].append(subject_data) 112 | print(f"Nb patients in {split} data: {len(paths_dict[split])}") 113 | 114 | # Logging hyperparameters 115 | print("[INFO] Hyperparameters") 116 | print('Spatial shape: {}'.format(opt.spatial_shape)) 117 | print(f"Inference on the {opt.phase} set") 118 | 119 | # PREPROCESSING 120 | all_keys = ["img"] 121 | test_transforms = Compose( 122 | ( 123 | LoadNiftid(keys=all_keys), 124 | AddChanneld(keys=all_keys), 125 | Orientationd(keys=all_keys, axcodes="RAS"), 126 | NormalizeIntensityd(keys=all_keys), 127 | ToTensord(keys=all_keys) 128 | ) 129 | ) 130 | 131 | # MODEL 132 | norm_op_kwargs = {"eps": 1e-5, "affine": True} 133 | net_nonlin = nn.LeakyReLU 134 | net_nonlin_kwargs = {"negative_slope": 1e-2, "inplace": True} 135 | 136 | model= UNet2D5(input_channels=1, 137 | base_num_features=16, 138 | num_classes=NB_CLASSES, 139 | num_pool=4, 140 | conv_op=nn.Conv3d, 141 | norm_op=nn.InstanceNorm3d, 142 | norm_op_kwargs=norm_op_kwargs, 143 | nonlin=net_nonlin, 144 | nonlin_kwargs=net_nonlin_kwargs).to(device) 145 | 146 | print("[INFO] Inference") 147 | inference(paths_dict[opt.phase], model, test_transforms, device, opt) 148 | 149 | 150 | def parsing_data(): 151 | parser = argparse.ArgumentParser( 152 | description='Performing inference') 153 | 154 | 155 | parser.add_argument('--model_dir', 156 | type=str) 157 | 158 | parser.add_argument("--dataset_split", 159 | type=str, 160 | default="splits/split_inextremis_budget1.csv") 161 | 162 | parser.add_argument("--path_data", 163 | type=str, 164 | default="data/VS_MICCAI21/T2/") 165 | 166 | parser.add_argument('--phase', 167 | type=str, 168 | default='inference') 169 | 170 | parser.add_argument('--spatial_shape', 171 | type=int, 172 | nargs="+", 173 | default=(224,224,48)) 174 | 175 | parser.add_argument('--epoch_inf', 176 | type=str, 177 | default='best') 178 | 179 | opt = parser.parse_args() 180 | 181 | return opt 182 | 183 | if __name__ == '__main__': 184 | main() 185 | 186 | 187 | 188 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | The pre-trained models can be downloaded [here](https://zenodo.org/record/5081163/files/models.zip?download=1). 2 | -------------------------------------------------------------------------------- /network/unet2d5.py: -------------------------------------------------------------------------------- 1 | # CODE ADAPTED FROM: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/network_architecture/generic_UNet.py 2 | 3 | from torch import nn 4 | import torch 5 | import numpy as np 6 | import torch.nn.functional 7 | 8 | 9 | class InitWeights_He(object): 10 | def __init__(self, neg_slope=1e-2): 11 | self.neg_slope = neg_slope 12 | 13 | def __call__(self, module): 14 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d): 15 | module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope) 16 | if module.bias is not None: 17 | module.bias = nn.init.constant_(module.bias, 0) 18 | 19 | 20 | class ConvNormNonlinBlock(nn.Module): 21 | def __init__( 22 | self, 23 | input_channels, 24 | output_channels, 25 | conv_op=nn.Conv3d, 26 | conv_kwargs=None, 27 | norm_op=nn.InstanceNorm3d, 28 | norm_op_kwargs=None, 29 | nonlin=nn.LeakyReLU, 30 | nonlin_kwargs=None): 31 | 32 | """ 33 | Block: Conv->Norm->Activation->Conv->Norm->Activation 34 | """ 35 | 36 | super(ConvNormNonlinBlock, self).__init__() 37 | 38 | self.nonlin_kwargs = nonlin_kwargs 39 | self.nonlin = nonlin 40 | self.conv_op = conv_op 41 | self.norm_op = norm_op 42 | self.norm_op_kwargs = norm_op_kwargs 43 | self.conv_kwargs = conv_kwargs 44 | self.output_channels = output_channels 45 | 46 | self.first_conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs) 47 | self.first_norm = self.norm_op(output_channels, **self.norm_op_kwargs) 48 | self.first_acti = self.nonlin(**self.nonlin_kwargs) 49 | 50 | self.second_conv = self.conv_op(output_channels, output_channels, **self.conv_kwargs) 51 | self.second_norm = self.norm_op(output_channels, **self.norm_op_kwargs) 52 | self.second_acti = self.nonlin(**self.nonlin_kwargs) 53 | 54 | self.block = nn.Sequential( 55 | self.first_conv, 56 | self.first_norm, 57 | self.first_acti, 58 | self.second_conv, 59 | self.second_norm, 60 | self.second_acti 61 | ) 62 | 63 | 64 | def forward(self, x): 65 | return self.block(x) 66 | 67 | 68 | 69 | class Upsample(nn.Module): 70 | def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False): 71 | super(Upsample, self).__init__() 72 | self.align_corners = align_corners 73 | self.mode = mode 74 | self.scale_factor = scale_factor 75 | self.size = size 76 | 77 | def forward(self, x): 78 | return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) 79 | 80 | 81 | class UNet2D5(nn.Module): 82 | 83 | def __init__( 84 | self, 85 | input_channels, 86 | base_num_features, 87 | num_classes, 88 | num_pool, 89 | conv_op=nn.Conv3d, 90 | conv_kernel_sizes=None, 91 | norm_op=nn.InstanceNorm3d, 92 | norm_op_kwargs=None, 93 | nonlin=nn.LeakyReLU, 94 | nonlin_kwargs=None, 95 | weightInitializer=InitWeights_He(1e-2)): 96 | """ 97 | 2.5D CNN combining 2D and 3D convolutions to dealwith the low through-plane resolution. 98 | The first two stages have 2D convolutions while the others have 3D convolutions. 99 | 100 | Architecture inspired by: 101 | Wang,et al: Automatic segmentation of vestibular schwannoma from t2-weighted mri 102 | by deep spatial attention with hardness-weighted loss. MICCAI 2019. 103 | """ 104 | super(UNet2D5, self).__init__() 105 | 106 | 107 | if nonlin_kwargs is None: 108 | nonlin_kwargs = {'negative_slope':1e-2, 'inplace':True} 109 | 110 | if norm_op_kwargs is None: 111 | norm_op_kwargs = {'eps':1e-5, 'affine':True, 'momentum':0.1} 112 | 113 | self.conv_kwargs = {'stride':1, 'dilation':1, 'bias':True} 114 | 115 | self.nonlin = nonlin 116 | self.nonlin_kwargs = nonlin_kwargs 117 | self.norm_op_kwargs = norm_op_kwargs 118 | self.weightInitializer = weightInitializer 119 | self.conv_op = conv_op 120 | self.norm_op = norm_op 121 | self.num_classes = num_classes 122 | 123 | upsample_mode = 'trilinear' 124 | pool_op = nn.MaxPool3d 125 | pool_op_kernel_sizes = [(2, 2, 2)] * num_pool 126 | if conv_kernel_sizes is None: 127 | conv_kernel_sizes = [(3, 3, 1)] * 2 + [(3,3,3)]*(num_pool - 1) 128 | 129 | 130 | self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64) 131 | self.pool_op_kernel_sizes = pool_op_kernel_sizes 132 | self.conv_kernel_sizes = conv_kernel_sizes 133 | 134 | self.conv_pad_sizes = [] 135 | for krnl in self.conv_kernel_sizes: 136 | self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl]) 137 | 138 | 139 | 140 | self.conv_blocks_context = [] 141 | self.conv_blocks_localization = [] 142 | self.td = [] 143 | self.tu = [] 144 | self.seg_outputs = [] 145 | 146 | 147 | input_features = input_channels 148 | output_features = base_num_features 149 | 150 | 151 | for d in range(num_pool): 152 | self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d] 153 | self.conv_kwargs['padding'] = self.conv_pad_sizes[d] 154 | # add convolutions 155 | 156 | self.conv_blocks_context.append(ConvNormNonlinBlock(input_features, output_features, 157 | self.conv_op, self.conv_kwargs, self.norm_op, 158 | self.norm_op_kwargs, self.nonlin, self.nonlin_kwargs)) 159 | 160 | self.td.append(pool_op(pool_op_kernel_sizes[d])) 161 | input_features = output_features 162 | output_features = 2* output_features # Number of kernel increases by a factor 2 after each pooling 163 | 164 | 165 | final_num_features = self.conv_blocks_context[-1].output_channels 166 | self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool] 167 | self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool] 168 | self.conv_blocks_context.append(ConvNormNonlinBlock(input_features, final_num_features, 169 | self.conv_op, self.conv_kwargs, self.norm_op, 170 | self.norm_op_kwargs, self.nonlin, self.nonlin_kwargs)) 171 | 172 | 173 | # now lets build the localization pathway 174 | for u in range(num_pool): 175 | nfeatures_from_skip = self.conv_blocks_context[-(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2 176 | n_features_after_tu_and_concat = nfeatures_from_skip * 2 177 | 178 | # the first conv reduces the number of features to match those of skip 179 | # the following convs work on that number of features 180 | # if not convolutional upsampling then the final conv reduces the num of features again 181 | if u != num_pool - 1: 182 | final_num_features = self.conv_blocks_context[-(3 + u)].output_channels 183 | else: 184 | final_num_features = nfeatures_from_skip 185 | 186 | self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u+1)], mode=upsample_mode)) 187 | 188 | 189 | self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[-(u+1)] 190 | self.conv_kwargs['padding'] = self.conv_pad_sizes[-(u+1)] 191 | self.conv_blocks_localization.append(ConvNormNonlinBlock(n_features_after_tu_and_concat, final_num_features, 192 | self.conv_op, self.conv_kwargs, self.norm_op, 193 | self.norm_op_kwargs, self.nonlin, self.nonlin_kwargs)) 194 | 195 | 196 | self.final_conv = conv_op(self.conv_blocks_localization[-1].output_channels, num_classes, 1, 1, 0, 1, 1, False) 197 | 198 | 199 | 200 | # register all modules properly 201 | self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization) 202 | self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context) 203 | self.td = nn.ModuleList(self.td) 204 | self.tu = nn.ModuleList(self.tu) 205 | 206 | if self.weightInitializer is not None: 207 | self.apply(self.weightInitializer) 208 | 209 | def forward(self, x): 210 | skips = [] 211 | seg_outputs = [] 212 | for d in range(len(self.conv_blocks_context) - 1): 213 | x = self.conv_blocks_context[d](x) 214 | skips.append(x) 215 | x = self.td[d](x) 216 | 217 | x = self.conv_blocks_context[-1](x) 218 | 219 | for u in range(len(self.tu)): 220 | x = self.tu[u](x) 221 | x = torch.cat((x, skips[-(u + 1)]), dim=1) 222 | x = self.conv_blocks_localization[u](x) 223 | 224 | output = self.final_conv(x) 225 | return output 226 | 227 | 228 | 229 | 230 | 231 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | monai==0.4.0 2 | nibabel==3.1.1 3 | torch==1.7.1 4 | dijkstra3d==1.8.0 5 | pandas 6 | medpy 7 | tqdm 8 | natsort -------------------------------------------------------------------------------- /splits/split_fully_budget1.csv: -------------------------------------------------------------------------------- 1 | vs_gk_13,training 2 | vs_gk_24,training 3 | vs_gk_170,training 4 | vs_gk_26,training 5 | vs_gk_36,training 6 | vs_gk_14,training 7 | vs_gk_66,training 8 | vs_gk_142,training 9 | vs_gk_114,training 10 | vs_gk_50,training 11 | vs_gk_125,training 12 | vs_gk_4,training 13 | vs_gk_7,training 14 | vs_gk_200,validation 15 | vs_gk_182,validation 16 | vs_gk_202,testing 17 | vs_gk_203,testing 18 | vs_gk_204,testing 19 | vs_gk_205,testing 20 | vs_gk_206,testing 21 | vs_gk_207,testing 22 | vs_gk_209,testing 23 | vs_gk_210,testing 24 | vs_gk_211,testing 25 | vs_gk_212,testing 26 | vs_gk_213,testing 27 | vs_gk_214,testing 28 | vs_gk_215,testing 29 | vs_gk_216,testing 30 | vs_gk_217,testing 31 | vs_gk_218,testing 32 | vs_gk_220,testing 33 | vs_gk_221,testing 34 | vs_gk_222,testing 35 | vs_gk_223,testing 36 | vs_gk_224,testing 37 | vs_gk_225,testing 38 | vs_gk_226,testing 39 | vs_gk_228,testing 40 | vs_gk_229,testing 41 | vs_gk_230,testing 42 | vs_gk_231,testing 43 | vs_gk_232,testing 44 | vs_gk_233,testing 45 | vs_gk_234,testing 46 | vs_gk_235,testing 47 | vs_gk_236,testing 48 | vs_gk_237,testing 49 | vs_gk_238,testing 50 | vs_gk_239,testing 51 | vs_gk_240,testing 52 | vs_gk_241,testing 53 | vs_gk_242,testing 54 | vs_gk_243,testing 55 | vs_gk_244,testing 56 | vs_gk_245,testing 57 | vs_gk_246,testing 58 | vs_gk_247,testing 59 | vs_gk_248,testing 60 | vs_gk_249,testing 61 | vs_gk_250,testing 62 | -------------------------------------------------------------------------------- /splits/split_fully_budget13.csv: -------------------------------------------------------------------------------- 1 | vs_gk_1,training 2 | vs_gk_2,training 3 | vs_gk_3,training 4 | vs_gk_4,training 5 | vs_gk_5,training 6 | vs_gk_6,training 7 | vs_gk_7,training 8 | vs_gk_8,training 9 | vs_gk_9,training 10 | vs_gk_10,training 11 | vs_gk_11,training 12 | vs_gk_12,training 13 | vs_gk_13,training 14 | vs_gk_14,training 15 | vs_gk_15,training 16 | vs_gk_16,training 17 | vs_gk_17,training 18 | vs_gk_18,training 19 | vs_gk_19,training 20 | vs_gk_20,training 21 | vs_gk_21,training 22 | vs_gk_22,training 23 | vs_gk_23,training 24 | vs_gk_24,training 25 | vs_gk_25,training 26 | vs_gk_26,training 27 | vs_gk_27,training 28 | vs_gk_28,training 29 | vs_gk_29,training 30 | vs_gk_30,training 31 | vs_gk_31,training 32 | vs_gk_32,training 33 | vs_gk_33,training 34 | vs_gk_34,training 35 | vs_gk_35,training 36 | vs_gk_36,training 37 | vs_gk_37,training 38 | vs_gk_38,training 39 | vs_gk_40,training 40 | vs_gk_41,training 41 | vs_gk_42,training 42 | vs_gk_43,training 43 | vs_gk_44,training 44 | vs_gk_45,training 45 | vs_gk_46,training 46 | vs_gk_47,training 47 | vs_gk_48,training 48 | vs_gk_49,training 49 | vs_gk_50,training 50 | vs_gk_51,training 51 | vs_gk_52,training 52 | vs_gk_53,training 53 | vs_gk_54,training 54 | vs_gk_55,training 55 | vs_gk_56,training 56 | vs_gk_57,training 57 | vs_gk_58,training 58 | vs_gk_59,training 59 | vs_gk_60,training 60 | vs_gk_61,training 61 | vs_gk_62,training 62 | vs_gk_63,training 63 | vs_gk_64,training 64 | vs_gk_65,training 65 | vs_gk_66,training 66 | vs_gk_67,training 67 | vs_gk_68,training 68 | vs_gk_69,training 69 | vs_gk_70,training 70 | vs_gk_71,training 71 | vs_gk_72,training 72 | vs_gk_73,training 73 | vs_gk_74,training 74 | vs_gk_75,training 75 | vs_gk_76,training 76 | vs_gk_77,training 77 | vs_gk_78,training 78 | vs_gk_79,training 79 | vs_gk_80,training 80 | vs_gk_81,training 81 | vs_gk_82,training 82 | vs_gk_83,training 83 | vs_gk_84,training 84 | vs_gk_85,training 85 | vs_gk_86,training 86 | vs_gk_87,training 87 | vs_gk_88,training 88 | vs_gk_89,training 89 | vs_gk_90,training 90 | vs_gk_91,training 91 | vs_gk_92,training 92 | vs_gk_93,training 93 | vs_gk_94,training 94 | vs_gk_95,training 95 | vs_gk_96,training 96 | vs_gk_98,training 97 | vs_gk_99,training 98 | vs_gk_100,training 99 | vs_gk_101,training 100 | vs_gk_102,training 101 | vs_gk_103,training 102 | vs_gk_104,training 103 | vs_gk_105,training 104 | vs_gk_106,training 105 | vs_gk_107,training 106 | vs_gk_108,training 107 | vs_gk_109,training 108 | vs_gk_110,training 109 | vs_gk_111,training 110 | vs_gk_112,training 111 | vs_gk_113,training 112 | vs_gk_114,training 113 | vs_gk_115,training 114 | vs_gk_116,training 115 | vs_gk_117,training 116 | vs_gk_118,training 117 | vs_gk_119,training 118 | vs_gk_120,training 119 | vs_gk_121,training 120 | vs_gk_122,training 121 | vs_gk_123,training 122 | vs_gk_124,training 123 | vs_gk_125,training 124 | vs_gk_126,training 125 | vs_gk_127,training 126 | vs_gk_128,training 127 | vs_gk_129,training 128 | vs_gk_131,training 129 | vs_gk_132,training 130 | vs_gk_133,training 131 | vs_gk_134,training 132 | vs_gk_135,training 133 | vs_gk_136,training 134 | vs_gk_137,training 135 | vs_gk_138,training 136 | vs_gk_139,training 137 | vs_gk_140,training 138 | vs_gk_141,training 139 | vs_gk_142,training 140 | vs_gk_143,training 141 | vs_gk_144,training 142 | vs_gk_145,training 143 | vs_gk_146,training 144 | vs_gk_147,training 145 | vs_gk_148,training 146 | vs_gk_149,training 147 | vs_gk_150,training 148 | vs_gk_151,training 149 | vs_gk_152,training 150 | vs_gk_153,training 151 | vs_gk_154,training 152 | vs_gk_155,training 153 | vs_gk_156,training 154 | vs_gk_157,training 155 | vs_gk_158,training 156 | vs_gk_159,training 157 | vs_gk_161,training 158 | vs_gk_162,training 159 | vs_gk_163,training 160 | vs_gk_164,training 161 | vs_gk_165,training 162 | vs_gk_166,training 163 | vs_gk_167,training 164 | vs_gk_169,training 165 | vs_gk_170,training 166 | vs_gk_171,training 167 | vs_gk_172,training 168 | vs_gk_173,training 169 | vs_gk_174,training 170 | vs_gk_175,training 171 | vs_gk_176,training 172 | vs_gk_177,training 173 | vs_gk_178,training 174 | vs_gk_179,training 175 | vs_gk_180,training 176 | vs_gk_181,training 177 | vs_gk_182,validation 178 | vs_gk_183,validation 179 | vs_gk_184,validation 180 | vs_gk_185,validation 181 | vs_gk_186,validation 182 | vs_gk_187,validation 183 | vs_gk_188,validation 184 | vs_gk_189,validation 185 | vs_gk_190,validation 186 | vs_gk_191,validation 187 | vs_gk_192,validation 188 | vs_gk_193,validation 189 | vs_gk_194,validation 190 | vs_gk_195,validation 191 | vs_gk_196,validation 192 | vs_gk_197,validation 193 | vs_gk_198,validation 194 | vs_gk_199,validation 195 | vs_gk_200,validation 196 | vs_gk_201,validation 197 | vs_gk_202,inference 198 | vs_gk_203,inference 199 | vs_gk_204,inference 200 | vs_gk_205,inference 201 | vs_gk_206,inference 202 | vs_gk_207,inference 203 | vs_gk_209,inference 204 | vs_gk_210,inference 205 | vs_gk_211,inference 206 | vs_gk_212,inference 207 | vs_gk_213,inference 208 | vs_gk_214,inference 209 | vs_gk_215,inference 210 | vs_gk_216,inference 211 | vs_gk_217,inference 212 | vs_gk_218,inference 213 | vs_gk_220,inference 214 | vs_gk_221,inference 215 | vs_gk_222,inference 216 | vs_gk_223,inference 217 | vs_gk_224,inference 218 | vs_gk_225,inference 219 | vs_gk_226,inference 220 | vs_gk_228,inference 221 | vs_gk_229,inference 222 | vs_gk_230,inference 223 | vs_gk_231,inference 224 | vs_gk_232,inference 225 | vs_gk_233,inference 226 | vs_gk_234,inference 227 | vs_gk_235,inference 228 | vs_gk_236,inference 229 | vs_gk_237,inference 230 | vs_gk_238,inference 231 | vs_gk_239,inference 232 | vs_gk_240,inference 233 | vs_gk_241,inference 234 | vs_gk_242,inference 235 | vs_gk_243,inference 236 | vs_gk_244,inference 237 | vs_gk_245,inference 238 | vs_gk_246,inference 239 | vs_gk_247,inference 240 | vs_gk_248,inference 241 | vs_gk_249,inference 242 | vs_gk_250,inference 243 | -------------------------------------------------------------------------------- /splits/split_inextremis_budget1.csv: -------------------------------------------------------------------------------- 1 | vs_gk_1,training 2 | vs_gk_2,training 3 | vs_gk_3,training 4 | vs_gk_4,training 5 | vs_gk_5,training 6 | vs_gk_6,training 7 | vs_gk_7,training 8 | vs_gk_8,training 9 | vs_gk_9,training 10 | vs_gk_10,training 11 | vs_gk_11,training 12 | vs_gk_12,training 13 | vs_gk_13,training 14 | vs_gk_14,training 15 | vs_gk_15,training 16 | vs_gk_16,training 17 | vs_gk_17,training 18 | vs_gk_18,training 19 | vs_gk_19,training 20 | vs_gk_20,training 21 | vs_gk_21,training 22 | vs_gk_22,training 23 | vs_gk_23,training 24 | vs_gk_24,training 25 | vs_gk_25,training 26 | vs_gk_26,training 27 | vs_gk_27,training 28 | vs_gk_28,training 29 | vs_gk_29,training 30 | vs_gk_30,training 31 | vs_gk_31,training 32 | vs_gk_32,training 33 | vs_gk_33,training 34 | vs_gk_34,training 35 | vs_gk_35,training 36 | vs_gk_36,training 37 | vs_gk_37,training 38 | vs_gk_38,training 39 | vs_gk_40,training 40 | vs_gk_41,training 41 | vs_gk_42,training 42 | vs_gk_43,training 43 | vs_gk_44,training 44 | vs_gk_45,training 45 | vs_gk_46,training 46 | vs_gk_47,training 47 | vs_gk_48,training 48 | vs_gk_49,training 49 | vs_gk_50,training 50 | vs_gk_51,training 51 | vs_gk_52,training 52 | vs_gk_53,training 53 | vs_gk_54,training 54 | vs_gk_55,training 55 | vs_gk_56,training 56 | vs_gk_57,training 57 | vs_gk_58,training 58 | vs_gk_59,training 59 | vs_gk_60,training 60 | vs_gk_61,training 61 | vs_gk_62,training 62 | vs_gk_63,training 63 | vs_gk_64,training 64 | vs_gk_65,training 65 | vs_gk_66,training 66 | vs_gk_67,training 67 | vs_gk_68,training 68 | vs_gk_69,training 69 | vs_gk_70,training 70 | vs_gk_71,training 71 | vs_gk_72,training 72 | vs_gk_73,training 73 | vs_gk_74,training 74 | vs_gk_75,training 75 | vs_gk_76,training 76 | vs_gk_77,training 77 | vs_gk_78,training 78 | vs_gk_79,training 79 | vs_gk_80,training 80 | vs_gk_81,training 81 | vs_gk_82,training 82 | vs_gk_83,training 83 | vs_gk_84,training 84 | vs_gk_85,training 85 | vs_gk_86,training 86 | vs_gk_87,training 87 | vs_gk_88,training 88 | vs_gk_89,training 89 | vs_gk_90,training 90 | vs_gk_91,training 91 | vs_gk_92,training 92 | vs_gk_93,training 93 | vs_gk_94,training 94 | vs_gk_95,training 95 | vs_gk_96,training 96 | vs_gk_98,training 97 | vs_gk_99,training 98 | vs_gk_100,training 99 | vs_gk_101,training 100 | vs_gk_102,training 101 | vs_gk_103,training 102 | vs_gk_104,training 103 | vs_gk_105,training 104 | vs_gk_106,training 105 | vs_gk_107,training 106 | vs_gk_108,training 107 | vs_gk_109,training 108 | vs_gk_110,training 109 | vs_gk_111,training 110 | vs_gk_112,training 111 | vs_gk_113,training 112 | vs_gk_114,training 113 | vs_gk_115,training 114 | vs_gk_116,training 115 | vs_gk_117,training 116 | vs_gk_118,training 117 | vs_gk_119,training 118 | vs_gk_120,training 119 | vs_gk_121,training 120 | vs_gk_122,training 121 | vs_gk_123,training 122 | vs_gk_124,training 123 | vs_gk_125,training 124 | vs_gk_126,training 125 | vs_gk_127,training 126 | vs_gk_128,training 127 | vs_gk_129,training 128 | vs_gk_131,training 129 | vs_gk_132,training 130 | vs_gk_133,training 131 | vs_gk_134,training 132 | vs_gk_135,training 133 | vs_gk_136,training 134 | vs_gk_137,training 135 | vs_gk_138,training 136 | vs_gk_139,training 137 | vs_gk_140,training 138 | vs_gk_141,training 139 | vs_gk_142,training 140 | vs_gk_143,training 141 | vs_gk_144,training 142 | vs_gk_145,training 143 | vs_gk_146,training 144 | vs_gk_147,training 145 | vs_gk_148,training 146 | vs_gk_149,training 147 | vs_gk_150,training 148 | vs_gk_151,training 149 | vs_gk_152,training 150 | vs_gk_153,training 151 | vs_gk_154,training 152 | vs_gk_155,training 153 | vs_gk_156,training 154 | vs_gk_157,training 155 | vs_gk_158,training 156 | vs_gk_159,training 157 | vs_gk_161,training 158 | vs_gk_162,training 159 | vs_gk_163,training 160 | vs_gk_164,training 161 | vs_gk_165,training 162 | vs_gk_166,training 163 | vs_gk_167,training 164 | vs_gk_169,training 165 | vs_gk_170,training 166 | vs_gk_171,training 167 | vs_gk_172,training 168 | vs_gk_173,training 169 | vs_gk_174,training 170 | vs_gk_175,training 171 | vs_gk_176,training 172 | vs_gk_177,training 173 | vs_gk_178,training 174 | vs_gk_179,training 175 | vs_gk_180,training 176 | vs_gk_181,training 177 | vs_gk_182,validation 178 | vs_gk_183,validation 179 | vs_gk_184,validation 180 | vs_gk_185,validation 181 | vs_gk_186,validation 182 | vs_gk_187,validation 183 | vs_gk_188,validation 184 | vs_gk_189,validation 185 | vs_gk_190,validation 186 | vs_gk_191,validation 187 | vs_gk_192,validation 188 | vs_gk_193,validation 189 | vs_gk_194,validation 190 | vs_gk_195,validation 191 | vs_gk_196,validation 192 | vs_gk_197,validation 193 | vs_gk_198,validation 194 | vs_gk_199,validation 195 | vs_gk_200,validation 196 | vs_gk_201,validation 197 | vs_gk_202,inference 198 | vs_gk_203,inference 199 | vs_gk_204,inference 200 | vs_gk_205,inference 201 | vs_gk_206,inference 202 | vs_gk_207,inference 203 | vs_gk_209,inference 204 | vs_gk_210,inference 205 | vs_gk_211,inference 206 | vs_gk_212,inference 207 | vs_gk_213,inference 208 | vs_gk_214,inference 209 | vs_gk_215,inference 210 | vs_gk_216,inference 211 | vs_gk_217,inference 212 | vs_gk_218,inference 213 | vs_gk_220,inference 214 | vs_gk_221,inference 215 | vs_gk_222,inference 216 | vs_gk_223,inference 217 | vs_gk_224,inference 218 | vs_gk_225,inference 219 | vs_gk_226,inference 220 | vs_gk_228,inference 221 | vs_gk_229,inference 222 | vs_gk_230,inference 223 | vs_gk_231,inference 224 | vs_gk_232,inference 225 | vs_gk_233,inference 226 | vs_gk_234,inference 227 | vs_gk_235,inference 228 | vs_gk_236,inference 229 | vs_gk_237,inference 230 | vs_gk_238,inference 231 | vs_gk_239,inference 232 | vs_gk_240,inference 233 | vs_gk_241,inference 234 | vs_gk_242,inference 235 | vs_gk_243,inference 236 | vs_gk_244,inference 237 | vs_gk_245,inference 238 | vs_gk_246,inference 239 | vs_gk_247,inference 240 | vs_gk_248,inference 241 | vs_gk_249,inference 242 | vs_gk_250,inference 243 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import argparse 5 | import time 6 | import os 7 | from tqdm import tqdm 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import torch 13 | from torch import nn 14 | 15 | from monai.inferers import sliding_window_inference 16 | from monai.utils import set_determinism 17 | from monai.data import DataLoader, Dataset 18 | from monai.transforms import ( 19 | Compose, 20 | LoadNiftid, 21 | AddChanneld, 22 | SpatialPadd, 23 | NormalizeIntensityd, 24 | RandFlipd, 25 | RandSpatialCropd, 26 | Orientationd, 27 | ToTensord, 28 | ) 29 | 30 | from utilities.losses import DC, DC_CE_Focal, PartialLoss 31 | from utilities.utils import ( 32 | create_logger, 33 | poly_lr, 34 | infinite_iterable) 35 | from utilities.geodesics import generate_geodesics 36 | 37 | from ScribbleDA.scribbleDALoss import CRFLoss 38 | from network.unet2d5 import UNet2D5 39 | 40 | 41 | # Define training and patches sampling parameters 42 | NB_CLASSES = 2 43 | PHASES = ["training", "validation"] 44 | MAX_EPOCHS = 300 45 | 46 | # Training parameters 47 | weight_decay = 3e-5 48 | 49 | def train(paths_dict, model, transformation, criterion, device, save_path, logger, opt): 50 | 51 | since = time.time() 52 | 53 | # Define transforms for data normalization and augmentation 54 | subjects_train = Dataset( 55 | paths_dict["training"], 56 | transform=transformation["training"]) 57 | 58 | subjects_val = Dataset( 59 | paths_dict["validation"], 60 | transform=transformation["validation"]) 61 | 62 | # Dataloaders 63 | dataloaders = dict() 64 | dataloaders["training"] = infinite_iterable( 65 | DataLoader(subjects_train, batch_size=opt.batch_size, num_workers=2, shuffle=True) 66 | ) 67 | dataloaders["validation"] = infinite_iterable( 68 | DataLoader(subjects_val, batch_size=1, num_workers=2) 69 | ) 70 | 71 | nb_batches = { 72 | "training": 30, # One image patch per epoch for the full dataset 73 | "validation": len(paths_dict["validation"]) 74 | } 75 | 76 | # Training parameters are saved 77 | df_path = os.path.join(opt.model_dir,"log.csv") 78 | if os.path.isfile(df_path): # If the training already started 79 | df = pd.read_csv(df_path, index_col=False) 80 | epoch = df.iloc[-1]["epoch"] 81 | best_epoch = df.iloc[-1]["best_epoch"] 82 | best_val = df.iloc[-1]["best_val"] 83 | initial_lr = df.iloc[-1]["lr"] 84 | model.load_state_dict(torch.load(save_path.format("best"))) 85 | 86 | else: # If training from scratch 87 | columns=["epoch","best_epoch", "MA", "best_MA", "lr", "timeit"] 88 | df = pd.DataFrame(columns=columns) 89 | best_val = None 90 | best_epoch = 0 91 | epoch = 0 92 | initial_lr = opt.learning_rate 93 | 94 | 95 | # Optimisation policy mimicking nnUnet training policy 96 | optimizer = torch.optim.SGD(model.parameters(), initial_lr, 97 | weight_decay=weight_decay, momentum=0.99, nesterov=True) 98 | 99 | # CRF Loss initialisation 100 | crf_l = CRFLoss(alpha=opt.alpha, beta=opt.beta, is_da=False, use_norm=False) 101 | 102 | # Training loop 103 | continue_training = True 104 | while continue_training: 105 | epoch+=1 106 | logger.info("-" * 10) 107 | logger.info("Epoch {}/".format(epoch)) 108 | logger.info 109 | for param_group in optimizer.param_groups: 110 | logger.info("Current learning rate is: {}".format(param_group["lr"])) 111 | 112 | # Each epoch has a training and validation phase 113 | for phase in PHASES: 114 | if phase == "training": 115 | model.train() # Set model to training mode 116 | else: 117 | model.eval() # Set model to evaluate mode 118 | 119 | # Initializing the statistics 120 | running_loss = 0.0 121 | running_loss_reg = 0.0 122 | running_loss_seg = 0.0 123 | epoch_samples = 0 124 | running_time = 0.0 125 | 126 | # Iterate over data 127 | for _ in tqdm(range(nb_batches[phase])): 128 | batch = next(dataloaders[phase]) 129 | inputs = batch["img"].to(device) # T2 images 130 | if opt.mode == "extreme_points": 131 | extremes = batch["label"].to(device) # Extreme points 132 | img_gradients = batch["img_gradient"].to(device) # Pre-Computed Sobel map 133 | else: 134 | labels = batch["label"].to(device) 135 | 136 | # zero the parameter gradients 137 | optimizer.zero_grad() 138 | 139 | with torch.set_grad_enabled(phase == "training"): 140 | if phase=="training": # Random patch predictions 141 | outputs = model(inputs) 142 | else: # if validation, Inference on the full image 143 | outputs = sliding_window_inference( 144 | inputs=inputs, 145 | roi_size=opt.spatial_shape, 146 | sw_batch_size=1, 147 | predictor=model, 148 | mode="gaussian", 149 | ) 150 | 151 | if opt.mode == "extreme_points": # Generate geodesics 152 | init_time_geodesics = time.time() 153 | geodesics = [] 154 | nb_target = outputs.shape[0] 155 | for i in range(nb_target): 156 | geodesics_i = generate_geodesics( 157 | extreme=extremes[i,...], 158 | img_gradient=img_gradients[i,...], 159 | prob=outputs[i,...], 160 | with_prob=opt.with_prob, 161 | with_euclidean=opt.with_euclidean 162 | ) 163 | geodesics.append(geodesics_i.to(device)) 164 | labels = torch.cat(geodesics,0) 165 | time_geodesics = time.time() - init_time_geodesics 166 | else: 167 | time_geodesics = 0. 168 | 169 | # Segmentation loss 170 | loss_seg = criterion(outputs, labels, phase) 171 | 172 | # CRF regularisation (training only) 173 | if (opt.beta>0 or opt.alpha>0) and phase == "training" and opt.mode == "extreme_points": 174 | reg = opt.weight_crf/np.prod(opt.spatial_shape)*crf_l(inputs, outputs) 175 | loss = loss_seg + reg 176 | else: 177 | reg = 0.0 178 | loss = loss_seg 179 | 180 | if phase == "training": 181 | loss.backward() 182 | optimizer.step() 183 | 184 | # Iteration statistics 185 | epoch_samples += 1 186 | running_loss += loss.item() 187 | running_loss_reg += reg 188 | running_loss_seg += loss_seg 189 | running_time += time_geodesics 190 | 191 | # Epoch statistcs 192 | epoch_loss = running_loss / epoch_samples 193 | epoch_loss_reg = running_loss_reg / epoch_samples 194 | epoch_loss_seg = running_loss_seg / epoch_samples 195 | if phase == "training": 196 | epoch_time = running_time / epoch_samples 197 | 198 | logger.info("{} Loss Reg: {:.4f}".format( 199 | phase, epoch_loss_reg)) 200 | logger.info("{} Loss Seg: {:.4f}".format( 201 | phase, epoch_loss_seg)) 202 | if phase == "training": 203 | logger.info("{} Time Geodesics: {:.4f}".format( 204 | phase, epoch_time)) 205 | 206 | # Saving best model on the validation set 207 | if phase == "validation": 208 | if best_val is None: # first iteration 209 | best_val = epoch_loss 210 | torch.save(model.state_dict(), save_path.format("best")) 211 | 212 | if epoch_loss <= best_val: 213 | best_val = epoch_loss 214 | best_epoch = epoch 215 | torch.save(model.state_dict(), save_path.format("best")) 216 | 217 | df = df.append( 218 | {"epoch":epoch, 219 | "best_epoch":best_epoch, 220 | "best_val":best_val, 221 | "lr":param_group["lr"], 222 | "timeit":epoch_time}, 223 | ignore_index=True) 224 | df.to_csv(df_path, index=False) 225 | 226 | optimizer.param_groups[0]["lr"] = poly_lr(epoch, MAX_EPOCHS, opt.learning_rate, 0.9) 227 | 228 | # Early stopping performed when full annotations are used (training set may be small) 229 | if opt.mode == "full_annotations" and epoch-best_epoch>70: 230 | torch.save(model.state_dict(), save_path.format("final")) 231 | continue_training=False 232 | 233 | if epoch == MAX_EPOCHS: 234 | torch.save(model.state_dict(), save_path.format("final")) 235 | continue_training=False 236 | 237 | time_elapsed = time.time() - since 238 | logger.info("[INFO] Training completed in {:.0f}m {:.0f}s".format( 239 | time_elapsed // 60, time_elapsed % 60)) 240 | logger.info(f"[INFO] Best validation epoch is {best_epoch}") 241 | 242 | 243 | def main(): 244 | set_determinism(seed=2) 245 | 246 | opt = parsing_data() 247 | 248 | # FOLDERS 249 | fold_dir = opt.model_dir 250 | fold_dir_model = os.path.join(fold_dir,"models") 251 | if not os.path.exists(fold_dir_model): 252 | os.makedirs(fold_dir_model) 253 | save_path = os.path.join(fold_dir_model,"./CP_{}.pth") 254 | 255 | if opt.path_labels is None: 256 | opt.path_labels = opt.path_data 257 | 258 | logger = create_logger(fold_dir) 259 | logger.info("[INFO] Hyperparameters") 260 | logger.info(f"Alpha: {opt.alpha}") 261 | logger.info(f"Beta: {opt.beta}") 262 | logger.info(f"Weight Reg: {opt.weight_crf}") 263 | logger.info(f"Batch size: {opt.batch_size}") 264 | logger.info(f"Spatial shape: {opt.spatial_shape}") 265 | logger.info(f"Initial lr: {opt.learning_rate}") 266 | logger.info(f"Postfix img gradients: {opt.img_gradient_postfix}") 267 | logger.info(f"Postfix labels: {opt.label_postfix}") 268 | logger.info(f"With euclidean: {opt.with_euclidean}") 269 | logger.info(f"With probs: {opt.with_prob}") 270 | 271 | # GPU CHECKING 272 | if torch.cuda.is_available(): 273 | logger.info("[INFO] GPU available.") 274 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 275 | else: 276 | raise logger.error( 277 | "[INFO] No GPU found") 278 | 279 | # SPLIT 280 | assert os.path.isfile(opt.dataset_split), logger.error("[ERROR] Invalid split") 281 | df_split = pd.read_csv(opt.dataset_split,header =None) 282 | list_file = dict() 283 | for split in PHASES: 284 | list_file[split] = df_split[df_split[1].isin([split])][0].tolist() 285 | 286 | 287 | # CREATING DICT FOR CACHEDATASET 288 | mod_ext = "_T2.nii.gz" 289 | grad_ext = f"_{opt.img_gradient_postfix}.nii.gz" 290 | extreme_ext = f"_{opt.label_postfix}.nii.gz" 291 | paths_dict = {split:[] for split in PHASES} 292 | 293 | for split in PHASES: 294 | for subject in list_file[split]: 295 | subject_data = dict() 296 | 297 | img_path = os.path.join(opt.path_data,subject+mod_ext) 298 | img_grad_path = os.path.join(opt.path_labels,subject+grad_ext) 299 | lab_path = os.path.join(opt.path_labels,subject+extreme_ext) 300 | 301 | if os.path.exists(img_path) and os.path.exists(lab_path): 302 | subject_data["img"] = img_path 303 | subject_data["label"] = lab_path 304 | 305 | if opt.mode == "extreme_points": 306 | if os.path.exists(img_grad_path): 307 | subject_data["img_gradient"] = img_grad_path 308 | paths_dict[split].append(subject_data) 309 | else: 310 | paths_dict[split].append(subject_data) 311 | 312 | logger.info(f"Nb patients in {split} data: {len(paths_dict[split])}") 313 | 314 | 315 | # PREPROCESSING 316 | transforms = dict() 317 | all_keys = ["img", "label"] 318 | if opt.mode == "extreme_points": 319 | all_keys.append("img_gradient") 320 | 321 | transforms_training = ( 322 | LoadNiftid(keys=all_keys), 323 | AddChanneld(keys=all_keys), 324 | Orientationd(keys=all_keys, axcodes="RAS"), 325 | NormalizeIntensityd(keys=["img"]), 326 | SpatialPadd(keys=all_keys, spatial_size=opt.spatial_shape), 327 | RandFlipd(keys=all_keys, prob=0.5, spatial_axis=0), 328 | RandSpatialCropd(keys=all_keys, roi_size=opt.spatial_shape, random_center=True, random_size=False), 329 | ToTensord(keys=all_keys), 330 | ) 331 | transforms["training"] = Compose(transforms_training) 332 | 333 | transforms_validation = ( 334 | LoadNiftid(keys=all_keys), 335 | AddChanneld(keys=all_keys), 336 | Orientationd(keys=all_keys, axcodes="RAS"), 337 | NormalizeIntensityd(keys=["img"]), 338 | SpatialPadd(keys=all_keys, spatial_size=opt.spatial_shape), 339 | ToTensord(keys=all_keys) 340 | ) 341 | transforms["validation"] = Compose(transforms_validation) 342 | 343 | # MODEL 344 | logger.info("[INFO] Building model") 345 | norm_op_kwargs = {"eps": 1e-5, "affine": True} 346 | net_nonlin = nn.LeakyReLU 347 | net_nonlin_kwargs = {"negative_slope": 1e-2, "inplace": True} 348 | 349 | model= UNet2D5(input_channels=1, 350 | base_num_features=16, 351 | num_classes=NB_CLASSES, 352 | num_pool=4, 353 | conv_op=nn.Conv3d, 354 | norm_op=nn.InstanceNorm3d, 355 | norm_op_kwargs=norm_op_kwargs, 356 | nonlin=net_nonlin, 357 | nonlin_kwargs=net_nonlin_kwargs).to(device) 358 | 359 | 360 | logger.info("[INFO] Training") 361 | if opt.mode == "full_annotations": 362 | dice = DC(NB_CLASSES) 363 | criterion = lambda pred, grnd, phase: dice(pred, grnd) 364 | 365 | elif opt.mode == "extreme_points" or opt.mode == "geodesics": 366 | dice_ce_focal = DC_CE_Focal(NB_CLASSES) 367 | criterion = PartialLoss(dice_ce_focal) 368 | 369 | train(paths_dict, 370 | model, 371 | transforms, 372 | criterion, 373 | device, 374 | save_path, 375 | logger, 376 | opt) 377 | 378 | def parsing_data(): 379 | parser = argparse.ArgumentParser( 380 | description="Script to train the models using extreme points as supervision") 381 | 382 | parser.add_argument("--model_dir", 383 | type=str, 384 | help="Path to the model directory") 385 | 386 | parser.add_argument("--mode", 387 | type=str, 388 | help="Choice of the supervision mode", 389 | choices=["full_annotations", "extreme_points", "geodesics"], 390 | default="extreme_points") 391 | 392 | parser.add_argument("--weight_crf", 393 | type=float, 394 | default=0.1) 395 | 396 | parser.add_argument("--alpha", 397 | type=float, 398 | default=15) 399 | 400 | parser.add_argument("--beta", 401 | type=float, 402 | default=0.05) 403 | 404 | parser.add_argument("--batch_size", 405 | type=int, 406 | default=6, 407 | help="Size of the batch size (default: 6)") 408 | 409 | parser.add_argument("--dataset_split", 410 | type=str, 411 | default="splits/split_inextremis_budget1.csv", 412 | help="Path to split file") 413 | 414 | parser.add_argument("--path_data", 415 | type=str, 416 | default="data/VS_MICCAI21/T2/", 417 | help="Path to the T2 scans") 418 | 419 | parser.add_argument("--path_labels", 420 | type=str, 421 | default=None, 422 | help="Path to the extreme points") 423 | 424 | parser.add_argument("--learning_rate", 425 | type=float, 426 | default=1e-2, 427 | help="Initial learning rate") 428 | 429 | parser.add_argument("--label_postfix", 430 | type=str, 431 | default="", 432 | help="Postfix of the Labels points") 433 | 434 | parser.add_argument("--img_gradient_postfix", 435 | type=str, 436 | default="", 437 | help="Postfix of the gradient images") 438 | 439 | parser.add_argument("--spatial_shape", 440 | type=int, 441 | nargs="+", 442 | default=(224,224,48), 443 | help="Size of the window patch") 444 | 445 | parser.add_argument("--with_prob", 446 | action="store_true", 447 | help="Add Deep probabilities") 448 | 449 | parser.add_argument("--with_euclidean", 450 | action="store_true", 451 | help="Add Euclidean distance") 452 | 453 | opt = parser.parse_args() 454 | 455 | return opt 456 | 457 | if __name__ == "__main__": 458 | main() 459 | 460 | 461 | 462 | -------------------------------------------------------------------------------- /utilities/figure_fully.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | 7 | score_dice = [] 8 | path_model = 'models/full_annotations/fully_{}/results_full.csv' 9 | score_hd = [] 10 | for i in [1,3,5,7,9,11,13]: 11 | score_dice.append(100*pd.read_csv(path_model.format(i))['dice'].mean()) 12 | score_hd.append(pd.read_csv(path_model.format(i))['hd95'].mean()) 13 | 14 | dice_InExtremIS = 100*pd.read_csv('models/extreme_points/manual_gradient_eucl_deep_crf/results_full.csv')['dice'].mean() 15 | hd_InExtremIS = pd.read_csv('models/extreme_points/manual_gradient_eucl_deep_crf/results_full.csv')['hd95'].mean() 16 | times = [2,6,10,14,18,22,26] 17 | 18 | scores = { 19 | 'Dice Score':{'unit':'%', 'scores':[score_dice, dice_InExtremIS]}, 20 | '95th-percentile Hausdorff Distance':{'unit':'mm', 'scores':[score_hd, hd_InExtremIS]} 21 | } 22 | 23 | for metric, score in scores.items(): 24 | unit = score['unit'] 25 | supervised_scores, InExtreMIS_score = score['scores'] 26 | fig = plt.figure() 27 | plt.title('Full supervision accuracy given an annotation time budget') 28 | plt.ylabel(f'{metric} ({unit})') 29 | plt.xlabel('Annotation time budget in hours') 30 | 31 | plt.plot(times, supervised_scores, '-ok', label='Fully Supervised') 32 | plt.scatter([2], [InExtreMIS_score], c='blue', label='Extreme points Supervision') 33 | plt.xticks(np.arange(0, 26+1, 2.0)) 34 | 35 | if metric == 'Dice Score': 36 | shift = 0.1 37 | sign = 1 38 | else: 39 | shift = 0.01 40 | sign = -1 41 | 42 | diff_same_budget = round(InExtreMIS_score-supervised_scores[0],1) 43 | diff_same_score = times[[n for n,i in enumerate(supervised_scores) if sign*i>sign*InExtreMIS_score][0]] - 2 44 | 45 | plt.annotate(s='', xy=(diff_same_score+2.1,InExtreMIS_score), xytext=(2.1,InExtreMIS_score), arrowprops=dict(arrowstyle='<->', linestyle="--",linewidth=2, color='purple')) 46 | plt.annotate(s='', xy=(2,supervised_scores[0]), xytext=(2,InExtreMIS_score), arrowprops=dict(arrowstyle='<->', linestyle="--",linewidth=2, color='purple')) 47 | 48 | plt.text(diff_same_score/2, InExtreMIS_score+5*shift, f'{diff_same_score} hours', ha='left', va='center',color='purple') 49 | plt.text(1.2, InExtreMIS_score-diff_same_budget/2, f'{abs(diff_same_budget)}{unit} {metric}', ha='left',rotation=90, va='center', color='purple') 50 | plt.legend() 51 | plt.grid() 52 | plt.show() 53 | fig.savefig(f"figs/{metric.replace(' ','')}_comparision.pdf",bbox_inches='tight') -------------------------------------------------------------------------------- /utilities/focal.py: -------------------------------------------------------------------------------- 1 | ## ADDAPTED from https://github.com/Project-MONAI/MONAI/blob/12f267c98eabdcd566dff11a3daf931201f04da4/monai/losses/focal_loss.py ### 2 | 3 | 4 | # Copyright 2020 - 2021 MONAI Consortium 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional, Union 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch.nn.modules.loss import _WeightedLoss 20 | 21 | from monai.utils import LossReduction 22 | 23 | class FocalLoss(_WeightedLoss): 24 | """ 25 | Reimplementation of the Focal Loss described in: 26 | 27 | - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 28 | - "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy", 29 | Zhu et al., Medical Physics 2018 30 | """ 31 | 32 | def __init__( 33 | self, 34 | gamma: float = 2.0, 35 | weight: Optional[torch.Tensor] = None, 36 | reduction: Union[LossReduction, str] = LossReduction.MEAN, 37 | ) -> None: 38 | """ 39 | Args: 40 | gamma: value of the exponent gamma in the definition of the Focal loss. 41 | weight: weights to apply to the voxels of each class. If None no weights are applied. 42 | This corresponds to the weights `\alpha` in [1]. 43 | reduction: {``"none"``, ``"mean"``, ``"sum"``} 44 | Specifies the reduction to apply to the output. Defaults to ``"mean"``. 45 | 46 | - ``"none"``: no reduction will be applied. 47 | - ``"mean"``: the sum of the output will be divided by the number of elements in the output. 48 | - ``"sum"``: the output will be summed. 49 | 50 | Example: 51 | .. code-block:: python 52 | 53 | import torch 54 | from monai.losses import FocalLoss 55 | 56 | pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32) 57 | grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64) 58 | fl = FocalLoss() 59 | fl(pred, grnd) 60 | 61 | """ 62 | super(FocalLoss, self).__init__(weight=weight, reduction=LossReduction(reduction).value) 63 | self.gamma = gamma 64 | self.weight: Optional[torch.Tensor] = None 65 | 66 | def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 67 | """ 68 | Args: 69 | logits: the shape should be BCH[WD]. 70 | where C (greater than 1) is the number of classes. 71 | Softmax over the logits is integrated in this module for improved numerical stability. 72 | target: the shape should be B1H[WD] or BCH[WD]. 73 | If the target's shape is B1H[WD], the target that this loss expects should be a class index 74 | in the range [0, C-1] where C is the number of classes. 75 | 76 | Raises: 77 | ValueError: When ``target`` ndim differs from ``logits``. 78 | ValueError: When ``target`` channel is not 1 and ``target`` shape differs from ``logits``. 79 | ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. 80 | 81 | """ 82 | i = logits 83 | t = target 84 | 85 | if i.ndimension() != t.ndimension(): 86 | raise ValueError(f"logits and target ndim must match, got logits={i.ndimension()} target={t.ndimension()}.") 87 | 88 | if t.shape[1] != 1 and t.shape[1] != i.shape[1]: 89 | raise ValueError( 90 | "target must have one channel or have the same shape as the logits. " 91 | "If it has one channel, it should be a class index in the range [0, C-1] " 92 | f"where C is the number of classes inferred from 'logits': C={i.shape[1]}. " 93 | ) 94 | if i.shape[1] == 1: 95 | raise NotImplementedError("Single-channel predictions not supported.") 96 | 97 | # Change the shape of logits and target to 98 | # num_batch x num_class x num_voxels. 99 | if i.dim() > 2: 100 | i = i.view(i.size(0), i.size(1), -1) # N,C,H,W => N,C,H*W 101 | t = t.view(t.size(0), t.size(1), -1) # N,1,H,W => N,1,H*W or N,C,H*W 102 | else: # Compatibility with classification. 103 | i = i.unsqueeze(2) # N,C => N,C,1 104 | t = t.unsqueeze(2) # N,1 => N,1,1 or N,C,1 105 | 106 | # Compute the log proba (more stable numerically than softmax). 107 | logpt = F.log_softmax(i, dim=1) # N,C,H*W 108 | # Keep only log proba values of the ground truth class for each voxel. 109 | if target.shape[1] == 1: 110 | logpt = logpt.gather(1, t.long()) # N,C,H*W => N,1,H*W 111 | logpt = torch.squeeze(logpt, dim=1) # N,1,H*W => N,H*W 112 | 113 | # Get the proba 114 | pt = torch.exp(logpt) # N,H*W or N,C,H*W 115 | 116 | if self.weight is not None: 117 | self.weight = self.weight.to(i) 118 | # Convert the weight to a map in which each voxel 119 | # has the weight associated with the ground-truth label 120 | # associated with this voxel in target. 121 | at = self.weight[None, :, None] # C => 1,C,1 122 | at = at.expand((t.size(0), -1, t.size(2))) # 1,C,1 => N,C,H*W 123 | if target.shape[1] == 1: 124 | at = at.gather(1, t.long()) # selection of the weights => N,1,H*W 125 | at = torch.squeeze(at, dim=1) # N,1,H*W => N,H*W 126 | # Multiply the log proba by their weights. 127 | logpt = logpt * at 128 | 129 | # Compute the loss mini-batch. 130 | weight = torch.pow(-pt + 1.0, self.gamma) 131 | if target.shape[1] == 1: 132 | loss = -weight * logpt # N 133 | else: 134 | loss = -weight * t * logpt # N,C 135 | 136 | if self.reduction == LossReduction.SUM.value: 137 | return loss.sum() 138 | if self.reduction == LossReduction.NONE.value: 139 | return loss 140 | if self.reduction == LossReduction.MEAN.value: 141 | return loss.mean() 142 | raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') -------------------------------------------------------------------------------- /utilities/geodesics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import dijkstra3d 4 | 5 | def normalize(data): 6 | return (data-data.min())/(data.max()-data.min()) 7 | 8 | def generate_geodesics(extreme, img_gradient, prob, inside_bb_value=12, with_prob=True, with_euclidean=True): 9 | # 10 | extreme = extreme.squeeze() 11 | img_gradient = img_gradient.squeeze() 12 | prob = prob.squeeze() 13 | 14 | if (extreme==inside_bb_value).sum()>0: # If the image patch is not only background 15 | # Corners of the bounding boxes 16 | bb_x_min = torch.where(extreme==inside_bb_value)[0].min().item() 17 | bb_x_max = torch.where(extreme==inside_bb_value)[0].max().item() + 1 18 | bb_y_min = torch.where(extreme==inside_bb_value)[1].min().item() 19 | bb_y_max = torch.where(extreme==inside_bb_value)[1].max().item() + 1 20 | bb_z_min = torch.where(extreme==inside_bb_value)[2].min().item() 21 | bb_z_max = torch.where(extreme==inside_bb_value)[2].max().item() + 1 22 | 23 | # Only paths within the non-relaxed bounding box are considered 24 | img_gradient_crop = img_gradient[bb_x_min:bb_x_max,bb_y_min:bb_y_max,bb_z_min:bb_z_max].cpu().numpy() 25 | prob_crop = prob[:,bb_x_min:bb_x_max,bb_y_min:bb_y_max,bb_z_min:bb_z_max].detach() 26 | prob_crop = torch.nn.Softmax(0)(prob_crop)[0,...].cpu().numpy() # Probability of the background 27 | 28 | # Extreme points 29 | ex_x_min = torch.where(extreme==1) 30 | ex_x_max = torch.where(extreme==2) 31 | ex_y_min = torch.where(extreme==3) 32 | ex_y_max = torch.where(extreme==4) 33 | ex_z_min = torch.where(extreme==5) 34 | ex_z_max = torch.where(extreme==6) 35 | 36 | # Identifying the pairs of extreme points to join (Extreme points may miss --> patch based approach) 37 | couples = [] 38 | if ex_x_min[0].shape[0]>0 and ex_x_max[0].shape[0]>0: # Extreme points in the x dimension 39 | couples.append([[k[0].item() for k in ex_x_min], [k[0].item() for k in ex_x_max]]) 40 | 41 | if ex_y_min[0].shape[0]>0 and ex_y_max[0].shape[0]>0: # Extreme points in the y dimension 42 | couples.append([[k[0].item() for k in ex_y_min], [k[0].item() for k in ex_y_max]]) 43 | 44 | if ex_z_min[0].shape[0]>0 and ex_z_max[0].shape[0]>0: # Extreme points in the z dimension 45 | couples.append([[k[0].item() for k in ex_z_min], [k[0].item() for k in ex_z_max]]) 46 | 47 | couples_crop = [[[k[0]-bb_x_min, k[1]-bb_y_min,k[2]-bb_z_min] for k in couple] for couple in couples] 48 | 49 | # Calculating the geodesics using the dijkstra3d 50 | output_crop = inside_bb_value + np.zeros(img_gradient_crop.shape) 51 | for source, target in couples_crop: 52 | weights = img_gradient_crop.copy() # Image gradient term 53 | 54 | if with_prob: 55 | weights+=prob_crop # Deep background probability term 56 | 57 | if with_euclidean: # Normalized distance map to the target 58 | x, y, z = np.ogrid[0:img_gradient_crop.shape[0], 0:img_gradient_crop.shape[1], 0:img_gradient_crop.shape[2]] 59 | distances = np.sqrt((x-target[0])**2+(y-target[1])**2+(z-target[2])**2) 60 | distances = normalize(distances) 61 | weights+=distances 62 | 63 | path = dijkstra3d.dijkstra(weights, source, target, connectivity=26) 64 | for k in path: 65 | x,y,z = k 66 | output_crop[x,y,z] = 1 67 | 68 | 69 | output = torch.zeros(extreme.shape) 70 | output[bb_x_min:bb_x_max,bb_y_min:bb_y_max,bb_z_min:bb_z_max] = torch.from_numpy(output_crop.astype(int)) 71 | return output[None,None,...] #Adding batch and channel 72 | else: 73 | # No geodesics 74 | for k in range(1,7): 75 | extreme[extreme==k] = 1 76 | return extreme[None,None,...] #Adding batch and channel 77 | -------------------------------------------------------------------------------- /utilities/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from utilities.focal import FocalLoss 4 | 5 | 6 | class PartialLoss(nn.Module): 7 | def __init__(self, criterion): 8 | super(PartialLoss, self).__init__() 9 | 10 | self.criterion = criterion 11 | self.nb_classes = self.criterion.nb_classes 12 | 13 | def forward(self, outputs, partial_target, phase='training'): 14 | nb_target = outputs.shape[0] 15 | loss_target = 0.0 16 | total = 0 17 | 18 | for i in range(nb_target): 19 | partial_i = partial_target[i,...].reshape(-1) 20 | outputs_i = outputs[i,...].reshape(self.nb_classes, -1).unsqueeze(0) 21 | outputs_i = outputs_i[:,:,partial_i0: 25 | outputs_i= outputs_i.reshape(1,self.nb_classes,1,1,nb_annotated) # Reshape to a 5D tensor 26 | partial_i = partial_i[partial_i0: 31 | return loss_target/total 32 | else: 33 | return 0.0 34 | 35 | 36 | class DC(nn.Module): 37 | def __init__(self,nb_classes): 38 | super(DC, self).__init__() 39 | 40 | self.softmax = nn.Softmax(1) 41 | self.nb_classes = nb_classes 42 | 43 | @staticmethod 44 | def onehot(gt,shape): 45 | shp_y = gt.shape 46 | gt = gt.long() 47 | y_onehot = torch.zeros(shape) 48 | y_onehot = y_onehot.cuda() 49 | y_onehot.scatter_(1, gt, 1) 50 | return y_onehot 51 | 52 | def reshape(self,output, target): 53 | batch_size = output.shape[0] 54 | 55 | if not all([i == j for i, j in zip(output.shape, target.shape)]): 56 | target = self.onehot(target, output.shape) 57 | 58 | target = target.permute(0,2,3,4,1) 59 | output = output.permute(0,2,3,4,1) 60 | print(target.shape,output.shape) 61 | return output, target 62 | 63 | 64 | def dice(self, output, target): 65 | output = self.softmax(output) 66 | if not all([i == j for i, j in zip(output.shape, target.shape)]): 67 | target = self.onehot(target, output.shape) 68 | 69 | sum_axis = list(range(2,len(target.shape))) 70 | 71 | s = (10e-20) 72 | intersect = torch.sum(output * target,sum_axis) 73 | dice = (2 * intersect) / (torch.sum(output,sum_axis) + torch.sum(target,sum_axis) + s) 74 | #dice shape is (batch_size, nb_classes) 75 | return 1.0 - dice.mean() 76 | 77 | 78 | def forward(self, output, target): 79 | result = self.dice(output, target) 80 | return result 81 | 82 | 83 | class DC_CE_Focal(DC): 84 | def __init__(self,nb_classes): 85 | super(DC_CE_Focal, self).__init__(nb_classes) 86 | 87 | self.ce = nn.CrossEntropyLoss(reduction='mean') 88 | self.fl = FocalLoss(reduction="none") 89 | 90 | def focal(self, pred, grnd, phase="training"): 91 | score = self.fl(pred, grnd).reshape(-1) 92 | 93 | if phase=="training": # class-balanced focal loss 94 | output = 0.0 95 | nb_classes = 0 96 | for cl in range(self.nb_classes): 97 | if (grnd==cl).sum().item()>0: 98 | output+=score[grnd.reshape(-1)==cl].mean() 99 | nb_classes+=1 100 | 101 | if nb_classes>0: 102 | return output/nb_classes 103 | else: 104 | return 0.0 105 | 106 | else: # class-balanced focal loss 107 | return score.mean() 108 | 109 | def forward(self, output, target, phase="training"): 110 | # Dice term 111 | dc_loss = self.dice(output, target) 112 | 113 | # Focal term 114 | focal_loss = self.focal(output, target, phase) 115 | 116 | # Cross entropy 117 | output = output.permute(0,2,3,4,1).contiguous().view(-1,self.nb_classes) 118 | target = target.view(-1,).long().cuda() 119 | ce_loss = self.ce(output, target) 120 | 121 | result = ce_loss + dc_loss + focal_loss 122 | return result 123 | 124 | 125 | -------------------------------------------------------------------------------- /utilities/scores.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from tqdm import tqdm 4 | from medpy.metric.binary import dc, precision, hd95 5 | import numpy as np 6 | import nibabel as nib 7 | import pandas as pd 8 | from natsort import natsorted 9 | 10 | opt = argparse.ArgumentParser( 11 | description="Computing scores") 12 | 13 | opt.add_argument("--model_dir", 14 | type=str, 15 | help="Path to the model directory") 16 | opt.add_argument("--path_data", 17 | type=str, 18 | default="../data/VS_MICCAI21/T2/", 19 | help="Path to the labels") 20 | 21 | opt = opt.parse_args() 22 | 23 | name_folder = os.path.basename(os.path.normpath(opt.model_dir)) 24 | 25 | df_split = pd.read_csv('splits/split_inextremis_budget1.csv',header =None) 26 | list_patient = natsorted(df_split[df_split[1].isin(['inference'])][0].tolist()) 27 | 28 | list_dice = [] 29 | list_hd = [] 30 | list_precision = [] 31 | 32 | df_scores= {'name':[],'dice':[],'hd95':[],'precision':[]} 33 | for patient in tqdm(list_patient): 34 | path_gt = os.path.join(opt.path_data, patient+"_Label.nii.gz") 35 | path_pred = os.path.join(opt.model_dir,'output_pred',f"{patient}_T2",f"{patient}_T2_seg.nii.gz") 36 | gt = nib.funcs.as_closest_canonical(nib.load(path_gt)).get_fdata().squeeze() 37 | pred = nib.funcs.as_closest_canonical(nib.load(path_pred)).get_fdata().squeeze() 38 | affine = nib.funcs.as_closest_canonical(nib.load(path_gt)).affine 39 | 40 | voxel_spacing = [abs(affine[k,k]) for k in range(3)] 41 | dice_score = dc(pred, gt) 42 | if np.sum(pred)>0: 43 | hd_score = 0.0 44 | hd_score = hd95(pred, gt, voxelspacing=voxel_spacing) 45 | precision_score = precision(pred, gt) 46 | 47 | list_dice.append(100*dice_score) 48 | list_hd.append(hd_score) 49 | list_precision.append(100*precision_score) 50 | 51 | df_scores['name'].append(patient) 52 | df_scores['dice'].append(dice_score) 53 | df_scores['hd95'].append(hd_score) 54 | df_scores['precision'].append(precision_score) 55 | 56 | df_scores = pd.DataFrame(df_scores) 57 | df_scores.to_csv(os.path.join(opt.model_dir, "results_full.csv")) 58 | 59 | 60 | mean_dice = np.round(np.mean(list_dice),1) 61 | std_dice = np.round(np.std(list_dice),1) 62 | mean_hd = np.round(np.mean(list_hd),1) 63 | std_hd = np.round(np.std(list_hd),1) 64 | mean_precision = np.round(np.mean(list_precision),1) 65 | std_precision = np.round(np.std(list_precision),1) 66 | 67 | print(name_folder) 68 | print(f"{mean_dice} ({std_dice}) & {mean_hd} ({std_hd}) & {mean_precision} ({std_precision}) \\") 69 | 70 | -------------------------------------------------------------------------------- /utilities/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def create_logger(folder): 6 | """Create a logger to save logs.""" 7 | compt = 0 8 | while os.path.exists(os.path.join(folder,f"logs_{compt}.txt")): 9 | compt+=1 10 | logname = os.path.join(folder,f"logs_{compt}.txt") 11 | 12 | logger = logging.getLogger() 13 | fileHandler = logging.FileHandler(logname, mode="w") 14 | consoleHandler = logging.StreamHandler() 15 | logger.addHandler(fileHandler) 16 | logger.addHandler(consoleHandler) 17 | formatter = logging.Formatter("%(message)s") 18 | fileHandler.setFormatter(formatter) 19 | consoleHandler.setFormatter(formatter) 20 | logger.setLevel(logging.INFO) 21 | return logger 22 | 23 | 24 | def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9): 25 | """Learning rate policy used in nnUNet.""" 26 | return initial_lr * (1 - epoch / max_epochs)**exponent 27 | 28 | 29 | def infinite_iterable(i): 30 | while True: 31 | yield from i 32 | 33 | 34 | -------------------------------------------------------------------------------- /utilities/write_geodesics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | import argparse 5 | import nibabel as nib 6 | import pandas as pd 7 | from geodesics import generate_geodesics 8 | 9 | PHASES = ["training", "validation"] 10 | 11 | def main(): 12 | opt = parsing_data() 13 | 14 | # FOLDERS 15 | fold_dir = opt.output_folder 16 | if not os.path.exists(fold_dir): 17 | os.makedirs(fold_dir) 18 | 19 | # SPLIT 20 | assert os.path.isfile(opt.dataset_split), print("[ERROR] Invalid split file") 21 | df_split = pd.read_csv(opt.dataset_split,header =None) 22 | list_file = dict() 23 | for split in PHASES: 24 | list_file[split] = df_split[df_split[1].isin([split])][0].tolist() 25 | 26 | mod_ext = "_T2.nii.gz" 27 | grad_ext = f"_{opt.img_gradient_postfix}.nii.gz" 28 | extreme_ext = f"_{opt.label_postfix}.nii.gz" 29 | paths_dict = {split:[] for split in PHASES} 30 | 31 | print(f"Using the Euclidean distance: {opt.with_euclidean}") 32 | for split in PHASES: 33 | score = [] 34 | for subject in tqdm(list_file[split]): 35 | subject_data = dict() 36 | 37 | img_path = os.path.join(opt.path_data,subject+mod_ext) 38 | img_grad_path = os.path.join(opt.path_extremes,subject+grad_ext) 39 | lab_path = os.path.join(opt.path_extremes,subject+extreme_ext) 40 | output_path = os.path.join(opt.output_folder,subject+'_PartLabel.nii.gz') 41 | 42 | if os.path.exists(img_path) and os.path.exists(lab_path) and os.path.exists(img_grad_path): 43 | extreme = nib.load(lab_path) 44 | affine = extreme.affine 45 | extreme_data = torch.from_numpy(extreme.get_fdata()) 46 | 47 | grad_data = torch.from_numpy(nib.load(img_grad_path).get_fdata()) 48 | 49 | geodesics = generate_geodesics( 50 | extreme=extreme_data, 51 | img_gradient=grad_data, 52 | prob=None, 53 | with_prob=False, 54 | with_euclidean=opt.with_euclidean).numpy().squeeze() 55 | 56 | nib.Nifti1Image(geodesics,affine).to_filename(output_path) 57 | 58 | 59 | def parsing_data(): 60 | parser = argparse.ArgumentParser( 61 | description="Script to generate (non-deep) geodesics using extreme points") 62 | 63 | parser.add_argument("--output_folder", 64 | type=str, 65 | default="geodesics_folder", 66 | help="Path to the model directory") 67 | 68 | parser.add_argument("--dataset_split", 69 | type=str, 70 | default="splits/split_inextremis_budget1.csv", 71 | help="Path to split file") 72 | 73 | parser.add_argument("--path_data", 74 | type=str, 75 | default="../data/VS_MICCAI21/T2/", 76 | help="Path to the T2 scans") 77 | 78 | parser.add_argument("--path_extremes", 79 | type=str, 80 | default="../data/VS_MICCAI21/extremes_manual/", 81 | help="Path to the extreme points") 82 | 83 | parser.add_argument("--label_postfix", 84 | type=str, 85 | default="Extremes_man", 86 | help="Postfix of the Labels points") 87 | 88 | parser.add_argument("--img_gradient_postfix", 89 | type=str, 90 | default="Sobel_man", 91 | help="Postfix of the gradient images") 92 | 93 | parser.add_argument("--with_euclidean", 94 | action="store_true", 95 | help="Add Euclidean distance") 96 | 97 | opt = parser.parse_args() 98 | 99 | return opt 100 | 101 | if __name__ == "__main__": 102 | main() 103 | --------------------------------------------------------------------------------