├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── data
├── README.md
├── convert.py
├── manifest-T2.tcia
└── preprocess_requirements.txt
├── figs
├── DiceScore_comparision.png
└── supervision_comparision.png
├── inference.py
├── models
└── README.md
├── network
└── unet2d5.py
├── requirements.txt
├── splits
├── split_fully_budget1.csv
├── split_fully_budget13.csv
└── split_inextremis_budget1.csv
├── train.py
└── utilities
├── figure_fully.py
├── focal.py
├── geodesics.py
├── losses.py
├── scores.py
├── utils.py
└── write_geodesics.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "ScribbleDA"]
2 | path = ScribbleDA
3 | url = https://github.com/KCL-BMEIS/ScribbleDA/
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 ReubenDo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Inter Extreme Points Geodesics for End-to-End Weakly Supervised Image Segmentation
2 |
3 | Public PyTorch implementation for our paper [Inter Extreme Points Geodesics for End-to-End Weakly Supervised Image Segmentation](https://arxiv.org/abs/2107.00583),
4 | which was accepted for presentation at [MICCAI 2021](https://www.miccai2021.org).
5 |
6 | If you find this code useful for your research, please cite the following paper:
7 |
8 | ```
9 | @article{InExtremIS2021Dorent,
10 | author={Dorent, Reuben and Joutard, Samuel and Shapey, Jonathan and
11 | Kujawa, Aaron and Modat, Marc and Ourselin, S\'ebastien and Vercauteren, Tom},
12 | title={Inter Extreme Points Geodesics for End-to-End Weakly Supervised Image Segmentation},
13 | journal={MICCAI},
14 | year={2021},
15 | }
16 | ```
17 |
18 | ## Method Overview
19 | We introduce InExtremIS, a weakly supervised 3D approach to train a deep image segmentation network using particularly weak train-time annotations: only 6 extreme clicks at the boundary of the objects of interest. Our fully automatic method is trained end-to-end and does not require any test-time annotations.
20 |
21 | *Example of weak labels for our use case of Vestibular Schwannoma (VS) segmentation. Magenta: Background. Green: VS:*
22 |
23 |
24 |
25 |
26 |
27 |
28 | ## Virtual Environment Setup
29 |
30 | The code is implemented in Python 3.6 using the PyTorch library.
31 | Requirements:
32 |
33 | * Set up a virtual environment (e.g. conda or virtualenv) with Python >=3.6.9
34 | * Install all requirements using:
35 |
36 | ````pip install -r requirements.txt````
37 | * Install the CUDA implementation of the permutohedral lattice and the CRF Loss.
38 | ````
39 | cd ScribbleDA/Permutohedral_attention_module/PAM_cuda/
40 | python3 setup.py build
41 | python3 setup.py install --user
42 | ````
43 |
44 |
45 | ## Data
46 |
47 | The data and annotations are publicly available. Details are provided in [data](/data/).
48 |
49 | ## Running the code
50 | `train.py` is the main file for training the models.
51 |
52 | Example 1: Training InExtreMIS with manual extreme points:
53 | ````
54 | python3 train.py \
55 | --model_dir ./models/manual_gradient_eucl_deep_crf/ \
56 | --alpha 15 \
57 | --beta 0.05 \
58 | --weight_crf 0.0001 \
59 | --mode extreme_points \
60 | --label_postfix Extremes_man \
61 | --img_gradient_postfix Sobel_man \
62 | --path_data data/T2/ \
63 | --path_labels data/extreme_points/manual/ \
64 | --with_euclidean \
65 | --with_prob
66 | ````
67 | Example 2: Training InExtreMIS with simulated extreme points:
68 | ````
69 | python3 train.py \
70 | --model_dir ./models/simulated_gradient_eucl_deep_crf/ \
71 | --alpha 15 \
72 | --beta 0.05 \
73 | --weight_crf 0.0001 \
74 | --mode extreme_points \
75 | --label_postfix Extremes \
76 | --img_gradient_postfix Sobel \
77 | --path_data data/T2/ \
78 | --path_labels data/extreme_points/simulated/ \
79 | --with_euclidean \
80 | --with_prob
81 | ````
82 |
83 | `inference.py` is the main file for running the inference:
84 | ````
85 | python3 inference.py \
86 | --model_dir ./models/manual_gradient_eucl_deep_crf/ \
87 | --path_data data/T2/ \
88 | ````
89 |
90 | ## Using the code with your own data
91 |
92 | If you want to use your own data, you just need to change the source and target paths,
93 | the splits and potentially the modality used.
94 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Downloading the data used for the experiments
2 |
3 | In this work, we used a large (N=242) dataset for Vestibular Schwannoma segmentation. This dataset is publicly available on
4 | [TCIA](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=70229053).
5 |
6 | This readme explains how to download and pre-process the raw data from TCIA. We also provide an open access to the extreme points and pre-computed geodesics used in this work.
7 |
8 | ## Download the fully annotated TCIA-VS dataset
9 |
10 | ### Option 1 - Downloading the T2 scans only and their segmentation maps (Recommended):
11 |
12 | Please follow the following steps:
13 |
14 | **Step 1**: Download the NBIA Data Retriever:
15 | * Please follow the instructions [here](https://wiki.cancerimagingarchive.net/display/NBIA/Downloading+TCIA+Images).
16 |
17 | **Step 2**: Download the T2 scans only:
18 | * Open `manifest-T2.tcia` with NBIA Data Retriever and download the T2 images (DICOM, 6GB) with the "Descriptive Directory Name" format.
19 |
20 | **Step 3**: DICOM to Nifti conversion:
21 | * Install dependencies: `pip install -r preprocess_requirements.txt`
22 | * Execute the conversion script:
23 | `python3 convert.py --input --output `
24 | * `` is the directory containing the raw T2 images (e.g. `/home/admin/manifest-T2/Vestibular-Schwannoma-SEG/`).
25 | * `` is the directory in which the pre-processed data will be saved.
26 |
27 | **Step 4**: Download the fully annotated segmentation masks [here](https://zenodo.org/record/5081986/files/full_annotations.zip?download=1).
28 |
29 | ### Option 2 - Downloading the full dataset and manually convert contours into segmentation masks:
30 | Please follow the instructions from the [VS_Seg repository](https://github.com/KCL-BMEIS/VS_Seg/tree/master/preprocessing).
31 |
32 | ## Download the extreme points and pre-computed geodesics
33 | The manual and simulated extreme points can be found [here](https://zenodo.org/record/5081986/files/extreme_points.zip?download=1).
34 | The pre-computed geodesics using the image gradient information (`grad` folder) and with the additional Euclidean distance (`grad_eucl` folder) can be found [here](https://zenodo.org/record/5081986/files/precomputed_geodesics.zip?download=1).
35 |
36 | The 6 extreme points are the voxels with the values _1,2,3,4,5,6_. Specifically, the pairs of extreme points in the _x_, _y_ and _z_ axis are respectively _{1,2}_, _{3,4}_ and _{5,6}_.
37 |
38 | ## Citations
39 | If you use this VS data, please cite:
40 |
41 | ```
42 | @article { ShapeyJNS21,
43 | author = "Jonathan Shapey and Guotai Wang and Reuben Dorent and Alexis Dimitriadis and Wenqi Li and Ian Paddick and Neil Kitchen and Sotirios Bisdas and Shakeel R. Saeed and Sebastien Ourselin and Robert Bradford and Tom Vercauteren",
44 | title = "{An artificial intelligence framework for automatic segmentation and volumetry of vestibular schwannomas from contrast-enhanced T1-weighted and high-resolution T2-weighted MRI}",
45 | journal = "Journal of Neurosurgery JNS",
46 | year = "2021",
47 | publisher = "American Association of Neurological Surgeons",
48 | volume = "134",
49 | number = "1",
50 | doi = "10.3171/2019.9.JNS191949",
51 | pages= "171 - 179",
52 | url = "https://thejns.org/view/journals/j-neurosurg/134/1/article-p171.xml"
53 | }
54 | ```
55 |
56 | If you use the extreme points, please additionally cite:
57 |
58 | ```
59 | @article{InExtremIS2021Dorent,
60 | author={Dorent, Reuben and Joutard, Samuel and Shapey, Jonathan and
61 | Kujawa, Aaron and Modat, Marc and Ourselin, S\'ebastien and Vercauteren, Tom},
62 | title={Inter Extreme Points Geodesics for End-to-End Weakly Supervised Image Segmentation},
63 | journal={MICCAI},
64 | year={2021},
65 | }
66 | ```
67 | ## Credits
68 | The conversion script is based on https://github.com/KCL-BMEIS/VS_Seg/tree/master/preprocessing.
69 |
--------------------------------------------------------------------------------
/data/convert.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/KCL-BMEIS/VS_Seg/blob/master/preprocessing/TCIA_data_convert_into_convenient_folder_structure.py
2 |
3 | #!/usr/bin/env python
4 | # coding: utf-8
5 |
6 | import os
7 | from glob import glob
8 | from natsort import natsorted
9 | import pydicom
10 | import SimpleITK as sitk
11 | import re
12 | import argparse
13 | from tqdm import tqdm
14 |
15 | parser = argparse.ArgumentParser(description='Convert the T2 scans from the TCIA dataset into the Nifti format')
16 | parser.add_argument('--input', type=str, help='(string) path to TCIA dataset, in "Descriptive Directory Name" format, for example /home/user/.../manifest-T2/Vestibular-Schwannoma-SEG')
17 | parser.add_argument('--output', type=str, help='(string) path to output folder')
18 | args = parser.parse_args()
19 |
20 | input_path = args.input
21 | output_path = args.output
22 |
23 | if not os.path.isdir(output_path):
24 | os.makedirs(output_path, exist_ok=True)
25 |
26 | cases = natsorted(glob(os.path.join(input_path, '*')))
27 |
28 | for case in tqdm(cases):
29 | folders = glob(case+'/*/*')
30 |
31 | MRs = []
32 | MRs_paths = []
33 |
34 | # Test that the DICOM is a MRI
35 | for folder in folders:
36 | first_file = glob(folder+"/*")[0]
37 | dd = pydicom.read_file(first_file)
38 |
39 | if dd['Modality'].value == 'MR':
40 | MRs.append(dd)
41 | MRs_paths.append(first_file)
42 |
43 | else:
44 | raise Exception
45 |
46 | file_paths = None
47 | # Test that the DICOM is a T2 scan
48 | for MR, path in zip(MRs, MRs_paths):
49 | if "t2_" in MR['SeriesDescription'].value:
50 | MR_T2 = MR
51 | file_paths = path
52 | else:
53 | raise Exception
54 |
55 | # write files into new folder structure
56 | p = re.compile(r'VS-SEG-(\d+)')
57 | case_idx = int(p.findall(case)[0])
58 | old_T2_folder = os.path.dirname(file_paths)
59 |
60 | # Output path
61 | new_T2_path = os.path.join(output_path, 'vs_gk_' + str(case_idx) +'_T2.nii.gz')
62 |
63 | # Conversion DICOM to NIFTI using SITK
64 | reader = sitk.ImageSeriesReader()
65 | dicom_names = reader.GetGDCMSeriesFileNames(old_T2_folder)
66 | reader.SetFileNames(dicom_names)
67 | image = reader.Execute()
68 |
69 | sitk.WriteImage(image, new_T2_path)
70 |
--------------------------------------------------------------------------------
/data/manifest-T2.tcia:
--------------------------------------------------------------------------------
1 | downloadServerUrl=https://public.cancerimagingarchive.net/nbia-download/servlet/DownloadServlet
2 | includeAnnotation=true
3 | noOfrRetry=4
4 | databasketId=manifest-T2.tcia
5 | manifestVersion=3.0
6 | ListOfSeriesToDownload=
7 | 1.3.6.1.4.1.14519.5.2.1.97824612055862366318560427964793890998
8 | 1.3.6.1.4.1.14519.5.2.1.124502390760917755084744319866454742880
9 | 1.3.6.1.4.1.14519.5.2.1.336209512066065542527308757738778134743
10 | 1.3.6.1.4.1.14519.5.2.1.61686433061457847902102354630069993489
11 | 1.3.6.1.4.1.14519.5.2.1.285266602767317398976732799437378140608
12 | 1.3.6.1.4.1.14519.5.2.1.128575943230315285809456617735815081338
13 | 1.3.6.1.4.1.14519.5.2.1.119974278459672124882397556735745886618
14 | 1.3.6.1.4.1.14519.5.2.1.328411971284129245304472742027457770689
15 | 1.3.6.1.4.1.14519.5.2.1.205715002896735863602196046118617324323
16 | 1.3.6.1.4.1.14519.5.2.1.47368711747357463116341094276606315239
17 | 1.3.6.1.4.1.14519.5.2.1.277564943275913764149628784943750055165
18 | 1.3.6.1.4.1.14519.5.2.1.216916846757178094156064648053287356705
19 | 1.3.6.1.4.1.14519.5.2.1.292451296497326330780007043491927669965
20 | 1.3.6.1.4.1.14519.5.2.1.37407646254753171628888832676854647641
21 | 1.3.6.1.4.1.14519.5.2.1.2154096351050073340417051795616964329
22 | 1.3.6.1.4.1.14519.5.2.1.288822954068116720576409516209064805137
23 | 1.3.6.1.4.1.14519.5.2.1.157455919411531329445971228461834454864
24 | 1.3.6.1.4.1.14519.5.2.1.300550246357357923662626736244214512354
25 | 1.3.6.1.4.1.14519.5.2.1.46875774646497052019561950116185371595
26 | 1.3.6.1.4.1.14519.5.2.1.265856471014025537322703012203178074916
27 | 1.3.6.1.4.1.14519.5.2.1.217554160929652517188403089378973632127
28 | 1.3.6.1.4.1.14519.5.2.1.81190427874393211914843328182235403675
29 | 1.3.6.1.4.1.14519.5.2.1.239006515845888908570518896552813145905
30 | 1.3.6.1.4.1.14519.5.2.1.226103095049025971830798326183529727687
31 | 1.3.6.1.4.1.14519.5.2.1.241519636120309329966889620612275264106
32 | 1.3.6.1.4.1.14519.5.2.1.250256352839480733008907732689438039992
33 | 1.3.6.1.4.1.14519.5.2.1.155625578034220411317764200158067934139
34 | 1.3.6.1.4.1.14519.5.2.1.223264893116803537954728806540487445907
35 | 1.3.6.1.4.1.14519.5.2.1.56803546825794638336125075839814402436
36 | 1.3.6.1.4.1.14519.5.2.1.168133623024284256524654755031136793033
37 | 1.3.6.1.4.1.14519.5.2.1.41051157220867917432644808566260695421
38 | 1.3.6.1.4.1.14519.5.2.1.13168934333778663511634424074788249440
39 | 1.3.6.1.4.1.14519.5.2.1.98519287807856841042726497993241875149
40 | 1.3.6.1.4.1.14519.5.2.1.67345468089650726208933617170352542940
41 | 1.3.6.1.4.1.14519.5.2.1.128016141184262134692211776899422909852
42 | 1.3.6.1.4.1.14519.5.2.1.261904727650229936032200059779168110546
43 | 1.3.6.1.4.1.14519.5.2.1.80306368584051877336147767062103188154
44 | 1.3.6.1.4.1.14519.5.2.1.81013456284548964428831720032766314407
45 | 1.3.6.1.4.1.14519.5.2.1.137851401856381638890659062774470307621
46 | 1.3.6.1.4.1.14519.5.2.1.274222328416232835281613446776194950887
47 | 1.3.6.1.4.1.14519.5.2.1.189932513571932107895723417841672250908
48 | 1.3.6.1.4.1.14519.5.2.1.218365832761211504858210426893853705350
49 | 1.3.6.1.4.1.14519.5.2.1.299478506810267521121655201175658435054
50 | 1.3.6.1.4.1.14519.5.2.1.175770890260169282489793361209219236540
51 | 1.3.6.1.4.1.14519.5.2.1.288937279368837050893511455153770225852
52 | 1.3.6.1.4.1.14519.5.2.1.137305226373921385879942188391264951989
53 | 1.3.6.1.4.1.14519.5.2.1.159740827447177453253673680290849733064
54 | 1.3.6.1.4.1.14519.5.2.1.233605121987468319370647828457502185983
55 | 1.3.6.1.4.1.14519.5.2.1.29747780949545019036183341430656302425
56 | 1.3.6.1.4.1.14519.5.2.1.233973029399216582850133589760654139151
57 | 1.3.6.1.4.1.14519.5.2.1.9669015644716471476945977883357982292
58 | 1.3.6.1.4.1.14519.5.2.1.67158313966057965569167946723639338387
59 | 1.3.6.1.4.1.14519.5.2.1.78418533061355773065152791264029309750
60 | 1.3.6.1.4.1.14519.5.2.1.95190271536129325658493975898105509278
61 | 1.3.6.1.4.1.14519.5.2.1.8744922103190999130080234468699487701
62 | 1.3.6.1.4.1.14519.5.2.1.277261916620071690562868021692602577695
63 | 1.3.6.1.4.1.14519.5.2.1.333826283119726799662204205087684554536
64 | 1.3.6.1.4.1.14519.5.2.1.294465790050902079649170070800458333118
65 | 1.3.6.1.4.1.14519.5.2.1.230012513674586138089728143330133632463
66 | 1.3.6.1.4.1.14519.5.2.1.35307211638230668632830028521014104687
67 | 1.3.6.1.4.1.14519.5.2.1.337494328277457405561277816098934383813
68 | 1.3.6.1.4.1.14519.5.2.1.162851729691398294622797105490114199357
69 | 1.3.6.1.4.1.14519.5.2.1.100420693789734401722866581473132024737
70 | 1.3.6.1.4.1.14519.5.2.1.229384024656851675147876036353267258487
71 | 1.3.6.1.4.1.14519.5.2.1.304198865028657374537924002347255938106
72 | 1.3.6.1.4.1.14519.5.2.1.20670976059556792609125791949916925570
73 | 1.3.6.1.4.1.14519.5.2.1.229528759693718677584709692621291077626
74 | 1.3.6.1.4.1.14519.5.2.1.21877083602789524438704314882665653577
75 | 1.3.6.1.4.1.14519.5.2.1.86855432222058507463718198553899529003
76 | 1.3.6.1.4.1.14519.5.2.1.223475447957754453121941963938410079266
77 | 1.3.6.1.4.1.14519.5.2.1.117858331830377878190511865298944867098
78 | 1.3.6.1.4.1.14519.5.2.1.15167479224344864655434050143434357061
79 | 1.3.6.1.4.1.14519.5.2.1.297889053157145327024131887531428225347
80 | 1.3.6.1.4.1.14519.5.2.1.90465408233381138722540652694651320969
81 | 1.3.6.1.4.1.14519.5.2.1.302714639670083142058852903953837264697
82 | 1.3.6.1.4.1.14519.5.2.1.16904271606903764104530936528625259781
83 | 1.3.6.1.4.1.14519.5.2.1.12764749088880384028924045444467299852
84 | 1.3.6.1.4.1.14519.5.2.1.267862471627412118753428341131185702764
85 | 1.3.6.1.4.1.14519.5.2.1.303624579553286757476664912625222198071
86 | 1.3.6.1.4.1.14519.5.2.1.82000297190327313616877429419376399973
87 | 1.3.6.1.4.1.14519.5.2.1.110063895628460115704321585826189095537
88 | 1.3.6.1.4.1.14519.5.2.1.336499584971688286794823333380374653244
89 | 1.3.6.1.4.1.14519.5.2.1.139392630476501192761585947339412212897
90 | 1.3.6.1.4.1.14519.5.2.1.113300425991115767575938977397298840060
91 | 1.3.6.1.4.1.14519.5.2.1.307568576320578693001905075814766671550
92 | 1.3.6.1.4.1.14519.5.2.1.282832052300721436992245894412315111912
93 | 1.3.6.1.4.1.14519.5.2.1.213400619248062551895416596109098398213
94 | 1.3.6.1.4.1.14519.5.2.1.36934820552005519735534479931332292846
95 | 1.3.6.1.4.1.14519.5.2.1.41216239343660574082540246837790859853
96 | 1.3.6.1.4.1.14519.5.2.1.63157081399858537808770660614894164681
97 | 1.3.6.1.4.1.14519.5.2.1.167260142371217243657393980232701016137
98 | 1.3.6.1.4.1.14519.5.2.1.61398984386978615487608646301170730391
99 | 1.3.6.1.4.1.14519.5.2.1.170913910290121505784426628803446970389
100 | 1.3.6.1.4.1.14519.5.2.1.208490847474630115886915444238291380368
101 | 1.3.6.1.4.1.14519.5.2.1.298671145848708339737222217998577755570
102 | 1.3.6.1.4.1.14519.5.2.1.9755483032379805473706391940166669036
103 | 1.3.6.1.4.1.14519.5.2.1.268718756853741766983354453912312801730
104 | 1.3.6.1.4.1.14519.5.2.1.261237309343469689065576631723962333840
105 | 1.3.6.1.4.1.14519.5.2.1.224003143144336561123702641959818284842
106 | 1.3.6.1.4.1.14519.5.2.1.163015487469502642386883094241440581072
107 | 1.3.6.1.4.1.14519.5.2.1.170347731217393158076277121980482188304
108 | 1.3.6.1.4.1.14519.5.2.1.199464597880518074498752987949770064952
109 | 1.3.6.1.4.1.14519.5.2.1.20347614015017630361319142934914011596
110 | 1.3.6.1.4.1.14519.5.2.1.9721409353341461465501650910815656993
111 | 1.3.6.1.4.1.14519.5.2.1.14262170121308979137577589173619571563
112 | 1.3.6.1.4.1.14519.5.2.1.311333098893435802624860311753091976002
113 | 1.3.6.1.4.1.14519.5.2.1.220840979774567365812993669047801644120
114 | 1.3.6.1.4.1.14519.5.2.1.77558762586485924195080424411910919171
115 | 1.3.6.1.4.1.14519.5.2.1.168507932534014014653293784243131332894
116 | 1.3.6.1.4.1.14519.5.2.1.116723325586106252707395359498797460520
117 | 1.3.6.1.4.1.14519.5.2.1.29910714713700110800496924621325713622
118 | 1.3.6.1.4.1.14519.5.2.1.52208737328500500547869186262562161680
119 | 1.3.6.1.4.1.14519.5.2.1.41341777065628435310650949856777228905
120 | 1.3.6.1.4.1.14519.5.2.1.251606799383580969844189799392861321861
121 | 1.3.6.1.4.1.14519.5.2.1.282485957146535605228825771941012962357
122 | 1.3.6.1.4.1.14519.5.2.1.317359270245808302145533856585672960883
123 | 1.3.6.1.4.1.14519.5.2.1.135514198492981058482845815249880889690
124 | 1.3.6.1.4.1.14519.5.2.1.96562875876973827961400745245219169962
125 | 1.3.6.1.4.1.14519.5.2.1.256648491507353229813182334221467374181
126 | 1.3.6.1.4.1.14519.5.2.1.70317311370767903368272174720906718342
127 | 1.3.6.1.4.1.14519.5.2.1.127703226653324327717860855502627722306
128 | 1.3.6.1.4.1.14519.5.2.1.17891221028856069257173473170851398654
129 | 1.3.6.1.4.1.14519.5.2.1.30354814832219495751611811177811844787
130 | 1.3.6.1.4.1.14519.5.2.1.196082117235151381832692400871564748236
131 | 1.3.6.1.4.1.14519.5.2.1.334508876733181239165112956951269361398
132 | 1.3.6.1.4.1.14519.5.2.1.137213097494387692044497332427005289056
133 | 1.3.6.1.4.1.14519.5.2.1.66602819257976265659201750697700497901
134 | 1.3.6.1.4.1.14519.5.2.1.254907007513982459959439589947824133569
135 | 1.3.6.1.4.1.14519.5.2.1.277932022685732692159807028685158125644
136 | 1.3.6.1.4.1.14519.5.2.1.75869570711830604911027500287084681088
137 | 1.3.6.1.4.1.14519.5.2.1.43394638547031237271853527848743247298
138 | 1.3.6.1.4.1.14519.5.2.1.219920872738792573764887676715372489467
139 | 1.3.6.1.4.1.14519.5.2.1.304966211758865415213459139643277133250
140 | 1.3.6.1.4.1.14519.5.2.1.59097307935533647142368110661201694954
141 | 1.3.6.1.4.1.14519.5.2.1.104342378995544082321527649818115719381
142 | 1.3.6.1.4.1.14519.5.2.1.275976308826344111988529521225074469122
143 | 1.3.6.1.4.1.14519.5.2.1.76966817086725830177633239884425707885
144 | 1.3.6.1.4.1.14519.5.2.1.37500502292073832653905111102576102168
145 | 1.3.6.1.4.1.14519.5.2.1.15051876378477273432527967382885690193
146 | 1.3.6.1.4.1.14519.5.2.1.75023595377115501952746333684971209569
147 | 1.3.6.1.4.1.14519.5.2.1.10244709163143315037706780774420781635
148 | 1.3.6.1.4.1.14519.5.2.1.315702370575273964468589032564024942002
149 | 1.3.6.1.4.1.14519.5.2.1.232581844696793658570866321285325476363
150 | 1.3.6.1.4.1.14519.5.2.1.255106474671713017130022360202121610175
151 | 1.3.6.1.4.1.14519.5.2.1.184438972748249848205519432640579417556
152 | 1.3.6.1.4.1.14519.5.2.1.15594630066063018098965121607394131685
153 | 1.3.6.1.4.1.14519.5.2.1.77895049258647801574952424776894897832
154 | 1.3.6.1.4.1.14519.5.2.1.194222281136456272002461195297286978806
155 | 1.3.6.1.4.1.14519.5.2.1.292361225108641180764247877135850676426
156 | 1.3.6.1.4.1.14519.5.2.1.84787388055021215524629025646692795407
157 | 1.3.6.1.4.1.14519.5.2.1.126181735854342571807747278015493395427
158 | 1.3.6.1.4.1.14519.5.2.1.50857688276531731282813856253493358728
159 | 1.3.6.1.4.1.14519.5.2.1.158363608374533284432503372099530292444
160 | 1.3.6.1.4.1.14519.5.2.1.221009419007488062284911758544674914233
161 | 1.3.6.1.4.1.14519.5.2.1.214587129890999158719654764602468482956
162 | 1.3.6.1.4.1.14519.5.2.1.301199852589293592180667435413243749421
163 | 1.3.6.1.4.1.14519.5.2.1.91307143982593757658329925269213107444
164 | 1.3.6.1.4.1.14519.5.2.1.115916457443958592970450702751155898672
165 | 1.3.6.1.4.1.14519.5.2.1.251387290435470956632040740033833126539
166 | 1.3.6.1.4.1.14519.5.2.1.35592985115515755107463670280222959193
167 | 1.3.6.1.4.1.14519.5.2.1.189790696754423307903604808866635619814
168 | 1.3.6.1.4.1.14519.5.2.1.278447060082661819769695941196761676970
169 | 1.3.6.1.4.1.14519.5.2.1.74127510988834203379779920116706931829
170 | 1.3.6.1.4.1.14519.5.2.1.158837645871384737652904421894110435515
171 | 1.3.6.1.4.1.14519.5.2.1.261027331158345315128560572348660712113
172 | 1.3.6.1.4.1.14519.5.2.1.320208736433890307354533957518098015251
173 | 1.3.6.1.4.1.14519.5.2.1.255721252600754855960977522233110250936
174 | 1.3.6.1.4.1.14519.5.2.1.265562325801308771480973706098784037904
175 | 1.3.6.1.4.1.14519.5.2.1.59496075545541835714082685052623985780
176 | 1.3.6.1.4.1.14519.5.2.1.29707760275931807980484755409216479658
177 | 1.3.6.1.4.1.14519.5.2.1.71847268882604614946395654926696180481
178 | 1.3.6.1.4.1.14519.5.2.1.139899967629592458992273833701974367027
179 | 1.3.6.1.4.1.14519.5.2.1.60645495344224097120659188750733663433
180 | 1.3.6.1.4.1.14519.5.2.1.76234766989041933862422924921980946094
181 | 1.3.6.1.4.1.14519.5.2.1.54647801429999999880279745055162166396
182 | 1.3.6.1.4.1.14519.5.2.1.245881267132473546456095715670497260366
183 | 1.3.6.1.4.1.14519.5.2.1.271500782689034842482795482397599708503
184 | 1.3.6.1.4.1.14519.5.2.1.260858848998229883842233413025317121641
185 | 1.3.6.1.4.1.14519.5.2.1.138787690545499304848411355088748136195
186 | 1.3.6.1.4.1.14519.5.2.1.138114624312317833838264235708033684106
187 | 1.3.6.1.4.1.14519.5.2.1.190044300314624874115076527729396088731
188 | 1.3.6.1.4.1.14519.5.2.1.327102410004294511180793158899656978209
189 | 1.3.6.1.4.1.14519.5.2.1.290140223859345390541699250506135777329
190 | 1.3.6.1.4.1.14519.5.2.1.235533134899035107929505275392119110300
191 | 1.3.6.1.4.1.14519.5.2.1.256223606579513472138624134007553794517
192 | 1.3.6.1.4.1.14519.5.2.1.292025009655148971742680910641710654260
193 | 1.3.6.1.4.1.14519.5.2.1.278019916823314020426350685915747751677
194 | 1.3.6.1.4.1.14519.5.2.1.119338114603372415979087053769916795752
195 | 1.3.6.1.4.1.14519.5.2.1.154846187556208082175515830387781463718
196 | 1.3.6.1.4.1.14519.5.2.1.241158227314558271467289218638320805912
197 | 1.3.6.1.4.1.14519.5.2.1.229985221915959115066070872369878709015
198 | 1.3.6.1.4.1.14519.5.2.1.129069453057329278334463872222670450244
199 | 1.3.6.1.4.1.14519.5.2.1.130306192263501254884259911459312278931
200 | 1.3.6.1.4.1.14519.5.2.1.99099341904959222126543542841859021660
201 | 1.3.6.1.4.1.14519.5.2.1.226937211758235683777950388828543487213
202 | 1.3.6.1.4.1.14519.5.2.1.222046637334875981449965463114585543169
203 | 1.3.6.1.4.1.14519.5.2.1.246909769442555605040748585982394591148
204 | 1.3.6.1.4.1.14519.5.2.1.197011928085586949947985979972960177494
205 | 1.3.6.1.4.1.14519.5.2.1.210688939016236904061492475261436050128
206 | 1.3.6.1.4.1.14519.5.2.1.194203897008456569475098794782196805092
207 | 1.3.6.1.4.1.14519.5.2.1.243028238309713788501554716738850686675
208 | 1.3.6.1.4.1.14519.5.2.1.333979424673348586669984373810626623120
209 | 1.3.6.1.4.1.14519.5.2.1.195824806623977377348042506440300253900
210 | 1.3.6.1.4.1.14519.5.2.1.86663105982361050637985901375777674119
211 | 1.3.6.1.4.1.14519.5.2.1.80596214383930081973613702686831725427
212 | 1.3.6.1.4.1.14519.5.2.1.297601975245157269521546365597109246887
213 | 1.3.6.1.4.1.14519.5.2.1.175508249861315421813970648276848747955
214 | 1.3.6.1.4.1.14519.5.2.1.197835772424729511288285229448608628340
215 | 1.3.6.1.4.1.14519.5.2.1.216583092100832104191093532137558942068
216 | 1.3.6.1.4.1.14519.5.2.1.323720174552274881744647199360391559339
217 | 1.3.6.1.4.1.14519.5.2.1.76123013711237089310162869568540894767
218 | 1.3.6.1.4.1.14519.5.2.1.57055430988757827707470998009719003163
219 | 1.3.6.1.4.1.14519.5.2.1.282176336559963455064991060259622459376
220 | 1.3.6.1.4.1.14519.5.2.1.27477011593957184467450523730643807258
221 | 1.3.6.1.4.1.14519.5.2.1.210671159041526701646380148425994405952
222 | 1.3.6.1.4.1.14519.5.2.1.251865584911266925837724605787007632031
223 | 1.3.6.1.4.1.14519.5.2.1.191494736838268150020432657327243028009
224 | 1.3.6.1.4.1.14519.5.2.1.44323770382390486457593204813672763488
225 | 1.3.6.1.4.1.14519.5.2.1.156769265159911961821857178602133057319
226 | 1.3.6.1.4.1.14519.5.2.1.216294443593141688683154232645411947601
227 | 1.3.6.1.4.1.14519.5.2.1.201429277212342505564727134166942685942
228 | 1.3.6.1.4.1.14519.5.2.1.232657652330222096161712085710696008345
229 | 1.3.6.1.4.1.14519.5.2.1.232111980308888243130020179540191986379
230 | 1.3.6.1.4.1.14519.5.2.1.171450547646265436894896475413907994571
231 | 1.3.6.1.4.1.14519.5.2.1.260208429721371945913593488747008392402
232 | 1.3.6.1.4.1.14519.5.2.1.96074972317210856257940726753393079402
233 | 1.3.6.1.4.1.14519.5.2.1.3281197893335564763282849900767310827
234 | 1.3.6.1.4.1.14519.5.2.1.33496187979283913017539607570678513263
235 | 1.3.6.1.4.1.14519.5.2.1.67049266716817140881910601189350295111
236 | 1.3.6.1.4.1.14519.5.2.1.53214686551225675208337966913982664260
237 | 1.3.6.1.4.1.14519.5.2.1.23925966695469658554360021791190910677
238 | 1.3.6.1.4.1.14519.5.2.1.249041531104210908961143956340899475303
239 | 1.3.6.1.4.1.14519.5.2.1.233826382362237990205162021534559156502
240 | 1.3.6.1.4.1.14519.5.2.1.144422009854692416651143605689152425184
241 | 1.3.6.1.4.1.14519.5.2.1.106786035069646005027604098072287812951
242 | 1.3.6.1.4.1.14519.5.2.1.236752444493873818319999694516257268975
243 | 1.3.6.1.4.1.14519.5.2.1.238358624802117233802379047360345454460
244 | 1.3.6.1.4.1.14519.5.2.1.300748886239012525308783476240258125999
245 | 1.3.6.1.4.1.14519.5.2.1.271262780142356626547695550879331444566
246 | 1.3.6.1.4.1.14519.5.2.1.83992889315124485147505742018229332342
247 | 1.3.6.1.4.1.14519.5.2.1.288810963554028825092258346944809083760
248 | 1.3.6.1.4.1.14519.5.2.1.173896487730572575760488822814514876976
249 |
--------------------------------------------------------------------------------
/data/preprocess_requirements.txt:
--------------------------------------------------------------------------------
1 | SimpleITK
2 | pydicom
3 | tqdm
4 | natsort
--------------------------------------------------------------------------------
/figs/DiceScore_comparision.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReubenDo/InExtremIS/1512ddf9b8c11c4d9f0ebd465d904ef3d539d350/figs/DiceScore_comparision.png
--------------------------------------------------------------------------------
/figs/supervision_comparision.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ReubenDo/InExtremIS/1512ddf9b8c11c4d9f0ebd465d904ef3d539d350/figs/supervision_comparision.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 | import os
6 | from tqdm import tqdm
7 |
8 | import pandas as pd
9 |
10 | import torch
11 | from torch import nn
12 |
13 | from monai.data import DataLoader, Dataset, NiftiSaver
14 | from monai.transforms import (
15 | Compose,
16 | LoadNiftid,
17 | AddChanneld,
18 | NormalizeIntensityd,
19 | Orientationd,
20 | ToTensord,
21 | )
22 | from monai.utils import set_determinism
23 | from monai.inferers import sliding_window_inference
24 |
25 | from network.unet2d5 import UNet2D5
26 |
27 | # Define training and patches sampling parameters
28 | SPATIAL_SHAPE = (224,224,48)
29 |
30 | NB_CLASSES = 2
31 |
32 | # Number of worker
33 | workers = 20
34 |
35 | # Training parameters
36 | val_eval_criterion_alpha = 0.95
37 | train_loss_MA_alpha = 0.95
38 | nb_patience = 10
39 | patience_lr = 5
40 | weight_decay = 1e-5
41 |
42 | PHASES = ['training', 'validation', 'inference']
43 |
44 | def infinite_iterable(i):
45 | while True:
46 | yield from i
47 |
48 | def inference(paths_dict, model, transform_inference, device, opt):
49 |
50 | # Define transforms for data normalization and augmentation
51 | dataloaders = dict()
52 | subjects_dataset = dict()
53 |
54 | checkpoint_path = os.path.join(opt.model_dir,'models', './CP_{}.pth')
55 | checkpoint_path = checkpoint_path.format(opt.epoch_inf)
56 | assert os.path.isfile(checkpoint_path), 'no checkpoint found'
57 | print(checkpoint_path)
58 | model.load_state_dict(torch.load(checkpoint_path))
59 |
60 | model = model.to(device)
61 |
62 | for phase in ['inference']:
63 | subjects_dataset[phase] = Dataset(paths_dict, transform=transform_inference)
64 | dataloaders[phase] = DataLoader(subjects_dataset[phase], batch_size=1, shuffle=False)
65 |
66 |
67 | model.eval() # Set model to evaluate mode
68 |
69 | fold_name = 'output_pred'
70 | # Iterate over data
71 | with torch.no_grad():
72 | saver = NiftiSaver(output_dir=os.path.join(opt.model_dir,fold_name))
73 | for batch in tqdm(dataloaders['inference']):
74 | inputs = batch['img'].to(device)
75 |
76 | pred = sliding_window_inference(inputs, opt.spatial_shape, 1, model, mode='gaussian')
77 |
78 | pred = pred.argmax(1, keepdim=True).detach()
79 | saver.save_batch(pred, batch["img_meta_dict"])
80 |
81 |
82 | def main():
83 | opt = parsing_data()
84 |
85 | set_determinism(seed=2)
86 |
87 | if torch.cuda.is_available():
88 | print('[INFO] GPU available.')
89 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
90 | else:
91 | raise Exception(
92 | "[INFO] No GPU found.")
93 |
94 |
95 | print("[INFO] Reading data")
96 | # PHASES
97 | split_path = os.path.join(opt.dataset_split)
98 | df_split = pd.read_csv(split_path,header =None)
99 | list_file = dict()
100 | for phase in PHASES: # list of patient name associated to each phase
101 | list_file[phase] = df_split[df_split[1].isin([phase])][0].tolist()
102 |
103 | # CREATING DICT FOR DATASET
104 | mod_ext = "_T2.nii.gz"
105 | paths_dict = {split:[] for split in PHASES}
106 | for split in PHASES:
107 | for subject in list_file[split]:
108 | subject_data = dict()
109 | if os.path.exists(os.path.join(opt.path_data,subject+mod_ext)):
110 | subject_data["img"] = os.path.join(opt.path_data,subject+mod_ext)
111 | paths_dict[split].append(subject_data)
112 | print(f"Nb patients in {split} data: {len(paths_dict[split])}")
113 |
114 | # Logging hyperparameters
115 | print("[INFO] Hyperparameters")
116 | print('Spatial shape: {}'.format(opt.spatial_shape))
117 | print(f"Inference on the {opt.phase} set")
118 |
119 | # PREPROCESSING
120 | all_keys = ["img"]
121 | test_transforms = Compose(
122 | (
123 | LoadNiftid(keys=all_keys),
124 | AddChanneld(keys=all_keys),
125 | Orientationd(keys=all_keys, axcodes="RAS"),
126 | NormalizeIntensityd(keys=all_keys),
127 | ToTensord(keys=all_keys)
128 | )
129 | )
130 |
131 | # MODEL
132 | norm_op_kwargs = {"eps": 1e-5, "affine": True}
133 | net_nonlin = nn.LeakyReLU
134 | net_nonlin_kwargs = {"negative_slope": 1e-2, "inplace": True}
135 |
136 | model= UNet2D5(input_channels=1,
137 | base_num_features=16,
138 | num_classes=NB_CLASSES,
139 | num_pool=4,
140 | conv_op=nn.Conv3d,
141 | norm_op=nn.InstanceNorm3d,
142 | norm_op_kwargs=norm_op_kwargs,
143 | nonlin=net_nonlin,
144 | nonlin_kwargs=net_nonlin_kwargs).to(device)
145 |
146 | print("[INFO] Inference")
147 | inference(paths_dict[opt.phase], model, test_transforms, device, opt)
148 |
149 |
150 | def parsing_data():
151 | parser = argparse.ArgumentParser(
152 | description='Performing inference')
153 |
154 |
155 | parser.add_argument('--model_dir',
156 | type=str)
157 |
158 | parser.add_argument("--dataset_split",
159 | type=str,
160 | default="splits/split_inextremis_budget1.csv")
161 |
162 | parser.add_argument("--path_data",
163 | type=str,
164 | default="data/VS_MICCAI21/T2/")
165 |
166 | parser.add_argument('--phase',
167 | type=str,
168 | default='inference')
169 |
170 | parser.add_argument('--spatial_shape',
171 | type=int,
172 | nargs="+",
173 | default=(224,224,48))
174 |
175 | parser.add_argument('--epoch_inf',
176 | type=str,
177 | default='best')
178 |
179 | opt = parser.parse_args()
180 |
181 | return opt
182 |
183 | if __name__ == '__main__':
184 | main()
185 |
186 |
187 |
188 |
--------------------------------------------------------------------------------
/models/README.md:
--------------------------------------------------------------------------------
1 | The pre-trained models can be downloaded [here](https://zenodo.org/record/5081163/files/models.zip?download=1).
2 |
--------------------------------------------------------------------------------
/network/unet2d5.py:
--------------------------------------------------------------------------------
1 | # CODE ADAPTED FROM: https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/network_architecture/generic_UNet.py
2 |
3 | from torch import nn
4 | import torch
5 | import numpy as np
6 | import torch.nn.functional
7 |
8 |
9 | class InitWeights_He(object):
10 | def __init__(self, neg_slope=1e-2):
11 | self.neg_slope = neg_slope
12 |
13 | def __call__(self, module):
14 | if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
15 | module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
16 | if module.bias is not None:
17 | module.bias = nn.init.constant_(module.bias, 0)
18 |
19 |
20 | class ConvNormNonlinBlock(nn.Module):
21 | def __init__(
22 | self,
23 | input_channels,
24 | output_channels,
25 | conv_op=nn.Conv3d,
26 | conv_kwargs=None,
27 | norm_op=nn.InstanceNorm3d,
28 | norm_op_kwargs=None,
29 | nonlin=nn.LeakyReLU,
30 | nonlin_kwargs=None):
31 |
32 | """
33 | Block: Conv->Norm->Activation->Conv->Norm->Activation
34 | """
35 |
36 | super(ConvNormNonlinBlock, self).__init__()
37 |
38 | self.nonlin_kwargs = nonlin_kwargs
39 | self.nonlin = nonlin
40 | self.conv_op = conv_op
41 | self.norm_op = norm_op
42 | self.norm_op_kwargs = norm_op_kwargs
43 | self.conv_kwargs = conv_kwargs
44 | self.output_channels = output_channels
45 |
46 | self.first_conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs)
47 | self.first_norm = self.norm_op(output_channels, **self.norm_op_kwargs)
48 | self.first_acti = self.nonlin(**self.nonlin_kwargs)
49 |
50 | self.second_conv = self.conv_op(output_channels, output_channels, **self.conv_kwargs)
51 | self.second_norm = self.norm_op(output_channels, **self.norm_op_kwargs)
52 | self.second_acti = self.nonlin(**self.nonlin_kwargs)
53 |
54 | self.block = nn.Sequential(
55 | self.first_conv,
56 | self.first_norm,
57 | self.first_acti,
58 | self.second_conv,
59 | self.second_norm,
60 | self.second_acti
61 | )
62 |
63 |
64 | def forward(self, x):
65 | return self.block(x)
66 |
67 |
68 |
69 | class Upsample(nn.Module):
70 | def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False):
71 | super(Upsample, self).__init__()
72 | self.align_corners = align_corners
73 | self.mode = mode
74 | self.scale_factor = scale_factor
75 | self.size = size
76 |
77 | def forward(self, x):
78 | return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
79 |
80 |
81 | class UNet2D5(nn.Module):
82 |
83 | def __init__(
84 | self,
85 | input_channels,
86 | base_num_features,
87 | num_classes,
88 | num_pool,
89 | conv_op=nn.Conv3d,
90 | conv_kernel_sizes=None,
91 | norm_op=nn.InstanceNorm3d,
92 | norm_op_kwargs=None,
93 | nonlin=nn.LeakyReLU,
94 | nonlin_kwargs=None,
95 | weightInitializer=InitWeights_He(1e-2)):
96 | """
97 | 2.5D CNN combining 2D and 3D convolutions to dealwith the low through-plane resolution.
98 | The first two stages have 2D convolutions while the others have 3D convolutions.
99 |
100 | Architecture inspired by:
101 | Wang,et al: Automatic segmentation of vestibular schwannoma from t2-weighted mri
102 | by deep spatial attention with hardness-weighted loss. MICCAI 2019.
103 | """
104 | super(UNet2D5, self).__init__()
105 |
106 |
107 | if nonlin_kwargs is None:
108 | nonlin_kwargs = {'negative_slope':1e-2, 'inplace':True}
109 |
110 | if norm_op_kwargs is None:
111 | norm_op_kwargs = {'eps':1e-5, 'affine':True, 'momentum':0.1}
112 |
113 | self.conv_kwargs = {'stride':1, 'dilation':1, 'bias':True}
114 |
115 | self.nonlin = nonlin
116 | self.nonlin_kwargs = nonlin_kwargs
117 | self.norm_op_kwargs = norm_op_kwargs
118 | self.weightInitializer = weightInitializer
119 | self.conv_op = conv_op
120 | self.norm_op = norm_op
121 | self.num_classes = num_classes
122 |
123 | upsample_mode = 'trilinear'
124 | pool_op = nn.MaxPool3d
125 | pool_op_kernel_sizes = [(2, 2, 2)] * num_pool
126 | if conv_kernel_sizes is None:
127 | conv_kernel_sizes = [(3, 3, 1)] * 2 + [(3,3,3)]*(num_pool - 1)
128 |
129 |
130 | self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64)
131 | self.pool_op_kernel_sizes = pool_op_kernel_sizes
132 | self.conv_kernel_sizes = conv_kernel_sizes
133 |
134 | self.conv_pad_sizes = []
135 | for krnl in self.conv_kernel_sizes:
136 | self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl])
137 |
138 |
139 |
140 | self.conv_blocks_context = []
141 | self.conv_blocks_localization = []
142 | self.td = []
143 | self.tu = []
144 | self.seg_outputs = []
145 |
146 |
147 | input_features = input_channels
148 | output_features = base_num_features
149 |
150 |
151 | for d in range(num_pool):
152 | self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d]
153 | self.conv_kwargs['padding'] = self.conv_pad_sizes[d]
154 | # add convolutions
155 |
156 | self.conv_blocks_context.append(ConvNormNonlinBlock(input_features, output_features,
157 | self.conv_op, self.conv_kwargs, self.norm_op,
158 | self.norm_op_kwargs, self.nonlin, self.nonlin_kwargs))
159 |
160 | self.td.append(pool_op(pool_op_kernel_sizes[d]))
161 | input_features = output_features
162 | output_features = 2* output_features # Number of kernel increases by a factor 2 after each pooling
163 |
164 |
165 | final_num_features = self.conv_blocks_context[-1].output_channels
166 | self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool]
167 | self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool]
168 | self.conv_blocks_context.append(ConvNormNonlinBlock(input_features, final_num_features,
169 | self.conv_op, self.conv_kwargs, self.norm_op,
170 | self.norm_op_kwargs, self.nonlin, self.nonlin_kwargs))
171 |
172 |
173 | # now lets build the localization pathway
174 | for u in range(num_pool):
175 | nfeatures_from_skip = self.conv_blocks_context[-(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2
176 | n_features_after_tu_and_concat = nfeatures_from_skip * 2
177 |
178 | # the first conv reduces the number of features to match those of skip
179 | # the following convs work on that number of features
180 | # if not convolutional upsampling then the final conv reduces the num of features again
181 | if u != num_pool - 1:
182 | final_num_features = self.conv_blocks_context[-(3 + u)].output_channels
183 | else:
184 | final_num_features = nfeatures_from_skip
185 |
186 | self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u+1)], mode=upsample_mode))
187 |
188 |
189 | self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[-(u+1)]
190 | self.conv_kwargs['padding'] = self.conv_pad_sizes[-(u+1)]
191 | self.conv_blocks_localization.append(ConvNormNonlinBlock(n_features_after_tu_and_concat, final_num_features,
192 | self.conv_op, self.conv_kwargs, self.norm_op,
193 | self.norm_op_kwargs, self.nonlin, self.nonlin_kwargs))
194 |
195 |
196 | self.final_conv = conv_op(self.conv_blocks_localization[-1].output_channels, num_classes, 1, 1, 0, 1, 1, False)
197 |
198 |
199 |
200 | # register all modules properly
201 | self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization)
202 | self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context)
203 | self.td = nn.ModuleList(self.td)
204 | self.tu = nn.ModuleList(self.tu)
205 |
206 | if self.weightInitializer is not None:
207 | self.apply(self.weightInitializer)
208 |
209 | def forward(self, x):
210 | skips = []
211 | seg_outputs = []
212 | for d in range(len(self.conv_blocks_context) - 1):
213 | x = self.conv_blocks_context[d](x)
214 | skips.append(x)
215 | x = self.td[d](x)
216 |
217 | x = self.conv_blocks_context[-1](x)
218 |
219 | for u in range(len(self.tu)):
220 | x = self.tu[u](x)
221 | x = torch.cat((x, skips[-(u + 1)]), dim=1)
222 | x = self.conv_blocks_localization[u](x)
223 |
224 | output = self.final_conv(x)
225 | return output
226 |
227 |
228 |
229 |
230 |
231 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | monai==0.4.0
2 | nibabel==3.1.1
3 | torch==1.7.1
4 | dijkstra3d==1.8.0
5 | pandas
6 | medpy
7 | tqdm
8 | natsort
--------------------------------------------------------------------------------
/splits/split_fully_budget1.csv:
--------------------------------------------------------------------------------
1 | vs_gk_13,training
2 | vs_gk_24,training
3 | vs_gk_170,training
4 | vs_gk_26,training
5 | vs_gk_36,training
6 | vs_gk_14,training
7 | vs_gk_66,training
8 | vs_gk_142,training
9 | vs_gk_114,training
10 | vs_gk_50,training
11 | vs_gk_125,training
12 | vs_gk_4,training
13 | vs_gk_7,training
14 | vs_gk_200,validation
15 | vs_gk_182,validation
16 | vs_gk_202,testing
17 | vs_gk_203,testing
18 | vs_gk_204,testing
19 | vs_gk_205,testing
20 | vs_gk_206,testing
21 | vs_gk_207,testing
22 | vs_gk_209,testing
23 | vs_gk_210,testing
24 | vs_gk_211,testing
25 | vs_gk_212,testing
26 | vs_gk_213,testing
27 | vs_gk_214,testing
28 | vs_gk_215,testing
29 | vs_gk_216,testing
30 | vs_gk_217,testing
31 | vs_gk_218,testing
32 | vs_gk_220,testing
33 | vs_gk_221,testing
34 | vs_gk_222,testing
35 | vs_gk_223,testing
36 | vs_gk_224,testing
37 | vs_gk_225,testing
38 | vs_gk_226,testing
39 | vs_gk_228,testing
40 | vs_gk_229,testing
41 | vs_gk_230,testing
42 | vs_gk_231,testing
43 | vs_gk_232,testing
44 | vs_gk_233,testing
45 | vs_gk_234,testing
46 | vs_gk_235,testing
47 | vs_gk_236,testing
48 | vs_gk_237,testing
49 | vs_gk_238,testing
50 | vs_gk_239,testing
51 | vs_gk_240,testing
52 | vs_gk_241,testing
53 | vs_gk_242,testing
54 | vs_gk_243,testing
55 | vs_gk_244,testing
56 | vs_gk_245,testing
57 | vs_gk_246,testing
58 | vs_gk_247,testing
59 | vs_gk_248,testing
60 | vs_gk_249,testing
61 | vs_gk_250,testing
62 |
--------------------------------------------------------------------------------
/splits/split_fully_budget13.csv:
--------------------------------------------------------------------------------
1 | vs_gk_1,training
2 | vs_gk_2,training
3 | vs_gk_3,training
4 | vs_gk_4,training
5 | vs_gk_5,training
6 | vs_gk_6,training
7 | vs_gk_7,training
8 | vs_gk_8,training
9 | vs_gk_9,training
10 | vs_gk_10,training
11 | vs_gk_11,training
12 | vs_gk_12,training
13 | vs_gk_13,training
14 | vs_gk_14,training
15 | vs_gk_15,training
16 | vs_gk_16,training
17 | vs_gk_17,training
18 | vs_gk_18,training
19 | vs_gk_19,training
20 | vs_gk_20,training
21 | vs_gk_21,training
22 | vs_gk_22,training
23 | vs_gk_23,training
24 | vs_gk_24,training
25 | vs_gk_25,training
26 | vs_gk_26,training
27 | vs_gk_27,training
28 | vs_gk_28,training
29 | vs_gk_29,training
30 | vs_gk_30,training
31 | vs_gk_31,training
32 | vs_gk_32,training
33 | vs_gk_33,training
34 | vs_gk_34,training
35 | vs_gk_35,training
36 | vs_gk_36,training
37 | vs_gk_37,training
38 | vs_gk_38,training
39 | vs_gk_40,training
40 | vs_gk_41,training
41 | vs_gk_42,training
42 | vs_gk_43,training
43 | vs_gk_44,training
44 | vs_gk_45,training
45 | vs_gk_46,training
46 | vs_gk_47,training
47 | vs_gk_48,training
48 | vs_gk_49,training
49 | vs_gk_50,training
50 | vs_gk_51,training
51 | vs_gk_52,training
52 | vs_gk_53,training
53 | vs_gk_54,training
54 | vs_gk_55,training
55 | vs_gk_56,training
56 | vs_gk_57,training
57 | vs_gk_58,training
58 | vs_gk_59,training
59 | vs_gk_60,training
60 | vs_gk_61,training
61 | vs_gk_62,training
62 | vs_gk_63,training
63 | vs_gk_64,training
64 | vs_gk_65,training
65 | vs_gk_66,training
66 | vs_gk_67,training
67 | vs_gk_68,training
68 | vs_gk_69,training
69 | vs_gk_70,training
70 | vs_gk_71,training
71 | vs_gk_72,training
72 | vs_gk_73,training
73 | vs_gk_74,training
74 | vs_gk_75,training
75 | vs_gk_76,training
76 | vs_gk_77,training
77 | vs_gk_78,training
78 | vs_gk_79,training
79 | vs_gk_80,training
80 | vs_gk_81,training
81 | vs_gk_82,training
82 | vs_gk_83,training
83 | vs_gk_84,training
84 | vs_gk_85,training
85 | vs_gk_86,training
86 | vs_gk_87,training
87 | vs_gk_88,training
88 | vs_gk_89,training
89 | vs_gk_90,training
90 | vs_gk_91,training
91 | vs_gk_92,training
92 | vs_gk_93,training
93 | vs_gk_94,training
94 | vs_gk_95,training
95 | vs_gk_96,training
96 | vs_gk_98,training
97 | vs_gk_99,training
98 | vs_gk_100,training
99 | vs_gk_101,training
100 | vs_gk_102,training
101 | vs_gk_103,training
102 | vs_gk_104,training
103 | vs_gk_105,training
104 | vs_gk_106,training
105 | vs_gk_107,training
106 | vs_gk_108,training
107 | vs_gk_109,training
108 | vs_gk_110,training
109 | vs_gk_111,training
110 | vs_gk_112,training
111 | vs_gk_113,training
112 | vs_gk_114,training
113 | vs_gk_115,training
114 | vs_gk_116,training
115 | vs_gk_117,training
116 | vs_gk_118,training
117 | vs_gk_119,training
118 | vs_gk_120,training
119 | vs_gk_121,training
120 | vs_gk_122,training
121 | vs_gk_123,training
122 | vs_gk_124,training
123 | vs_gk_125,training
124 | vs_gk_126,training
125 | vs_gk_127,training
126 | vs_gk_128,training
127 | vs_gk_129,training
128 | vs_gk_131,training
129 | vs_gk_132,training
130 | vs_gk_133,training
131 | vs_gk_134,training
132 | vs_gk_135,training
133 | vs_gk_136,training
134 | vs_gk_137,training
135 | vs_gk_138,training
136 | vs_gk_139,training
137 | vs_gk_140,training
138 | vs_gk_141,training
139 | vs_gk_142,training
140 | vs_gk_143,training
141 | vs_gk_144,training
142 | vs_gk_145,training
143 | vs_gk_146,training
144 | vs_gk_147,training
145 | vs_gk_148,training
146 | vs_gk_149,training
147 | vs_gk_150,training
148 | vs_gk_151,training
149 | vs_gk_152,training
150 | vs_gk_153,training
151 | vs_gk_154,training
152 | vs_gk_155,training
153 | vs_gk_156,training
154 | vs_gk_157,training
155 | vs_gk_158,training
156 | vs_gk_159,training
157 | vs_gk_161,training
158 | vs_gk_162,training
159 | vs_gk_163,training
160 | vs_gk_164,training
161 | vs_gk_165,training
162 | vs_gk_166,training
163 | vs_gk_167,training
164 | vs_gk_169,training
165 | vs_gk_170,training
166 | vs_gk_171,training
167 | vs_gk_172,training
168 | vs_gk_173,training
169 | vs_gk_174,training
170 | vs_gk_175,training
171 | vs_gk_176,training
172 | vs_gk_177,training
173 | vs_gk_178,training
174 | vs_gk_179,training
175 | vs_gk_180,training
176 | vs_gk_181,training
177 | vs_gk_182,validation
178 | vs_gk_183,validation
179 | vs_gk_184,validation
180 | vs_gk_185,validation
181 | vs_gk_186,validation
182 | vs_gk_187,validation
183 | vs_gk_188,validation
184 | vs_gk_189,validation
185 | vs_gk_190,validation
186 | vs_gk_191,validation
187 | vs_gk_192,validation
188 | vs_gk_193,validation
189 | vs_gk_194,validation
190 | vs_gk_195,validation
191 | vs_gk_196,validation
192 | vs_gk_197,validation
193 | vs_gk_198,validation
194 | vs_gk_199,validation
195 | vs_gk_200,validation
196 | vs_gk_201,validation
197 | vs_gk_202,inference
198 | vs_gk_203,inference
199 | vs_gk_204,inference
200 | vs_gk_205,inference
201 | vs_gk_206,inference
202 | vs_gk_207,inference
203 | vs_gk_209,inference
204 | vs_gk_210,inference
205 | vs_gk_211,inference
206 | vs_gk_212,inference
207 | vs_gk_213,inference
208 | vs_gk_214,inference
209 | vs_gk_215,inference
210 | vs_gk_216,inference
211 | vs_gk_217,inference
212 | vs_gk_218,inference
213 | vs_gk_220,inference
214 | vs_gk_221,inference
215 | vs_gk_222,inference
216 | vs_gk_223,inference
217 | vs_gk_224,inference
218 | vs_gk_225,inference
219 | vs_gk_226,inference
220 | vs_gk_228,inference
221 | vs_gk_229,inference
222 | vs_gk_230,inference
223 | vs_gk_231,inference
224 | vs_gk_232,inference
225 | vs_gk_233,inference
226 | vs_gk_234,inference
227 | vs_gk_235,inference
228 | vs_gk_236,inference
229 | vs_gk_237,inference
230 | vs_gk_238,inference
231 | vs_gk_239,inference
232 | vs_gk_240,inference
233 | vs_gk_241,inference
234 | vs_gk_242,inference
235 | vs_gk_243,inference
236 | vs_gk_244,inference
237 | vs_gk_245,inference
238 | vs_gk_246,inference
239 | vs_gk_247,inference
240 | vs_gk_248,inference
241 | vs_gk_249,inference
242 | vs_gk_250,inference
243 |
--------------------------------------------------------------------------------
/splits/split_inextremis_budget1.csv:
--------------------------------------------------------------------------------
1 | vs_gk_1,training
2 | vs_gk_2,training
3 | vs_gk_3,training
4 | vs_gk_4,training
5 | vs_gk_5,training
6 | vs_gk_6,training
7 | vs_gk_7,training
8 | vs_gk_8,training
9 | vs_gk_9,training
10 | vs_gk_10,training
11 | vs_gk_11,training
12 | vs_gk_12,training
13 | vs_gk_13,training
14 | vs_gk_14,training
15 | vs_gk_15,training
16 | vs_gk_16,training
17 | vs_gk_17,training
18 | vs_gk_18,training
19 | vs_gk_19,training
20 | vs_gk_20,training
21 | vs_gk_21,training
22 | vs_gk_22,training
23 | vs_gk_23,training
24 | vs_gk_24,training
25 | vs_gk_25,training
26 | vs_gk_26,training
27 | vs_gk_27,training
28 | vs_gk_28,training
29 | vs_gk_29,training
30 | vs_gk_30,training
31 | vs_gk_31,training
32 | vs_gk_32,training
33 | vs_gk_33,training
34 | vs_gk_34,training
35 | vs_gk_35,training
36 | vs_gk_36,training
37 | vs_gk_37,training
38 | vs_gk_38,training
39 | vs_gk_40,training
40 | vs_gk_41,training
41 | vs_gk_42,training
42 | vs_gk_43,training
43 | vs_gk_44,training
44 | vs_gk_45,training
45 | vs_gk_46,training
46 | vs_gk_47,training
47 | vs_gk_48,training
48 | vs_gk_49,training
49 | vs_gk_50,training
50 | vs_gk_51,training
51 | vs_gk_52,training
52 | vs_gk_53,training
53 | vs_gk_54,training
54 | vs_gk_55,training
55 | vs_gk_56,training
56 | vs_gk_57,training
57 | vs_gk_58,training
58 | vs_gk_59,training
59 | vs_gk_60,training
60 | vs_gk_61,training
61 | vs_gk_62,training
62 | vs_gk_63,training
63 | vs_gk_64,training
64 | vs_gk_65,training
65 | vs_gk_66,training
66 | vs_gk_67,training
67 | vs_gk_68,training
68 | vs_gk_69,training
69 | vs_gk_70,training
70 | vs_gk_71,training
71 | vs_gk_72,training
72 | vs_gk_73,training
73 | vs_gk_74,training
74 | vs_gk_75,training
75 | vs_gk_76,training
76 | vs_gk_77,training
77 | vs_gk_78,training
78 | vs_gk_79,training
79 | vs_gk_80,training
80 | vs_gk_81,training
81 | vs_gk_82,training
82 | vs_gk_83,training
83 | vs_gk_84,training
84 | vs_gk_85,training
85 | vs_gk_86,training
86 | vs_gk_87,training
87 | vs_gk_88,training
88 | vs_gk_89,training
89 | vs_gk_90,training
90 | vs_gk_91,training
91 | vs_gk_92,training
92 | vs_gk_93,training
93 | vs_gk_94,training
94 | vs_gk_95,training
95 | vs_gk_96,training
96 | vs_gk_98,training
97 | vs_gk_99,training
98 | vs_gk_100,training
99 | vs_gk_101,training
100 | vs_gk_102,training
101 | vs_gk_103,training
102 | vs_gk_104,training
103 | vs_gk_105,training
104 | vs_gk_106,training
105 | vs_gk_107,training
106 | vs_gk_108,training
107 | vs_gk_109,training
108 | vs_gk_110,training
109 | vs_gk_111,training
110 | vs_gk_112,training
111 | vs_gk_113,training
112 | vs_gk_114,training
113 | vs_gk_115,training
114 | vs_gk_116,training
115 | vs_gk_117,training
116 | vs_gk_118,training
117 | vs_gk_119,training
118 | vs_gk_120,training
119 | vs_gk_121,training
120 | vs_gk_122,training
121 | vs_gk_123,training
122 | vs_gk_124,training
123 | vs_gk_125,training
124 | vs_gk_126,training
125 | vs_gk_127,training
126 | vs_gk_128,training
127 | vs_gk_129,training
128 | vs_gk_131,training
129 | vs_gk_132,training
130 | vs_gk_133,training
131 | vs_gk_134,training
132 | vs_gk_135,training
133 | vs_gk_136,training
134 | vs_gk_137,training
135 | vs_gk_138,training
136 | vs_gk_139,training
137 | vs_gk_140,training
138 | vs_gk_141,training
139 | vs_gk_142,training
140 | vs_gk_143,training
141 | vs_gk_144,training
142 | vs_gk_145,training
143 | vs_gk_146,training
144 | vs_gk_147,training
145 | vs_gk_148,training
146 | vs_gk_149,training
147 | vs_gk_150,training
148 | vs_gk_151,training
149 | vs_gk_152,training
150 | vs_gk_153,training
151 | vs_gk_154,training
152 | vs_gk_155,training
153 | vs_gk_156,training
154 | vs_gk_157,training
155 | vs_gk_158,training
156 | vs_gk_159,training
157 | vs_gk_161,training
158 | vs_gk_162,training
159 | vs_gk_163,training
160 | vs_gk_164,training
161 | vs_gk_165,training
162 | vs_gk_166,training
163 | vs_gk_167,training
164 | vs_gk_169,training
165 | vs_gk_170,training
166 | vs_gk_171,training
167 | vs_gk_172,training
168 | vs_gk_173,training
169 | vs_gk_174,training
170 | vs_gk_175,training
171 | vs_gk_176,training
172 | vs_gk_177,training
173 | vs_gk_178,training
174 | vs_gk_179,training
175 | vs_gk_180,training
176 | vs_gk_181,training
177 | vs_gk_182,validation
178 | vs_gk_183,validation
179 | vs_gk_184,validation
180 | vs_gk_185,validation
181 | vs_gk_186,validation
182 | vs_gk_187,validation
183 | vs_gk_188,validation
184 | vs_gk_189,validation
185 | vs_gk_190,validation
186 | vs_gk_191,validation
187 | vs_gk_192,validation
188 | vs_gk_193,validation
189 | vs_gk_194,validation
190 | vs_gk_195,validation
191 | vs_gk_196,validation
192 | vs_gk_197,validation
193 | vs_gk_198,validation
194 | vs_gk_199,validation
195 | vs_gk_200,validation
196 | vs_gk_201,validation
197 | vs_gk_202,inference
198 | vs_gk_203,inference
199 | vs_gk_204,inference
200 | vs_gk_205,inference
201 | vs_gk_206,inference
202 | vs_gk_207,inference
203 | vs_gk_209,inference
204 | vs_gk_210,inference
205 | vs_gk_211,inference
206 | vs_gk_212,inference
207 | vs_gk_213,inference
208 | vs_gk_214,inference
209 | vs_gk_215,inference
210 | vs_gk_216,inference
211 | vs_gk_217,inference
212 | vs_gk_218,inference
213 | vs_gk_220,inference
214 | vs_gk_221,inference
215 | vs_gk_222,inference
216 | vs_gk_223,inference
217 | vs_gk_224,inference
218 | vs_gk_225,inference
219 | vs_gk_226,inference
220 | vs_gk_228,inference
221 | vs_gk_229,inference
222 | vs_gk_230,inference
223 | vs_gk_231,inference
224 | vs_gk_232,inference
225 | vs_gk_233,inference
226 | vs_gk_234,inference
227 | vs_gk_235,inference
228 | vs_gk_236,inference
229 | vs_gk_237,inference
230 | vs_gk_238,inference
231 | vs_gk_239,inference
232 | vs_gk_240,inference
233 | vs_gk_241,inference
234 | vs_gk_242,inference
235 | vs_gk_243,inference
236 | vs_gk_244,inference
237 | vs_gk_245,inference
238 | vs_gk_246,inference
239 | vs_gk_247,inference
240 | vs_gk_248,inference
241 | vs_gk_249,inference
242 | vs_gk_250,inference
243 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | import argparse
5 | import time
6 | import os
7 | from tqdm import tqdm
8 |
9 | import numpy as np
10 | import pandas as pd
11 |
12 | import torch
13 | from torch import nn
14 |
15 | from monai.inferers import sliding_window_inference
16 | from monai.utils import set_determinism
17 | from monai.data import DataLoader, Dataset
18 | from monai.transforms import (
19 | Compose,
20 | LoadNiftid,
21 | AddChanneld,
22 | SpatialPadd,
23 | NormalizeIntensityd,
24 | RandFlipd,
25 | RandSpatialCropd,
26 | Orientationd,
27 | ToTensord,
28 | )
29 |
30 | from utilities.losses import DC, DC_CE_Focal, PartialLoss
31 | from utilities.utils import (
32 | create_logger,
33 | poly_lr,
34 | infinite_iterable)
35 | from utilities.geodesics import generate_geodesics
36 |
37 | from ScribbleDA.scribbleDALoss import CRFLoss
38 | from network.unet2d5 import UNet2D5
39 |
40 |
41 | # Define training and patches sampling parameters
42 | NB_CLASSES = 2
43 | PHASES = ["training", "validation"]
44 | MAX_EPOCHS = 300
45 |
46 | # Training parameters
47 | weight_decay = 3e-5
48 |
49 | def train(paths_dict, model, transformation, criterion, device, save_path, logger, opt):
50 |
51 | since = time.time()
52 |
53 | # Define transforms for data normalization and augmentation
54 | subjects_train = Dataset(
55 | paths_dict["training"],
56 | transform=transformation["training"])
57 |
58 | subjects_val = Dataset(
59 | paths_dict["validation"],
60 | transform=transformation["validation"])
61 |
62 | # Dataloaders
63 | dataloaders = dict()
64 | dataloaders["training"] = infinite_iterable(
65 | DataLoader(subjects_train, batch_size=opt.batch_size, num_workers=2, shuffle=True)
66 | )
67 | dataloaders["validation"] = infinite_iterable(
68 | DataLoader(subjects_val, batch_size=1, num_workers=2)
69 | )
70 |
71 | nb_batches = {
72 | "training": 30, # One image patch per epoch for the full dataset
73 | "validation": len(paths_dict["validation"])
74 | }
75 |
76 | # Training parameters are saved
77 | df_path = os.path.join(opt.model_dir,"log.csv")
78 | if os.path.isfile(df_path): # If the training already started
79 | df = pd.read_csv(df_path, index_col=False)
80 | epoch = df.iloc[-1]["epoch"]
81 | best_epoch = df.iloc[-1]["best_epoch"]
82 | best_val = df.iloc[-1]["best_val"]
83 | initial_lr = df.iloc[-1]["lr"]
84 | model.load_state_dict(torch.load(save_path.format("best")))
85 |
86 | else: # If training from scratch
87 | columns=["epoch","best_epoch", "MA", "best_MA", "lr", "timeit"]
88 | df = pd.DataFrame(columns=columns)
89 | best_val = None
90 | best_epoch = 0
91 | epoch = 0
92 | initial_lr = opt.learning_rate
93 |
94 |
95 | # Optimisation policy mimicking nnUnet training policy
96 | optimizer = torch.optim.SGD(model.parameters(), initial_lr,
97 | weight_decay=weight_decay, momentum=0.99, nesterov=True)
98 |
99 | # CRF Loss initialisation
100 | crf_l = CRFLoss(alpha=opt.alpha, beta=opt.beta, is_da=False, use_norm=False)
101 |
102 | # Training loop
103 | continue_training = True
104 | while continue_training:
105 | epoch+=1
106 | logger.info("-" * 10)
107 | logger.info("Epoch {}/".format(epoch))
108 | logger.info
109 | for param_group in optimizer.param_groups:
110 | logger.info("Current learning rate is: {}".format(param_group["lr"]))
111 |
112 | # Each epoch has a training and validation phase
113 | for phase in PHASES:
114 | if phase == "training":
115 | model.train() # Set model to training mode
116 | else:
117 | model.eval() # Set model to evaluate mode
118 |
119 | # Initializing the statistics
120 | running_loss = 0.0
121 | running_loss_reg = 0.0
122 | running_loss_seg = 0.0
123 | epoch_samples = 0
124 | running_time = 0.0
125 |
126 | # Iterate over data
127 | for _ in tqdm(range(nb_batches[phase])):
128 | batch = next(dataloaders[phase])
129 | inputs = batch["img"].to(device) # T2 images
130 | if opt.mode == "extreme_points":
131 | extremes = batch["label"].to(device) # Extreme points
132 | img_gradients = batch["img_gradient"].to(device) # Pre-Computed Sobel map
133 | else:
134 | labels = batch["label"].to(device)
135 |
136 | # zero the parameter gradients
137 | optimizer.zero_grad()
138 |
139 | with torch.set_grad_enabled(phase == "training"):
140 | if phase=="training": # Random patch predictions
141 | outputs = model(inputs)
142 | else: # if validation, Inference on the full image
143 | outputs = sliding_window_inference(
144 | inputs=inputs,
145 | roi_size=opt.spatial_shape,
146 | sw_batch_size=1,
147 | predictor=model,
148 | mode="gaussian",
149 | )
150 |
151 | if opt.mode == "extreme_points": # Generate geodesics
152 | init_time_geodesics = time.time()
153 | geodesics = []
154 | nb_target = outputs.shape[0]
155 | for i in range(nb_target):
156 | geodesics_i = generate_geodesics(
157 | extreme=extremes[i,...],
158 | img_gradient=img_gradients[i,...],
159 | prob=outputs[i,...],
160 | with_prob=opt.with_prob,
161 | with_euclidean=opt.with_euclidean
162 | )
163 | geodesics.append(geodesics_i.to(device))
164 | labels = torch.cat(geodesics,0)
165 | time_geodesics = time.time() - init_time_geodesics
166 | else:
167 | time_geodesics = 0.
168 |
169 | # Segmentation loss
170 | loss_seg = criterion(outputs, labels, phase)
171 |
172 | # CRF regularisation (training only)
173 | if (opt.beta>0 or opt.alpha>0) and phase == "training" and opt.mode == "extreme_points":
174 | reg = opt.weight_crf/np.prod(opt.spatial_shape)*crf_l(inputs, outputs)
175 | loss = loss_seg + reg
176 | else:
177 | reg = 0.0
178 | loss = loss_seg
179 |
180 | if phase == "training":
181 | loss.backward()
182 | optimizer.step()
183 |
184 | # Iteration statistics
185 | epoch_samples += 1
186 | running_loss += loss.item()
187 | running_loss_reg += reg
188 | running_loss_seg += loss_seg
189 | running_time += time_geodesics
190 |
191 | # Epoch statistcs
192 | epoch_loss = running_loss / epoch_samples
193 | epoch_loss_reg = running_loss_reg / epoch_samples
194 | epoch_loss_seg = running_loss_seg / epoch_samples
195 | if phase == "training":
196 | epoch_time = running_time / epoch_samples
197 |
198 | logger.info("{} Loss Reg: {:.4f}".format(
199 | phase, epoch_loss_reg))
200 | logger.info("{} Loss Seg: {:.4f}".format(
201 | phase, epoch_loss_seg))
202 | if phase == "training":
203 | logger.info("{} Time Geodesics: {:.4f}".format(
204 | phase, epoch_time))
205 |
206 | # Saving best model on the validation set
207 | if phase == "validation":
208 | if best_val is None: # first iteration
209 | best_val = epoch_loss
210 | torch.save(model.state_dict(), save_path.format("best"))
211 |
212 | if epoch_loss <= best_val:
213 | best_val = epoch_loss
214 | best_epoch = epoch
215 | torch.save(model.state_dict(), save_path.format("best"))
216 |
217 | df = df.append(
218 | {"epoch":epoch,
219 | "best_epoch":best_epoch,
220 | "best_val":best_val,
221 | "lr":param_group["lr"],
222 | "timeit":epoch_time},
223 | ignore_index=True)
224 | df.to_csv(df_path, index=False)
225 |
226 | optimizer.param_groups[0]["lr"] = poly_lr(epoch, MAX_EPOCHS, opt.learning_rate, 0.9)
227 |
228 | # Early stopping performed when full annotations are used (training set may be small)
229 | if opt.mode == "full_annotations" and epoch-best_epoch>70:
230 | torch.save(model.state_dict(), save_path.format("final"))
231 | continue_training=False
232 |
233 | if epoch == MAX_EPOCHS:
234 | torch.save(model.state_dict(), save_path.format("final"))
235 | continue_training=False
236 |
237 | time_elapsed = time.time() - since
238 | logger.info("[INFO] Training completed in {:.0f}m {:.0f}s".format(
239 | time_elapsed // 60, time_elapsed % 60))
240 | logger.info(f"[INFO] Best validation epoch is {best_epoch}")
241 |
242 |
243 | def main():
244 | set_determinism(seed=2)
245 |
246 | opt = parsing_data()
247 |
248 | # FOLDERS
249 | fold_dir = opt.model_dir
250 | fold_dir_model = os.path.join(fold_dir,"models")
251 | if not os.path.exists(fold_dir_model):
252 | os.makedirs(fold_dir_model)
253 | save_path = os.path.join(fold_dir_model,"./CP_{}.pth")
254 |
255 | if opt.path_labels is None:
256 | opt.path_labels = opt.path_data
257 |
258 | logger = create_logger(fold_dir)
259 | logger.info("[INFO] Hyperparameters")
260 | logger.info(f"Alpha: {opt.alpha}")
261 | logger.info(f"Beta: {opt.beta}")
262 | logger.info(f"Weight Reg: {opt.weight_crf}")
263 | logger.info(f"Batch size: {opt.batch_size}")
264 | logger.info(f"Spatial shape: {opt.spatial_shape}")
265 | logger.info(f"Initial lr: {opt.learning_rate}")
266 | logger.info(f"Postfix img gradients: {opt.img_gradient_postfix}")
267 | logger.info(f"Postfix labels: {opt.label_postfix}")
268 | logger.info(f"With euclidean: {opt.with_euclidean}")
269 | logger.info(f"With probs: {opt.with_prob}")
270 |
271 | # GPU CHECKING
272 | if torch.cuda.is_available():
273 | logger.info("[INFO] GPU available.")
274 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
275 | else:
276 | raise logger.error(
277 | "[INFO] No GPU found")
278 |
279 | # SPLIT
280 | assert os.path.isfile(opt.dataset_split), logger.error("[ERROR] Invalid split")
281 | df_split = pd.read_csv(opt.dataset_split,header =None)
282 | list_file = dict()
283 | for split in PHASES:
284 | list_file[split] = df_split[df_split[1].isin([split])][0].tolist()
285 |
286 |
287 | # CREATING DICT FOR CACHEDATASET
288 | mod_ext = "_T2.nii.gz"
289 | grad_ext = f"_{opt.img_gradient_postfix}.nii.gz"
290 | extreme_ext = f"_{opt.label_postfix}.nii.gz"
291 | paths_dict = {split:[] for split in PHASES}
292 |
293 | for split in PHASES:
294 | for subject in list_file[split]:
295 | subject_data = dict()
296 |
297 | img_path = os.path.join(opt.path_data,subject+mod_ext)
298 | img_grad_path = os.path.join(opt.path_labels,subject+grad_ext)
299 | lab_path = os.path.join(opt.path_labels,subject+extreme_ext)
300 |
301 | if os.path.exists(img_path) and os.path.exists(lab_path):
302 | subject_data["img"] = img_path
303 | subject_data["label"] = lab_path
304 |
305 | if opt.mode == "extreme_points":
306 | if os.path.exists(img_grad_path):
307 | subject_data["img_gradient"] = img_grad_path
308 | paths_dict[split].append(subject_data)
309 | else:
310 | paths_dict[split].append(subject_data)
311 |
312 | logger.info(f"Nb patients in {split} data: {len(paths_dict[split])}")
313 |
314 |
315 | # PREPROCESSING
316 | transforms = dict()
317 | all_keys = ["img", "label"]
318 | if opt.mode == "extreme_points":
319 | all_keys.append("img_gradient")
320 |
321 | transforms_training = (
322 | LoadNiftid(keys=all_keys),
323 | AddChanneld(keys=all_keys),
324 | Orientationd(keys=all_keys, axcodes="RAS"),
325 | NormalizeIntensityd(keys=["img"]),
326 | SpatialPadd(keys=all_keys, spatial_size=opt.spatial_shape),
327 | RandFlipd(keys=all_keys, prob=0.5, spatial_axis=0),
328 | RandSpatialCropd(keys=all_keys, roi_size=opt.spatial_shape, random_center=True, random_size=False),
329 | ToTensord(keys=all_keys),
330 | )
331 | transforms["training"] = Compose(transforms_training)
332 |
333 | transforms_validation = (
334 | LoadNiftid(keys=all_keys),
335 | AddChanneld(keys=all_keys),
336 | Orientationd(keys=all_keys, axcodes="RAS"),
337 | NormalizeIntensityd(keys=["img"]),
338 | SpatialPadd(keys=all_keys, spatial_size=opt.spatial_shape),
339 | ToTensord(keys=all_keys)
340 | )
341 | transforms["validation"] = Compose(transforms_validation)
342 |
343 | # MODEL
344 | logger.info("[INFO] Building model")
345 | norm_op_kwargs = {"eps": 1e-5, "affine": True}
346 | net_nonlin = nn.LeakyReLU
347 | net_nonlin_kwargs = {"negative_slope": 1e-2, "inplace": True}
348 |
349 | model= UNet2D5(input_channels=1,
350 | base_num_features=16,
351 | num_classes=NB_CLASSES,
352 | num_pool=4,
353 | conv_op=nn.Conv3d,
354 | norm_op=nn.InstanceNorm3d,
355 | norm_op_kwargs=norm_op_kwargs,
356 | nonlin=net_nonlin,
357 | nonlin_kwargs=net_nonlin_kwargs).to(device)
358 |
359 |
360 | logger.info("[INFO] Training")
361 | if opt.mode == "full_annotations":
362 | dice = DC(NB_CLASSES)
363 | criterion = lambda pred, grnd, phase: dice(pred, grnd)
364 |
365 | elif opt.mode == "extreme_points" or opt.mode == "geodesics":
366 | dice_ce_focal = DC_CE_Focal(NB_CLASSES)
367 | criterion = PartialLoss(dice_ce_focal)
368 |
369 | train(paths_dict,
370 | model,
371 | transforms,
372 | criterion,
373 | device,
374 | save_path,
375 | logger,
376 | opt)
377 |
378 | def parsing_data():
379 | parser = argparse.ArgumentParser(
380 | description="Script to train the models using extreme points as supervision")
381 |
382 | parser.add_argument("--model_dir",
383 | type=str,
384 | help="Path to the model directory")
385 |
386 | parser.add_argument("--mode",
387 | type=str,
388 | help="Choice of the supervision mode",
389 | choices=["full_annotations", "extreme_points", "geodesics"],
390 | default="extreme_points")
391 |
392 | parser.add_argument("--weight_crf",
393 | type=float,
394 | default=0.1)
395 |
396 | parser.add_argument("--alpha",
397 | type=float,
398 | default=15)
399 |
400 | parser.add_argument("--beta",
401 | type=float,
402 | default=0.05)
403 |
404 | parser.add_argument("--batch_size",
405 | type=int,
406 | default=6,
407 | help="Size of the batch size (default: 6)")
408 |
409 | parser.add_argument("--dataset_split",
410 | type=str,
411 | default="splits/split_inextremis_budget1.csv",
412 | help="Path to split file")
413 |
414 | parser.add_argument("--path_data",
415 | type=str,
416 | default="data/VS_MICCAI21/T2/",
417 | help="Path to the T2 scans")
418 |
419 | parser.add_argument("--path_labels",
420 | type=str,
421 | default=None,
422 | help="Path to the extreme points")
423 |
424 | parser.add_argument("--learning_rate",
425 | type=float,
426 | default=1e-2,
427 | help="Initial learning rate")
428 |
429 | parser.add_argument("--label_postfix",
430 | type=str,
431 | default="",
432 | help="Postfix of the Labels points")
433 |
434 | parser.add_argument("--img_gradient_postfix",
435 | type=str,
436 | default="",
437 | help="Postfix of the gradient images")
438 |
439 | parser.add_argument("--spatial_shape",
440 | type=int,
441 | nargs="+",
442 | default=(224,224,48),
443 | help="Size of the window patch")
444 |
445 | parser.add_argument("--with_prob",
446 | action="store_true",
447 | help="Add Deep probabilities")
448 |
449 | parser.add_argument("--with_euclidean",
450 | action="store_true",
451 | help="Add Euclidean distance")
452 |
453 | opt = parser.parse_args()
454 |
455 | return opt
456 |
457 | if __name__ == "__main__":
458 | main()
459 |
460 |
461 |
462 |
--------------------------------------------------------------------------------
/utilities/figure_fully.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | import matplotlib.pyplot as plt
4 | import numpy as np
5 | import pandas as pd
6 |
7 | score_dice = []
8 | path_model = 'models/full_annotations/fully_{}/results_full.csv'
9 | score_hd = []
10 | for i in [1,3,5,7,9,11,13]:
11 | score_dice.append(100*pd.read_csv(path_model.format(i))['dice'].mean())
12 | score_hd.append(pd.read_csv(path_model.format(i))['hd95'].mean())
13 |
14 | dice_InExtremIS = 100*pd.read_csv('models/extreme_points/manual_gradient_eucl_deep_crf/results_full.csv')['dice'].mean()
15 | hd_InExtremIS = pd.read_csv('models/extreme_points/manual_gradient_eucl_deep_crf/results_full.csv')['hd95'].mean()
16 | times = [2,6,10,14,18,22,26]
17 |
18 | scores = {
19 | 'Dice Score':{'unit':'%', 'scores':[score_dice, dice_InExtremIS]},
20 | '95th-percentile Hausdorff Distance':{'unit':'mm', 'scores':[score_hd, hd_InExtremIS]}
21 | }
22 |
23 | for metric, score in scores.items():
24 | unit = score['unit']
25 | supervised_scores, InExtreMIS_score = score['scores']
26 | fig = plt.figure()
27 | plt.title('Full supervision accuracy given an annotation time budget')
28 | plt.ylabel(f'{metric} ({unit})')
29 | plt.xlabel('Annotation time budget in hours')
30 |
31 | plt.plot(times, supervised_scores, '-ok', label='Fully Supervised')
32 | plt.scatter([2], [InExtreMIS_score], c='blue', label='Extreme points Supervision')
33 | plt.xticks(np.arange(0, 26+1, 2.0))
34 |
35 | if metric == 'Dice Score':
36 | shift = 0.1
37 | sign = 1
38 | else:
39 | shift = 0.01
40 | sign = -1
41 |
42 | diff_same_budget = round(InExtreMIS_score-supervised_scores[0],1)
43 | diff_same_score = times[[n for n,i in enumerate(supervised_scores) if sign*i>sign*InExtreMIS_score][0]] - 2
44 |
45 | plt.annotate(s='', xy=(diff_same_score+2.1,InExtreMIS_score), xytext=(2.1,InExtreMIS_score), arrowprops=dict(arrowstyle='<->', linestyle="--",linewidth=2, color='purple'))
46 | plt.annotate(s='', xy=(2,supervised_scores[0]), xytext=(2,InExtreMIS_score), arrowprops=dict(arrowstyle='<->', linestyle="--",linewidth=2, color='purple'))
47 |
48 | plt.text(diff_same_score/2, InExtreMIS_score+5*shift, f'{diff_same_score} hours', ha='left', va='center',color='purple')
49 | plt.text(1.2, InExtreMIS_score-diff_same_budget/2, f'{abs(diff_same_budget)}{unit} {metric}', ha='left',rotation=90, va='center', color='purple')
50 | plt.legend()
51 | plt.grid()
52 | plt.show()
53 | fig.savefig(f"figs/{metric.replace(' ','')}_comparision.pdf",bbox_inches='tight')
--------------------------------------------------------------------------------
/utilities/focal.py:
--------------------------------------------------------------------------------
1 | ## ADDAPTED from https://github.com/Project-MONAI/MONAI/blob/12f267c98eabdcd566dff11a3daf931201f04da4/monai/losses/focal_loss.py ###
2 |
3 |
4 | # Copyright 2020 - 2021 MONAI Consortium
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional, Union
16 |
17 | import torch
18 | import torch.nn.functional as F
19 | from torch.nn.modules.loss import _WeightedLoss
20 |
21 | from monai.utils import LossReduction
22 |
23 | class FocalLoss(_WeightedLoss):
24 | """
25 | Reimplementation of the Focal Loss described in:
26 |
27 | - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
28 | - "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy",
29 | Zhu et al., Medical Physics 2018
30 | """
31 |
32 | def __init__(
33 | self,
34 | gamma: float = 2.0,
35 | weight: Optional[torch.Tensor] = None,
36 | reduction: Union[LossReduction, str] = LossReduction.MEAN,
37 | ) -> None:
38 | """
39 | Args:
40 | gamma: value of the exponent gamma in the definition of the Focal loss.
41 | weight: weights to apply to the voxels of each class. If None no weights are applied.
42 | This corresponds to the weights `\alpha` in [1].
43 | reduction: {``"none"``, ``"mean"``, ``"sum"``}
44 | Specifies the reduction to apply to the output. Defaults to ``"mean"``.
45 |
46 | - ``"none"``: no reduction will be applied.
47 | - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
48 | - ``"sum"``: the output will be summed.
49 |
50 | Example:
51 | .. code-block:: python
52 |
53 | import torch
54 | from monai.losses import FocalLoss
55 |
56 | pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)
57 | grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64)
58 | fl = FocalLoss()
59 | fl(pred, grnd)
60 |
61 | """
62 | super(FocalLoss, self).__init__(weight=weight, reduction=LossReduction(reduction).value)
63 | self.gamma = gamma
64 | self.weight: Optional[torch.Tensor] = None
65 |
66 | def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
67 | """
68 | Args:
69 | logits: the shape should be BCH[WD].
70 | where C (greater than 1) is the number of classes.
71 | Softmax over the logits is integrated in this module for improved numerical stability.
72 | target: the shape should be B1H[WD] or BCH[WD].
73 | If the target's shape is B1H[WD], the target that this loss expects should be a class index
74 | in the range [0, C-1] where C is the number of classes.
75 |
76 | Raises:
77 | ValueError: When ``target`` ndim differs from ``logits``.
78 | ValueError: When ``target`` channel is not 1 and ``target`` shape differs from ``logits``.
79 | ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
80 |
81 | """
82 | i = logits
83 | t = target
84 |
85 | if i.ndimension() != t.ndimension():
86 | raise ValueError(f"logits and target ndim must match, got logits={i.ndimension()} target={t.ndimension()}.")
87 |
88 | if t.shape[1] != 1 and t.shape[1] != i.shape[1]:
89 | raise ValueError(
90 | "target must have one channel or have the same shape as the logits. "
91 | "If it has one channel, it should be a class index in the range [0, C-1] "
92 | f"where C is the number of classes inferred from 'logits': C={i.shape[1]}. "
93 | )
94 | if i.shape[1] == 1:
95 | raise NotImplementedError("Single-channel predictions not supported.")
96 |
97 | # Change the shape of logits and target to
98 | # num_batch x num_class x num_voxels.
99 | if i.dim() > 2:
100 | i = i.view(i.size(0), i.size(1), -1) # N,C,H,W => N,C,H*W
101 | t = t.view(t.size(0), t.size(1), -1) # N,1,H,W => N,1,H*W or N,C,H*W
102 | else: # Compatibility with classification.
103 | i = i.unsqueeze(2) # N,C => N,C,1
104 | t = t.unsqueeze(2) # N,1 => N,1,1 or N,C,1
105 |
106 | # Compute the log proba (more stable numerically than softmax).
107 | logpt = F.log_softmax(i, dim=1) # N,C,H*W
108 | # Keep only log proba values of the ground truth class for each voxel.
109 | if target.shape[1] == 1:
110 | logpt = logpt.gather(1, t.long()) # N,C,H*W => N,1,H*W
111 | logpt = torch.squeeze(logpt, dim=1) # N,1,H*W => N,H*W
112 |
113 | # Get the proba
114 | pt = torch.exp(logpt) # N,H*W or N,C,H*W
115 |
116 | if self.weight is not None:
117 | self.weight = self.weight.to(i)
118 | # Convert the weight to a map in which each voxel
119 | # has the weight associated with the ground-truth label
120 | # associated with this voxel in target.
121 | at = self.weight[None, :, None] # C => 1,C,1
122 | at = at.expand((t.size(0), -1, t.size(2))) # 1,C,1 => N,C,H*W
123 | if target.shape[1] == 1:
124 | at = at.gather(1, t.long()) # selection of the weights => N,1,H*W
125 | at = torch.squeeze(at, dim=1) # N,1,H*W => N,H*W
126 | # Multiply the log proba by their weights.
127 | logpt = logpt * at
128 |
129 | # Compute the loss mini-batch.
130 | weight = torch.pow(-pt + 1.0, self.gamma)
131 | if target.shape[1] == 1:
132 | loss = -weight * logpt # N
133 | else:
134 | loss = -weight * t * logpt # N,C
135 |
136 | if self.reduction == LossReduction.SUM.value:
137 | return loss.sum()
138 | if self.reduction == LossReduction.NONE.value:
139 | return loss
140 | if self.reduction == LossReduction.MEAN.value:
141 | return loss.mean()
142 | raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
--------------------------------------------------------------------------------
/utilities/geodesics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import dijkstra3d
4 |
5 | def normalize(data):
6 | return (data-data.min())/(data.max()-data.min())
7 |
8 | def generate_geodesics(extreme, img_gradient, prob, inside_bb_value=12, with_prob=True, with_euclidean=True):
9 | #
10 | extreme = extreme.squeeze()
11 | img_gradient = img_gradient.squeeze()
12 | prob = prob.squeeze()
13 |
14 | if (extreme==inside_bb_value).sum()>0: # If the image patch is not only background
15 | # Corners of the bounding boxes
16 | bb_x_min = torch.where(extreme==inside_bb_value)[0].min().item()
17 | bb_x_max = torch.where(extreme==inside_bb_value)[0].max().item() + 1
18 | bb_y_min = torch.where(extreme==inside_bb_value)[1].min().item()
19 | bb_y_max = torch.where(extreme==inside_bb_value)[1].max().item() + 1
20 | bb_z_min = torch.where(extreme==inside_bb_value)[2].min().item()
21 | bb_z_max = torch.where(extreme==inside_bb_value)[2].max().item() + 1
22 |
23 | # Only paths within the non-relaxed bounding box are considered
24 | img_gradient_crop = img_gradient[bb_x_min:bb_x_max,bb_y_min:bb_y_max,bb_z_min:bb_z_max].cpu().numpy()
25 | prob_crop = prob[:,bb_x_min:bb_x_max,bb_y_min:bb_y_max,bb_z_min:bb_z_max].detach()
26 | prob_crop = torch.nn.Softmax(0)(prob_crop)[0,...].cpu().numpy() # Probability of the background
27 |
28 | # Extreme points
29 | ex_x_min = torch.where(extreme==1)
30 | ex_x_max = torch.where(extreme==2)
31 | ex_y_min = torch.where(extreme==3)
32 | ex_y_max = torch.where(extreme==4)
33 | ex_z_min = torch.where(extreme==5)
34 | ex_z_max = torch.where(extreme==6)
35 |
36 | # Identifying the pairs of extreme points to join (Extreme points may miss --> patch based approach)
37 | couples = []
38 | if ex_x_min[0].shape[0]>0 and ex_x_max[0].shape[0]>0: # Extreme points in the x dimension
39 | couples.append([[k[0].item() for k in ex_x_min], [k[0].item() for k in ex_x_max]])
40 |
41 | if ex_y_min[0].shape[0]>0 and ex_y_max[0].shape[0]>0: # Extreme points in the y dimension
42 | couples.append([[k[0].item() for k in ex_y_min], [k[0].item() for k in ex_y_max]])
43 |
44 | if ex_z_min[0].shape[0]>0 and ex_z_max[0].shape[0]>0: # Extreme points in the z dimension
45 | couples.append([[k[0].item() for k in ex_z_min], [k[0].item() for k in ex_z_max]])
46 |
47 | couples_crop = [[[k[0]-bb_x_min, k[1]-bb_y_min,k[2]-bb_z_min] for k in couple] for couple in couples]
48 |
49 | # Calculating the geodesics using the dijkstra3d
50 | output_crop = inside_bb_value + np.zeros(img_gradient_crop.shape)
51 | for source, target in couples_crop:
52 | weights = img_gradient_crop.copy() # Image gradient term
53 |
54 | if with_prob:
55 | weights+=prob_crop # Deep background probability term
56 |
57 | if with_euclidean: # Normalized distance map to the target
58 | x, y, z = np.ogrid[0:img_gradient_crop.shape[0], 0:img_gradient_crop.shape[1], 0:img_gradient_crop.shape[2]]
59 | distances = np.sqrt((x-target[0])**2+(y-target[1])**2+(z-target[2])**2)
60 | distances = normalize(distances)
61 | weights+=distances
62 |
63 | path = dijkstra3d.dijkstra(weights, source, target, connectivity=26)
64 | for k in path:
65 | x,y,z = k
66 | output_crop[x,y,z] = 1
67 |
68 |
69 | output = torch.zeros(extreme.shape)
70 | output[bb_x_min:bb_x_max,bb_y_min:bb_y_max,bb_z_min:bb_z_max] = torch.from_numpy(output_crop.astype(int))
71 | return output[None,None,...] #Adding batch and channel
72 | else:
73 | # No geodesics
74 | for k in range(1,7):
75 | extreme[extreme==k] = 1
76 | return extreme[None,None,...] #Adding batch and channel
77 |
--------------------------------------------------------------------------------
/utilities/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from utilities.focal import FocalLoss
4 |
5 |
6 | class PartialLoss(nn.Module):
7 | def __init__(self, criterion):
8 | super(PartialLoss, self).__init__()
9 |
10 | self.criterion = criterion
11 | self.nb_classes = self.criterion.nb_classes
12 |
13 | def forward(self, outputs, partial_target, phase='training'):
14 | nb_target = outputs.shape[0]
15 | loss_target = 0.0
16 | total = 0
17 |
18 | for i in range(nb_target):
19 | partial_i = partial_target[i,...].reshape(-1)
20 | outputs_i = outputs[i,...].reshape(self.nb_classes, -1).unsqueeze(0)
21 | outputs_i = outputs_i[:,:,partial_i0:
25 | outputs_i= outputs_i.reshape(1,self.nb_classes,1,1,nb_annotated) # Reshape to a 5D tensor
26 | partial_i = partial_i[partial_i0:
31 | return loss_target/total
32 | else:
33 | return 0.0
34 |
35 |
36 | class DC(nn.Module):
37 | def __init__(self,nb_classes):
38 | super(DC, self).__init__()
39 |
40 | self.softmax = nn.Softmax(1)
41 | self.nb_classes = nb_classes
42 |
43 | @staticmethod
44 | def onehot(gt,shape):
45 | shp_y = gt.shape
46 | gt = gt.long()
47 | y_onehot = torch.zeros(shape)
48 | y_onehot = y_onehot.cuda()
49 | y_onehot.scatter_(1, gt, 1)
50 | return y_onehot
51 |
52 | def reshape(self,output, target):
53 | batch_size = output.shape[0]
54 |
55 | if not all([i == j for i, j in zip(output.shape, target.shape)]):
56 | target = self.onehot(target, output.shape)
57 |
58 | target = target.permute(0,2,3,4,1)
59 | output = output.permute(0,2,3,4,1)
60 | print(target.shape,output.shape)
61 | return output, target
62 |
63 |
64 | def dice(self, output, target):
65 | output = self.softmax(output)
66 | if not all([i == j for i, j in zip(output.shape, target.shape)]):
67 | target = self.onehot(target, output.shape)
68 |
69 | sum_axis = list(range(2,len(target.shape)))
70 |
71 | s = (10e-20)
72 | intersect = torch.sum(output * target,sum_axis)
73 | dice = (2 * intersect) / (torch.sum(output,sum_axis) + torch.sum(target,sum_axis) + s)
74 | #dice shape is (batch_size, nb_classes)
75 | return 1.0 - dice.mean()
76 |
77 |
78 | def forward(self, output, target):
79 | result = self.dice(output, target)
80 | return result
81 |
82 |
83 | class DC_CE_Focal(DC):
84 | def __init__(self,nb_classes):
85 | super(DC_CE_Focal, self).__init__(nb_classes)
86 |
87 | self.ce = nn.CrossEntropyLoss(reduction='mean')
88 | self.fl = FocalLoss(reduction="none")
89 |
90 | def focal(self, pred, grnd, phase="training"):
91 | score = self.fl(pred, grnd).reshape(-1)
92 |
93 | if phase=="training": # class-balanced focal loss
94 | output = 0.0
95 | nb_classes = 0
96 | for cl in range(self.nb_classes):
97 | if (grnd==cl).sum().item()>0:
98 | output+=score[grnd.reshape(-1)==cl].mean()
99 | nb_classes+=1
100 |
101 | if nb_classes>0:
102 | return output/nb_classes
103 | else:
104 | return 0.0
105 |
106 | else: # class-balanced focal loss
107 | return score.mean()
108 |
109 | def forward(self, output, target, phase="training"):
110 | # Dice term
111 | dc_loss = self.dice(output, target)
112 |
113 | # Focal term
114 | focal_loss = self.focal(output, target, phase)
115 |
116 | # Cross entropy
117 | output = output.permute(0,2,3,4,1).contiguous().view(-1,self.nb_classes)
118 | target = target.view(-1,).long().cuda()
119 | ce_loss = self.ce(output, target)
120 |
121 | result = ce_loss + dc_loss + focal_loss
122 | return result
123 |
124 |
125 |
--------------------------------------------------------------------------------
/utilities/scores.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from tqdm import tqdm
4 | from medpy.metric.binary import dc, precision, hd95
5 | import numpy as np
6 | import nibabel as nib
7 | import pandas as pd
8 | from natsort import natsorted
9 |
10 | opt = argparse.ArgumentParser(
11 | description="Computing scores")
12 |
13 | opt.add_argument("--model_dir",
14 | type=str,
15 | help="Path to the model directory")
16 | opt.add_argument("--path_data",
17 | type=str,
18 | default="../data/VS_MICCAI21/T2/",
19 | help="Path to the labels")
20 |
21 | opt = opt.parse_args()
22 |
23 | name_folder = os.path.basename(os.path.normpath(opt.model_dir))
24 |
25 | df_split = pd.read_csv('splits/split_inextremis_budget1.csv',header =None)
26 | list_patient = natsorted(df_split[df_split[1].isin(['inference'])][0].tolist())
27 |
28 | list_dice = []
29 | list_hd = []
30 | list_precision = []
31 |
32 | df_scores= {'name':[],'dice':[],'hd95':[],'precision':[]}
33 | for patient in tqdm(list_patient):
34 | path_gt = os.path.join(opt.path_data, patient+"_Label.nii.gz")
35 | path_pred = os.path.join(opt.model_dir,'output_pred',f"{patient}_T2",f"{patient}_T2_seg.nii.gz")
36 | gt = nib.funcs.as_closest_canonical(nib.load(path_gt)).get_fdata().squeeze()
37 | pred = nib.funcs.as_closest_canonical(nib.load(path_pred)).get_fdata().squeeze()
38 | affine = nib.funcs.as_closest_canonical(nib.load(path_gt)).affine
39 |
40 | voxel_spacing = [abs(affine[k,k]) for k in range(3)]
41 | dice_score = dc(pred, gt)
42 | if np.sum(pred)>0:
43 | hd_score = 0.0
44 | hd_score = hd95(pred, gt, voxelspacing=voxel_spacing)
45 | precision_score = precision(pred, gt)
46 |
47 | list_dice.append(100*dice_score)
48 | list_hd.append(hd_score)
49 | list_precision.append(100*precision_score)
50 |
51 | df_scores['name'].append(patient)
52 | df_scores['dice'].append(dice_score)
53 | df_scores['hd95'].append(hd_score)
54 | df_scores['precision'].append(precision_score)
55 |
56 | df_scores = pd.DataFrame(df_scores)
57 | df_scores.to_csv(os.path.join(opt.model_dir, "results_full.csv"))
58 |
59 |
60 | mean_dice = np.round(np.mean(list_dice),1)
61 | std_dice = np.round(np.std(list_dice),1)
62 | mean_hd = np.round(np.mean(list_hd),1)
63 | std_hd = np.round(np.std(list_hd),1)
64 | mean_precision = np.round(np.mean(list_precision),1)
65 | std_precision = np.round(np.std(list_precision),1)
66 |
67 | print(name_folder)
68 | print(f"{mean_dice} ({std_dice}) & {mean_hd} ({std_hd}) & {mean_precision} ({std_precision}) \\")
69 |
70 |
--------------------------------------------------------------------------------
/utilities/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 |
5 | def create_logger(folder):
6 | """Create a logger to save logs."""
7 | compt = 0
8 | while os.path.exists(os.path.join(folder,f"logs_{compt}.txt")):
9 | compt+=1
10 | logname = os.path.join(folder,f"logs_{compt}.txt")
11 |
12 | logger = logging.getLogger()
13 | fileHandler = logging.FileHandler(logname, mode="w")
14 | consoleHandler = logging.StreamHandler()
15 | logger.addHandler(fileHandler)
16 | logger.addHandler(consoleHandler)
17 | formatter = logging.Formatter("%(message)s")
18 | fileHandler.setFormatter(formatter)
19 | consoleHandler.setFormatter(formatter)
20 | logger.setLevel(logging.INFO)
21 | return logger
22 |
23 |
24 | def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
25 | """Learning rate policy used in nnUNet."""
26 | return initial_lr * (1 - epoch / max_epochs)**exponent
27 |
28 |
29 | def infinite_iterable(i):
30 | while True:
31 | yield from i
32 |
33 |
34 |
--------------------------------------------------------------------------------
/utilities/write_geodesics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 | import os
4 | import argparse
5 | import nibabel as nib
6 | import pandas as pd
7 | from geodesics import generate_geodesics
8 |
9 | PHASES = ["training", "validation"]
10 |
11 | def main():
12 | opt = parsing_data()
13 |
14 | # FOLDERS
15 | fold_dir = opt.output_folder
16 | if not os.path.exists(fold_dir):
17 | os.makedirs(fold_dir)
18 |
19 | # SPLIT
20 | assert os.path.isfile(opt.dataset_split), print("[ERROR] Invalid split file")
21 | df_split = pd.read_csv(opt.dataset_split,header =None)
22 | list_file = dict()
23 | for split in PHASES:
24 | list_file[split] = df_split[df_split[1].isin([split])][0].tolist()
25 |
26 | mod_ext = "_T2.nii.gz"
27 | grad_ext = f"_{opt.img_gradient_postfix}.nii.gz"
28 | extreme_ext = f"_{opt.label_postfix}.nii.gz"
29 | paths_dict = {split:[] for split in PHASES}
30 |
31 | print(f"Using the Euclidean distance: {opt.with_euclidean}")
32 | for split in PHASES:
33 | score = []
34 | for subject in tqdm(list_file[split]):
35 | subject_data = dict()
36 |
37 | img_path = os.path.join(opt.path_data,subject+mod_ext)
38 | img_grad_path = os.path.join(opt.path_extremes,subject+grad_ext)
39 | lab_path = os.path.join(opt.path_extremes,subject+extreme_ext)
40 | output_path = os.path.join(opt.output_folder,subject+'_PartLabel.nii.gz')
41 |
42 | if os.path.exists(img_path) and os.path.exists(lab_path) and os.path.exists(img_grad_path):
43 | extreme = nib.load(lab_path)
44 | affine = extreme.affine
45 | extreme_data = torch.from_numpy(extreme.get_fdata())
46 |
47 | grad_data = torch.from_numpy(nib.load(img_grad_path).get_fdata())
48 |
49 | geodesics = generate_geodesics(
50 | extreme=extreme_data,
51 | img_gradient=grad_data,
52 | prob=None,
53 | with_prob=False,
54 | with_euclidean=opt.with_euclidean).numpy().squeeze()
55 |
56 | nib.Nifti1Image(geodesics,affine).to_filename(output_path)
57 |
58 |
59 | def parsing_data():
60 | parser = argparse.ArgumentParser(
61 | description="Script to generate (non-deep) geodesics using extreme points")
62 |
63 | parser.add_argument("--output_folder",
64 | type=str,
65 | default="geodesics_folder",
66 | help="Path to the model directory")
67 |
68 | parser.add_argument("--dataset_split",
69 | type=str,
70 | default="splits/split_inextremis_budget1.csv",
71 | help="Path to split file")
72 |
73 | parser.add_argument("--path_data",
74 | type=str,
75 | default="../data/VS_MICCAI21/T2/",
76 | help="Path to the T2 scans")
77 |
78 | parser.add_argument("--path_extremes",
79 | type=str,
80 | default="../data/VS_MICCAI21/extremes_manual/",
81 | help="Path to the extreme points")
82 |
83 | parser.add_argument("--label_postfix",
84 | type=str,
85 | default="Extremes_man",
86 | help="Postfix of the Labels points")
87 |
88 | parser.add_argument("--img_gradient_postfix",
89 | type=str,
90 | default="Sobel_man",
91 | help="Postfix of the gradient images")
92 |
93 | parser.add_argument("--with_euclidean",
94 | action="store_true",
95 | help="Add Euclidean distance")
96 |
97 | opt = parser.parse_args()
98 |
99 | return opt
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------