├── radiogenomics ├── data │ ├── radgen │ │ └── processed │ │ │ ├── labels.csv │ │ │ └── lungs_roi │ │ │ ├── R01-003 │ │ │ ├── R01-003_ct.pt │ │ │ ├── R01-003_seg.pt │ │ │ ├── R01-003_bbx.pt │ │ │ ├── R01-003_lda_feat.npy │ │ │ ├── R01-003_raw_feat.npy │ │ │ └── R01-003_seg.npy │ │ │ └── R01-004 │ │ │ ├── R01-004_ct.pt │ │ │ ├── R01-004_seg.pt │ │ │ ├── R01-004_bbx.pt │ │ │ ├── R01-004_lda_feat.npy │ │ │ ├── R01-004_raw_feat.npy │ │ │ └── R01-004_seg.npy │ ├── msd │ │ └── processed │ │ │ └── lungs_roi │ │ │ └── lung_001 │ │ │ ├── lung_001_bbx.pt │ │ │ ├── lung_001_ct.pt │ │ │ └── lung_001_seg.pt │ └── rad │ │ └── processed │ │ └── lungs_roi │ │ └── LUNG1-001 │ │ ├── LUNG1-001_ct.pt │ │ ├── LUNG1-001_bbx.pt │ │ └── LUNG1-001_seg.pt ├── classifiers │ ├── dt_model.sav │ ├── qda_model.sav │ ├── rf_model.sav │ ├── svc_model.sav │ ├── LDA_dim_red.sav │ └── PCA_dim_red.sav ├── weights │ └── radgen_finetune.pt ├── requirements.txt ├── experiments │ ├── inf_class.py │ ├── utils │ │ ├── visualisation.py │ │ └── data_init_inf.py │ └── inf_seg.py ├── main.py ├── models │ ├── utils │ │ └── convlstm.py │ ├── ra_seg.py │ └── unet.py └── utils │ └── preprocess.py ├── .gitattributes ├── figures ├── CT_processing.png ├── arch_outline.PNG ├── radgen_pipeline.png └── final_model_isbi.png ├── LICENSE ├── .gitignore └── README.md /radiogenomics/data/radgen/processed/labels.csv: -------------------------------------------------------------------------------- 1 | case,label 2 | R01-003,1 3 | R01-004,0 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pt filter=lfs diff=lfs merge=lfs -text 2 | *.npy filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /figures/CT_processing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gollini/NSCLC_Radiogenomics/HEAD/figures/CT_processing.png -------------------------------------------------------------------------------- /figures/arch_outline.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gollini/NSCLC_Radiogenomics/HEAD/figures/arch_outline.PNG -------------------------------------------------------------------------------- /figures/radgen_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gollini/NSCLC_Radiogenomics/HEAD/figures/radgen_pipeline.png -------------------------------------------------------------------------------- /figures/final_model_isbi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gollini/NSCLC_Radiogenomics/HEAD/figures/final_model_isbi.png -------------------------------------------------------------------------------- /radiogenomics/classifiers/dt_model.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gollini/NSCLC_Radiogenomics/HEAD/radiogenomics/classifiers/dt_model.sav -------------------------------------------------------------------------------- /radiogenomics/classifiers/qda_model.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gollini/NSCLC_Radiogenomics/HEAD/radiogenomics/classifiers/qda_model.sav -------------------------------------------------------------------------------- /radiogenomics/classifiers/rf_model.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gollini/NSCLC_Radiogenomics/HEAD/radiogenomics/classifiers/rf_model.sav -------------------------------------------------------------------------------- /radiogenomics/classifiers/svc_model.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gollini/NSCLC_Radiogenomics/HEAD/radiogenomics/classifiers/svc_model.sav -------------------------------------------------------------------------------- /radiogenomics/classifiers/LDA_dim_red.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gollini/NSCLC_Radiogenomics/HEAD/radiogenomics/classifiers/LDA_dim_red.sav -------------------------------------------------------------------------------- /radiogenomics/classifiers/PCA_dim_red.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gollini/NSCLC_Radiogenomics/HEAD/radiogenomics/classifiers/PCA_dim_red.sav -------------------------------------------------------------------------------- /radiogenomics/weights/radgen_finetune.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bdc6b8e42667c252b90e8fbd53806a853a167b442caa693f2e9c7621b363ea63 3 | size 282424762 4 | -------------------------------------------------------------------------------- /radiogenomics/data/msd/processed/lungs_roi/lung_001/lung_001_bbx.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:377c554051850f36dc7db9f4a7fe4ee4711b93a3144a6915faa8e8cadff67184 3 | size 134226287 4 | -------------------------------------------------------------------------------- /radiogenomics/data/msd/processed/lungs_roi/lung_001/lung_001_ct.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1f91df76b1f7db5c6c6518dd4d6715ab02979b391c3e5ddab8bc0408a81c1d93 3 | size 80830767 4 | -------------------------------------------------------------------------------- /radiogenomics/data/msd/processed/lungs_roi/lung_001/lung_001_seg.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:48e98fe812f4072a587233b6371b70490bf3527269e73c522f9aa13138eed3d4 3 | size 67111471 4 | -------------------------------------------------------------------------------- /radiogenomics/data/rad/processed/lungs_roi/LUNG1-001/LUNG1-001_ct.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e76168d56f51b6976af76f28e4b81f75f6b3f6d58b14f66e3300636d41b18fb7 3 | size 82316655 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-003/R01-003_ct.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7114daea17e944892c173d21e1df9ed2e12591bc62310dd9d1e510f02ae21157 3 | size 81786671 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-003/R01-003_seg.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:062abcd8cb34ff2a60126dc4eeed9b36729dee0fd62e76c08dbe287a8c3e3743 3 | size 67110703 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-004/R01-004_ct.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:aa7fdf643e062430e653b5575fe819655f473ae53d68d5cfc9c5432bca693c9d 3 | size 82689647 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-004/R01-004_seg.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:514849697f60de5ce81e4a25313e437d8b0612db708eadcb016f4bb4c9a5bd46 3 | size 67110639 4 | -------------------------------------------------------------------------------- /radiogenomics/data/rad/processed/lungs_roi/LUNG1-001/LUNG1-001_bbx.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:794dee5755724dc604d6c7e2ef55e81c78ad4bd38af82e321649089564ad0474 3 | size 134765935 4 | -------------------------------------------------------------------------------- /radiogenomics/data/rad/processed/lungs_roi/LUNG1-001/LUNG1-001_seg.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c7e5f1f01757c025513547665a0e23984bba343967555df902b3da527b408463 3 | size 67272815 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-003/R01-003_bbx.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:666c8cd149d27f6a8bf03a8d6aff66acc65707421732ad5cacec849abba5c4ee 3 | size 134221487 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-003/R01-003_lda_feat.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4d1ecb2f79a415ff9e0045a01d5dd8972dadb3b1c1feff463026fac53bfa0852 3 | size 136 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-003/R01-003_raw_feat.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b82eeae70f6b8edacfff8d39b660338bca80143df121e759a9d71a14a98ea2a9 3 | size 32896 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-003/R01-003_seg.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:21d48b8523121bf4522900e6c6ec217171366dae4ec802f27f0bacc294b9f18a 3 | size 134217856 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-004/R01-004_bbx.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2d7ceb8d111ca0d35620209dcc44ad67f9ebcb0f779c0fc067d326df384fd6aa 3 | size 134223791 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-004/R01-004_lda_feat.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bdb0f39b2c1931c9fb5559dbef96c15170965c6e83878888cdabaccbf76da601 3 | size 136 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-004/R01-004_raw_feat.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d52f7ebd035ca59fa80cd62b7482605b48804d180e057eb7cec014b8b42a33ad 3 | size 32896 4 | -------------------------------------------------------------------------------- /radiogenomics/data/radgen/processed/lungs_roi/R01-004/R01-004_seg.npy: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b9a6638c955b29900d8da2bc79e9c4273f3415788f3915e91469d14ac2e10259 3 | size 134217856 4 | -------------------------------------------------------------------------------- /radiogenomics/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0+cu102 2 | torchvision==0.12.0+cu102 3 | # package location 4 | --find-links https://download.pytorch.org/whl/torch_stable.html 5 | monai==1.0.1 6 | numpy==1.23.4 7 | pandas==1.5.1 8 | SimpleITK==2.2.0 9 | pydicom==2.3.0 10 | nibabel==4.0.2 11 | torchio==0.18.85 12 | git+https://github.com/JoHof/lungmask 13 | tensorboard==2.11.0 14 | setuptools==65.5.0 15 | scikit-learn==1.1.3 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Gollini 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /radiogenomics/experiments/inf_class.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Ivo Gollini Navarrete 3 | Date: 14/nov/2022 4 | Institution: MBZUAI 5 | """ 6 | 7 | #Imports 8 | import os 9 | import numpy as np 10 | import pickle 11 | 12 | MODELS =[ 13 | "qda", 14 | "dt", 15 | "rf", 16 | "svc" 17 | ] 18 | 19 | """Class to perform inference on a trained model""" 20 | class Inf_class: 21 | def __init__(self, data_path, classifiers_path, model_class): 22 | self.data_path = data_path 23 | self.classifiers_path = classifiers_path 24 | self.model_class= model_class 25 | 26 | # Initialize model 27 | model_path = os.path.join(self.classifiers_path, self.model_class + '_model.sav') 28 | try: 29 | self.loaded_model = pickle.load(open(model_path, 'rb')) 30 | print("Model {} loaded".format(model_class)) 31 | # print(self.loaded_model.get_params()) 32 | except: 33 | print("Model {} not found".format(model_class)) 34 | print(MODELS) 35 | exit() 36 | 37 | # Initialize data 38 | X = [] 39 | 40 | for case in os.listdir(data_path): 41 | feat_path = os.path.join(data_path, case, case + "_lda_feat.npy") 42 | X.append(np.load(feat_path)) 43 | 44 | self.X = np.array(X) 45 | print("Data loaded with {} cases".format(len(self.X))) 46 | 47 | def run(self): 48 | # Predict 49 | result = self.loaded_model.predict(self.X) 50 | print(result) 51 | 52 | -------------------------------------------------------------------------------- /radiogenomics/experiments/utils/visualisation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Ibrahim Almakky 3 | Date: 01/04/2021 4 | Institution: MBZUAI 5 | """ 6 | 7 | # Copied from "https://gist.github.com/zachguo/10296432" 8 | def print_cm(cm, labels, hide_zeroes=False, hide_diagonal=False, hide_threshold=None): 9 | """pretty print for confusion matrixes""" 10 | columnwidth = max([len(str(x)) for x in labels] + [5]) # 5 is value length 11 | empty_cell = " " * columnwidth 12 | 13 | # Begin CHANGES 14 | fst_empty_cell = (columnwidth - 3) // 2 * " " + "t/p" + (columnwidth - 3) // 2 * " " 15 | 16 | if len(fst_empty_cell) < len(empty_cell): 17 | fst_empty_cell = " " * (len(empty_cell) - len(fst_empty_cell)) + fst_empty_cell 18 | # Print header 19 | cm_str = " " + fst_empty_cell + " " 20 | # End CHANGES 21 | 22 | for label in labels: 23 | cm_str += "%{0}s".format(columnwidth) % label + " " 24 | 25 | # print() 26 | cm_str += "\n" 27 | # Print rows 28 | for i, label1 in enumerate(labels): 29 | # print(" %{0}s".format(columnwidth) % label1, end=" ") 30 | cm_str += " %{0}s".format(columnwidth) % label1 + " " 31 | for j in range(len(labels)): 32 | cell = "%{0}.1f".format(columnwidth) % cm[i, j] 33 | if hide_zeroes: 34 | cell = cell if float(cm[i, j]) != 0 else empty_cell 35 | if hide_diagonal: 36 | cell = cell if i != j else empty_cell 37 | if hide_threshold: 38 | cell = cell if cm[i, j] > hide_threshold else empty_cell 39 | # print(cell, end=" ") 40 | cm_str += cell + " " 41 | # print() 42 | cm_str += "\n" 43 | return cm_str 44 | -------------------------------------------------------------------------------- /radiogenomics/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Ivo Gollini Navarrete 3 | Date: 21/august/2022 4 | Institution: MBZUAI 5 | """ 6 | 7 | # IMPORTS 8 | import sys 9 | import argparse 10 | import os 11 | from utils import preprocess 12 | from experiments.inf_seg import Inference 13 | from experiments.inf_class import Inf_class 14 | 15 | 16 | def path(string): 17 | if os.path.exists(string): 18 | return string 19 | else: 20 | sys.exit(f'File not found: {string}') 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('process', metavar='process', type=str, help='Process to be performed (Preprocess, train segmentation, inference)') 25 | parser.add_argument('input', metavar='input', type=path, help='Path to the input dataset') 26 | parser.add_argument('output', metavar='output', type=str, help='path to preprocessed data output / inference weights / inference classifiers') 27 | parser.add_argument('dataset', metavar='dataset', type=str, help='Select dataset / model to be used') 28 | parser.add_argument( 29 | "-D", 30 | "--debug", 31 | default=False, 32 | action="store_true", 33 | help="""Flag to set the experiment to debug mode.""") 34 | 35 | args = parser.parse_args() 36 | 37 | if args.process == "preprocess": 38 | data_preprocess = preprocess.Preprocess(args.input, args.output, args.dataset) 39 | if args.dataset == "radgen": data_preprocess.radiogenomics() 40 | elif args.dataset == "rad": data_preprocess.radiomics() 41 | elif args.dataset == "msd": data_preprocess.msd() 42 | 43 | elif args.process == "inf_seg": 44 | inference = Inference(args.input, args.output, args.dataset) 45 | inference.run() 46 | 47 | elif args.process == "inf_class": 48 | inf_class = Inf_class(args.input, args.output, args.dataset) 49 | inf_class.run() 50 | 51 | if __name__ == "__main__": 52 | main() -------------------------------------------------------------------------------- /radiogenomics/experiments/utils/data_init_inf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Ivo Gollini Navarrete 3 | Date: 14/oct/2022 4 | Institution: MBZUAI 5 | """ 6 | 7 | #Imports 8 | import os 9 | import json 10 | import numpy as np 11 | 12 | import torch 13 | from torch.utils.data.dataset import Dataset 14 | 15 | class DatasetInit(Dataset): 16 | """ 17 | Dataset class for the msd dataset. 18 | """ 19 | NAME = "msd" 20 | IMG_EXT = ".pt" 21 | 22 | # The modes supported by this dataset loader class 23 | SEGMENTATION_MODES = ["vanilla", "a2t"] 24 | 25 | SUBSETS = ["test"] 26 | 27 | def __init__( 28 | self, 29 | path, 30 | subset="test", 31 | mode="vanilla", 32 | channels=1, 33 | ) -> None: 34 | 35 | if subset not in self.SUBSETS: 36 | raise ValueError("""Specified subset for dataset is not recognized.""") 37 | self.subset = subset 38 | 39 | # Check that the modes are supported 40 | if ( 41 | mode not in self.SEGMENTATION_MODES 42 | ): 43 | raise ValueError("Unrecognised modes were selected for the Radiogenomics dataset.") 44 | 45 | self.mode = mode 46 | self.channels = channels 47 | self.path = path 48 | self.data = os.listdir(path) 49 | print("Dataset loaded with {} samples".format(len(self.data))) 50 | 51 | def __len__(self): 52 | return len(self.data) 53 | 54 | def __getitem__(self, idx): 55 | patient = self.data[idx] 56 | img_path = os.path.join(self.path, patient, patient+'_ct.pt') 57 | # target_path = os.path.join(self.path, patient, patient+'_seg.pt') 58 | 59 | img = torch.load(img_path) 60 | # target = torch.load(target_path).astype(np.float32) 61 | 62 | if self.mode == "vanilla": 63 | # return img , target 64 | return img, patient 65 | 66 | elif self.mode == "a2t": 67 | organ_path = os.path.join(self.path, patient, patient+'_bbx.pt') 68 | organ = torch.load(organ_path).astype(np.float32) 69 | # return img, organ , target 70 | return img, organ, patient 71 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | playground_utils 8 | radiogenomics/data/msd/Task06_Lung 9 | radiogenomics/data/rad/NSCLC-Radiomics 10 | radiogenomics/data/radgen/NSCLC-Radiogenomics 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /radiogenomics/experiments/inf_seg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Ivo Gollini Navarrete 3 | Date: 14/nov/2022 4 | Institution: MBZUAI 5 | """ 6 | 7 | #Imports 8 | import os 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.utils.data import DataLoader 14 | from experiments.utils import data_init_inf 15 | 16 | from models.unet import UNet 17 | from models.ra_seg import RA_Seg 18 | 19 | from monai.losses import DiceLoss, DiceCELoss, FocalLoss 20 | from monai.metrics import DiceMetric 21 | 22 | 23 | MODELS ={ 24 | "UNet": UNet, 25 | "RA_Seg": RA_Seg 26 | } 27 | 28 | CRITERIONS = { 29 | "CE": nn.CrossEntropyLoss, 30 | "BCE": nn.BCELoss, 31 | "BCEL": nn.BCEWithLogitsLoss, 32 | "FOCAL": FocalLoss, 33 | "DICE": DiceLoss, 34 | "DICE_CE": DiceCELoss, 35 | } 36 | 37 | METRIC = { 38 | "DICE": DiceMetric, 39 | } 40 | 41 | """Class to perform inference on a trained model""" 42 | class Inference: 43 | def __init__(self, data_path, weights_path, model_class): 44 | self.data_path = data_path 45 | self.weights_path = weights_path 46 | self.model_class= model_class 47 | 48 | # Set GPU 49 | self.device = torch.device( 50 | "cuda:" + str(torch.cuda.current_device()) 51 | if torch.cuda.is_available() 52 | else "cpu" 53 | ) 54 | print("Running on GPU:", torch.cuda.current_device()) 55 | 56 | # Initialize model 57 | try: 58 | self.model = MODELS[model_class]( 59 | spatial_dims=3, 60 | in_channels=1, 61 | out_channels=1, 62 | channels=[64, 128, 256, 512, 512], 63 | strides=[2, 2, 2, 2], 64 | ).to(self.device) 65 | print("Model {} loaded".format(model_class)) 66 | except: 67 | print("Model {} not found".format(model_class)) 68 | print(sorted(MODELS)) 69 | exit() 70 | 71 | # Load weights 72 | self.model.load_state_dict(torch.load(self.weights_path)) 73 | print("Weights loaded from {}".format(self.weights_path)) 74 | 75 | # Initizalize dataset 76 | self.testset = data_init_inf.DatasetInit( 77 | path = self.data_path, 78 | subset="test", 79 | channels=[64, 128, 256, 512, 512], 80 | mode = "a2t", 81 | ) 82 | 83 | self.test_loader = DataLoader( 84 | self.testset, 85 | batch_size=1, 86 | num_workers=4, 87 | shuffle=False 88 | ) 89 | 90 | self.criterion = CRITERIONS["DICE_CE"](to_onehot_y=True) 91 | self.metric = METRIC["DICE"](include_background=False, reduction="mean") 92 | 93 | def post_process(self, img): 94 | img[img >= 0.5] = 1 95 | img[img < 0.5] = 0 96 | img = img.to(torch.int64) 97 | return img 98 | 99 | def feat_preprocess(self, features, patient): 100 | vector = features.mean(dim=(-2,-1)).flatten(start_dim=1) 101 | vector = vector.squeeze().cpu().detach().numpy() 102 | print(vector.shape) 103 | np.save(os.path.join(self.data_path, patient, "{}_raw_feat.npy".format(patient)), vector) 104 | print('Raw high-Level features extracted.') 105 | return 106 | 107 | def save_seg(self, seg, patient): 108 | seg = seg.squeeze().cpu().detach().numpy() 109 | np.save(os.path.join(self.data_path, patient, "{}_seg.npy".format(patient)), seg) 110 | return 111 | 112 | def run(self): 113 | self.model.eval() 114 | 115 | print("Starting inference...") 116 | with torch.no_grad(): 117 | for batch_num, data in enumerate(self.test_loader): 118 | inp = data[0].to(self.device) 119 | 120 | if self.model_class == "RA_Seg": 121 | organ = data[1].to(self.device) 122 | test_output, hl_feat = self.model(inp, organ) # Tumor segmentation and high level features from encoder2. 123 | test_output = self.post_process(test_output) #Binarize 124 | self.save_seg(test_output, data[2][0]) # Save segmentation 125 | self.feat_preprocess(hl_feat, data[2][0]) #Extract features 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Radiogenomics Pipeline for Lung Nodules Segmentation and Prediction of EGFR Mutation Status from CT Scans. 2 | 3 | This is the official repository for the **Preprint** "A Radiogenomics Pipeline for Lung Nodules Segmentation and Prediction of EGFR Mutation Status from CT Scans". 4 | 5 | **Link:** [ArXiv](https://arxiv.org/abs/2211.06620) 6 | 7 | **Abstract** 8 | 9 | Lung cancer is a leading cause of death worldwide. Early-stage detection of lung cancer is essential for a more favorable prognosis. Radiogenomics is an emerging discipline that combines medical imaging and genomics features for modeling patient outcomes non-invasively. This study presents a radiogenomics pipeline that has: 1) a novel mixed architecture (RA-Seg) to segment lung cancer through attention and recurrent blocks; and 2) deep feature classifiers to distinguish Epidermal Growth Factor Receptor (EGFR) mutation status. We evaluate the proposed algorithm on multiple public datasets to assess its generalizability and robustness. We demonstrate how the proposed segmentation and classification methods outperform existing baseline and SOTA approaches (73.54 Dice and 93 F1 scores). 10 | 11 | ## Dependencies 12 | In addition to the python packages, [lungmask](https://github.com/JoHof/lungmask) is required. 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Preprocessing 18 | *Preprocess* argument prepares the data as available from their sources, data is normalized and stored as tensor for training or inference. 19 | Preprocessing is available for public dataset: [MSD](http://medicaldecathlon.com/) Task6_Lung, [NSCLC Radiomics](https://wiki.cancerimagingarchive.net/display/Public/NSCLC-Radiomics) (RAD), and [NSCLC Radiogenomics](https://wiki.cancerimagingarchive.net/display/Public/NSCLC+Radiogenomics) (RADGEN). 20 | 21 | For other datasets user must create the corresponding function making sure first axis contains slices, second axis goes from chest to back, and third axis from right to left. 22 | 23 | **Preprocessing:** 24 | * Voxel Intensity clipped to the range [-200, 250]. 25 | * Min-max normalization to the range [0, 1]. 26 | * Resampling to anisotropic resolution 1x1x1.5 mm^3. 27 | * Lungs Volume of Interest (VOI) crop utilizing [U-Net(R231)](https://github.com/JoHof/lungmask) pretrained weights. 28 | * Resize to 256×256×256 using b-spline for the image and nearest interpolation for the segmentation. 29 | * Generation of 3D bounding boxes using the tumor segmentation mask. 30 | 31 | ``` 32 | cd radiogenomics 33 | 34 | python main.py preprocess RAW_DATA_PATH PREPROCESSED_DATA_OUTPUT DATASET 35 | python main.py preprocess ./data/MSD/Task06_Lung ./data/MSD/processed msd 36 | ``` 37 | 38 | Illustrated example of data preprocessing: 39 | ![alt text](figures/CT_processing.png "Preprocessing example") 40 | 41 | ## Segmentation. Lung tumor segmentation. 42 | "Residual-Attention Segmentor" (RA-Seg) architecture combines the core concepts of three models: 43 | * [Volumentric U-NET](https://link.springer.com/chapter/10.1007/978-3-319-46723-8_49) 44 | * [Recurrent 3D-DenseUNet](https://link.springer.com/chapter/10.1007/978-3-030-62469-9_4) 45 | * [Organ-to-lesion (O2L) module](https://arxiv.org/abs/2010.12219) 46 | 47 | The proposed “Recurrence-Attention Segmentor” (RA-Seg) generates a tumor mask with a U-Net like structure, the introduction of an Atenttion block to recalibrate the skip connections, and a Recurrent Block to capture slice interdependencies of the CT. 48 | 49 | **The architecture of the proposed model**: 50 | ![alt text](figures/radgen_pipeline.png "Radiogenomics pipeline") 51 | 52 | **Architecture outline and details:** 53 | ![alt text](figures/arch_outline.PNG "Radiogenomics pipeline") 54 | 55 | ## Segmentation Inference. Lung tumor segmentation mask. 56 | Segmentation results: 57 | * Trained on MSD. Average DSC 67.24 (5 fold CV) and 68.65 on RADGEN testset. 58 | * Trained on RAD. Average DSC 72.36 (5 fold CV) and 66.11 on RADGEN testset. 59 | * Trained on MSD+RAD. Average DSC 71.42 (5 fold CV) and 73.54 on RADGEN testset. 60 | * Trained on RADGEN with pretrained weights (MSD+RAD). Average DSC 75.26 (5 fold CV). 61 | 62 | Available pretrained weight: 63 | * Trained on RADGEN with pretrained weights (MSD+RAD). 64 | 65 | 66 | Inference command will save tumor mask for preprocessed data available on DATA_PATH and extract the high-level raw features (NO dimensionality reduction). File structure: 67 | ``` 68 | DATA PATH 69 | |____ CASE1 70 | required: 71 | |____ |____ INPUT.pt 72 | |____ |____ TUMOR_BBX.pt 73 | generated: 74 | |____ |____ TUMOR_SEG.npy 75 | |____ |____ RAW_FEATS.npy 76 | | 77 | |____ CASE2 78 | |____ |____ INPUT.pt 79 | |____ |____ TUMOR_BBX.pt 80 | |____ |____ TUMOR_SEG.npy 81 | |____ |____ RAW_FEATS.npy 82 | |____ CASE3 83 | . 84 | . 85 | . 86 | ``` 87 | 88 | Command Example: 89 | 90 | ``` 91 | cd radiogenomics # Make sure you are on the right directory. 92 | 93 | python main.py inf_seg DATA_PATH WEIGHTS_PATH MODEL 94 | python main.py inf_seg ./data/radgen/processed/lungs_roi ./weights/radgen_finetune.pt RA_Seg 95 | ``` 96 | ![alt text](figures/final_model_isbi.png "Radiogenomics prediction example") 97 | 98 | ## Classification. EGFR mutation status classification. 99 | Highl-level deep features are extracted from the *Decoder 2* step output. Then undergo preprocessing (mean and flatten operation) with LDA dimensionality reduction. 100 | 101 | Classification results: 102 | * Quadratic discriminant analysis (QDA). Average ROC-AUC 0.90 (5 fold CV) 103 | * Decision Tree (DT). Average ROC-AUC 0.91 (5 fold CV) 104 | * Random Forest (RF). Average ROC-AUC 0.93 (5 fold CV) 105 | * C-Support Vector Classification (SVC). Average ROC-AUC 0.83 (5 fold CV) 106 | 107 | Classifiers available as SAV files: 108 | * QDA 109 | * DT 110 | * RF 111 | * SVC 112 | 113 | Inference command will return the prediction of EGFR mutation status. Class-negative corresponds to "Wildtype" and Class-positive corresponds to "Mutated". 114 | ``` 115 | cd radiogenomics # Make sure you are on the right directory. 116 | 117 | python main.py inf_class DATA_PATH SAVED_MODELS_PATH MODEL 118 | python main.py inf_class ./data/radgen/processed/lungs_roi ./classifiers qda 119 | ``` 120 | -------------------------------------------------------------------------------- /radiogenomics/models/utils/convlstm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Ivo Gollini Navarrete 3 | Date: 17/sep/2022 4 | Institution: MBZUAI 5 | Copied from: https://github.com/ndrplz/ConvLSTM_pytorch/blob/master/convlstm.py 6 | """ 7 | 8 | import torch.nn as nn 9 | import torch 10 | 11 | class ConvLSTMCell(nn.Module): 12 | 13 | def __init__(self, input_dim, hidden_dim, kernel_size, bias): 14 | """ 15 | Initialize ConvLSTM cell. 16 | Parameters 17 | ---------- 18 | input_dim: int 19 | Number of channels of input tensor. 20 | hidden_dim: int 21 | Number of channels of hidden state. 22 | kernel_size: (int, int) 23 | Size of the convolutional kernel. 24 | bias: bool 25 | Whether or not to add the bias. 26 | """ 27 | 28 | super(ConvLSTMCell, self).__init__() 29 | 30 | self.input_dim = input_dim 31 | self.hidden_dim = hidden_dim 32 | 33 | self.kernel_size = kernel_size 34 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 35 | self.bias = bias 36 | 37 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 38 | out_channels=4 * self.hidden_dim, 39 | kernel_size=self.kernel_size, 40 | padding=self.padding, 41 | bias=self.bias) 42 | 43 | def forward(self, input_tensor, cur_state): 44 | h_cur, c_cur = cur_state 45 | 46 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 47 | 48 | combined_conv = self.conv(combined) 49 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 50 | i = torch.sigmoid(cc_i) 51 | f = torch.sigmoid(cc_f) 52 | o = torch.sigmoid(cc_o) 53 | g = torch.tanh(cc_g) 54 | 55 | c_next = f * c_cur + i * g 56 | h_next = o * torch.tanh(c_next) 57 | 58 | return h_next, c_next 59 | 60 | def init_hidden(self, batch_size, image_size): 61 | height, width = image_size 62 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 63 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) 64 | 65 | 66 | class ConvLSTM(nn.Module): 67 | 68 | """ 69 | Parameters: 70 | input_dim: Number of channels in input 71 | hidden_dim: Number of hidden channels 72 | kernel_size: Size of kernel in convolutions 73 | num_layers: Number of LSTM layers stacked on each other 74 | batch_first: Whether or not dimension 0 is the batch or not 75 | bias: Bias or no bias in Convolution 76 | return_all_layers: Return the list of computations for all layers 77 | Note: Will do same padding. 78 | Input: 79 | A tensor of size B, T, C, H, W or T, B, C, H, W 80 | Output: 81 | A tuple of two lists of length num_layers (or length 1 if return_all_layers is False). 82 | 0 - layer_output_list is the list of lists of length T of each output 83 | 1 - last_state_list is the list of last states 84 | each element of the list is a tuple (h, c) for hidden state and memory 85 | Example: 86 | >> x = torch.rand((32, 10, 64, 128, 128)) 87 | >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False) 88 | >> _, last_states = convlstm(x) 89 | >> h = last_states[0][0] # 0 for layer index, 0 for h index 90 | """ 91 | 92 | def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, 93 | batch_first=False, bias=True, return_all_layers=False): 94 | super(ConvLSTM, self).__init__() 95 | 96 | self._check_kernel_size_consistency(kernel_size) 97 | 98 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 99 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 100 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 101 | if not len(kernel_size) == len(hidden_dim) == num_layers: 102 | raise ValueError('Inconsistent list length.') 103 | 104 | self.input_dim = input_dim 105 | self.hidden_dim = hidden_dim 106 | self.kernel_size = kernel_size 107 | self.num_layers = num_layers 108 | self.batch_first = batch_first 109 | self.bias = bias 110 | self.return_all_layers = return_all_layers 111 | 112 | cell_list = [] 113 | for i in range(0, self.num_layers): 114 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] 115 | 116 | cell_list.append(ConvLSTMCell(input_dim=cur_input_dim, 117 | hidden_dim=self.hidden_dim[i], 118 | kernel_size=self.kernel_size[i], 119 | bias=self.bias)) 120 | 121 | self.cell_list = nn.ModuleList(cell_list) 122 | 123 | def forward(self, input_tensor, hidden_state=None): 124 | """ 125 | Parameters 126 | ---------- 127 | input_tensor: todo 128 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 129 | hidden_state: todo 130 | None. todo implement stateful 131 | Returns 132 | ------- 133 | last_state_list, layer_output 134 | """ 135 | if not self.batch_first: 136 | # (t, b, c, h, w) -> (b, t, c, h, w) 137 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 138 | 139 | b, _, _, h, w = input_tensor.size() 140 | 141 | # Implement stateful ConvLSTM 142 | if hidden_state is not None: 143 | raise NotImplementedError() 144 | else: 145 | # Since the init is done in forward. Can send image size here 146 | hidden_state = self._init_hidden(batch_size=b, 147 | image_size=(h, w)) 148 | 149 | layer_output_list = [] 150 | last_state_list = [] 151 | 152 | seq_len = input_tensor.size(1) 153 | cur_layer_input = input_tensor 154 | 155 | for layer_idx in range(self.num_layers): 156 | 157 | h, c = hidden_state[layer_idx] 158 | output_inner = [] 159 | for t in range(seq_len): 160 | h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], 161 | cur_state=[h, c]) 162 | output_inner.append(h) 163 | 164 | layer_output = torch.stack(output_inner, dim=1) 165 | cur_layer_input = layer_output 166 | 167 | layer_output_list.append(layer_output) 168 | last_state_list.append([h, c]) 169 | 170 | if not self.return_all_layers: 171 | layer_output_list = layer_output_list[-1:] 172 | last_state_list = last_state_list[-1:] 173 | 174 | return layer_output_list, last_state_list 175 | 176 | def _init_hidden(self, batch_size, image_size): 177 | init_states = [] 178 | for i in range(self.num_layers): 179 | init_states.append(self.cell_list[i].init_hidden(batch_size, image_size)) 180 | return init_states 181 | 182 | @staticmethod 183 | def _check_kernel_size_consistency(kernel_size): 184 | if not (isinstance(kernel_size, tuple) or 185 | (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): 186 | raise ValueError('`kernel_size` must be tuple or list of tuples') 187 | 188 | @staticmethod 189 | def _extend_for_multilayer(param, num_layers): 190 | if not isinstance(param, list): 191 | param = [param] * num_layers 192 | return param -------------------------------------------------------------------------------- /radiogenomics/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Ivo Gollini Navarrete 3 | Date: 21/august/2022 4 | Institution: MBZUAI 5 | """ 6 | 7 | # IMPORTS 8 | import os 9 | import numpy as np 10 | # import pandas as pd 11 | import SimpleITK as sitk 12 | import pydicom as dicom 13 | import nibabel as nib 14 | 15 | import torch 16 | import torchio.transforms as tt 17 | from torchvision.ops import masks_to_boxes 18 | from lungmask import mask 19 | 20 | # from tqdm import tqdm 21 | 22 | class Preprocess: 23 | """ 24 | Prepare data for experiments. 25 | """ 26 | def __init__(self, datapath, dataoutput, dataset): 27 | self.datapath = datapath 28 | self.dataoutput = dataoutput 29 | self.dataset = dataset 30 | 31 | if torch.cuda.is_available(): 32 | self.device = torch.device("cuda:0") 33 | print('GPU Available') 34 | else: 35 | self.device = torch.device("cpu") 36 | print("GPU Not Available ") 37 | 38 | def msd(self): 39 | print('Preprocessing MSD dataset...') 40 | 41 | print(self.datapath) 42 | ct_list = sorted(os.listdir(os.path.join(self.datapath, 'imagesTr'))) 43 | seg_list = sorted(os.listdir(os.path.join(self.datapath, 'labelsTr'))) 44 | 45 | for pat, label in zip(ct_list, seg_list): 46 | patient = pat.split('.')[0] 47 | print(patient) 48 | 49 | ct = nib.load(os.path.join(self.datapath, 'imagesTr', pat)) 50 | ct = ct.get_fdata() 51 | ct = np.transpose(ct, (2,1,0)) 52 | ct = np.flip(ct, axis=1).copy() 53 | 54 | seg = nib.load(os.path.join(self.datapath, 'labelsTr', label)) 55 | seg = seg.get_fdata() 56 | seg = np.transpose(seg, (2,1,0)) 57 | seg = np.flip(seg, axis=1).copy() 58 | 59 | preprocessed_ct = self.normalize(ct[None, :, :, :]) 60 | print('CT Normalized') 61 | self.save_img(patient, preprocessed_ct) 62 | 63 | preprocessed_seg = self.normalize(seg[None, :, :, :], is_label=True) 64 | print('Seg Normalized') 65 | self.save_img(patient, preprocessed_seg, is_Label=True) 66 | 67 | self.extract_lungs(patient, ct, seg[None, :, :, :]) 68 | 69 | self.tumor_bbx(self.dataoutput+'/lungs_roi') 70 | return 71 | 72 | def radiomics(self): 73 | 74 | print('Preprocessing Radiomics dataset...') 75 | 76 | for root, directories, files in os.walk(self.datapath, topdown=True): 77 | for patient in sorted(directories): 78 | print(patient) 79 | study = os.listdir(os.path.join(root, patient))[0] 80 | elements = sorted(os.listdir(os.path.join(root, patient, study))) 81 | print(elements) 82 | for element in elements: 83 | if not "Segmentation" in element: 84 | dcms_path = os.path.join(root, patient, study, element) 85 | if len(os.listdir(dcms_path)) == 1: continue 86 | else: 87 | patient_ct = self.read_ct(dcms_path) 88 | patient_ct = sitk.GetArrayFromImage(patient_ct) 89 | preprocessed_ct = self.normalize(patient_ct[None, :, :, :]) 90 | self.save_img(patient, preprocessed_ct) 91 | 92 | else: 93 | seg_path = os.path.join(root, patient, study, element, '1-1.dcm') 94 | seg_data = dicom.read_file(seg_path) 95 | seg_array = seg_data.pixel_array 96 | seg_num = len(seg_data.SegmentSequence) 97 | print("Seg shape: {}, with {} segmentations".format(seg_array.shape, seg_num)) 98 | 99 | seg_idx = 0 100 | for i in range(seg_num): 101 | label_idx = seg_data.SegmentSequence[i].SegmentLabel 102 | if label_idx == "Neoplasm, Primary": 103 | print("Primary Neoplasm in slice {}".format(i)) 104 | seg_idx = i 105 | break 106 | 107 | dim0 = int(seg_array.shape[0]/seg_num) 108 | seg_tensor = torch.reshape(torch.from_numpy(seg_array), (seg_num, dim0, 512, 512)) 109 | patient_seg = seg_tensor[seg_idx] 110 | patient_seg = patient_seg.expand(1, patient_seg.shape[0], patient_seg.shape[1], patient_seg.shape[2]).numpy() 111 | 112 | preprocessed_seg = self.normalize(patient_seg, is_label=True) 113 | self.save_img(patient, preprocessed_seg, is_Label=True) 114 | 115 | self.extract_lungs(patient, patient_ct, patient_seg) 116 | break 117 | 118 | self.tumor_bbx(self.dataoutput+'/lungs_roi') 119 | 120 | def radiogenomics(self): 121 | 122 | print('Preprocessing Radiogenomics dataset...') 123 | 124 | patients_list = sorted(os.listdir(self.datapath)) 125 | for patient in patients_list: 126 | print(patient) 127 | 128 | patient_path = os.path.join(self.datapath, patient, 'CT') 129 | patient_ct = self.read_ct(patient_path) 130 | patient_ct = sitk.GetArrayFromImage(patient_ct) 131 | preprocessed_ct = self.normalize(patient_ct[None, :, :, :]) 132 | self.save_img(patient, preprocessed_ct) 133 | 134 | if os.path.exists(os.path.join(self.datapath, patient, 'seg')): 135 | patient_seg = self.read_ct(os.path.join(self.datapath, patient, 'seg')) 136 | patient_seg = sitk.GetArrayFromImage(patient_seg) 137 | 138 | preprocessed_seg = self.normalize(patient_seg, is_label=True) 139 | self.save_img(patient, preprocessed_seg, is_Label=True) 140 | 141 | self.extract_lungs(patient, patient_ct, patient_seg) 142 | self.extract_tumor(patient, patient_ct, torch.tensor(patient_seg[0])) 143 | 144 | else: 145 | self.extract_lungs(patient, patient_ct) 146 | continue 147 | self.tumor_bbx(self.dataoutput+'/lungs_roi') 148 | 149 | def extract_tumor(self, patient, ct, seg=None): 150 | tcr = self.roi_coord(seg, roi='tumor') # tcr = tumor_roi_coord 151 | tumor_ct = ct.copy() 152 | tumor_ct = self.crop_coord(tcr, tumor_ct) 153 | tumor_ct = self.normalize(tumor_ct, out_shape=(64, 64, 64)) 154 | self.save_img(patient, tumor_ct, mode='tumor_roi') 155 | 156 | def extract_lungs(self, patient, ct, seg=None): 157 | 158 | model = mask.get_model('unet','R231').to(self.device) 159 | extracted = mask.apply(ct, model) 160 | extracted = torch.tensor(extracted) 161 | 162 | lungs_mask = extracted.clone() 163 | lcr = self.roi_coord(lungs_mask, roi='lungs') # lcr = lung_roi_coord 164 | lungs_ct = ct.copy() 165 | lungs_ct = self.crop_coord(lcr, lungs_ct) 166 | lungs_ct = self.normalize(lungs_ct) 167 | self.save_img(patient, lungs_ct, mode='lungs_roi') 168 | 169 | if seg is not None: 170 | lungs_seg = seg.copy() 171 | lungs_seg = self.crop_coord(lcr, lungs_seg, is_label=True) 172 | lungs_seg = self.normalize(lungs_seg, is_label=True) 173 | self.save_img(patient, lungs_seg, mode='lungs_roi', is_Label=True) 174 | 175 | def crop_coord(self, coord, image, is_label=False): 176 | if is_label: 177 | image = image[:, coord[4]:coord[5], coord[2]:coord[3], coord[0]:coord[1]] 178 | else: 179 | image = image[coord[4]:coord[5], coord[2]:coord[3], coord[0]:coord[1]] 180 | image = image[None,:,:,:] 181 | 182 | return image 183 | 184 | def roi_coord(self, mask, roi='lungs'): 185 | frame_list = [] 186 | x_min, y_min, x_max, y_max = np.inf, np.inf, -np.inf, -np.inf 187 | 188 | if roi == 'lungs': 189 | mask[mask == 2] = 1 190 | 191 | for i in range(len(mask)): 192 | if mask[i].max() > 0: 193 | frame_list.append(i) 194 | 195 | ct_slice = mask[i] 196 | ct_slice = ct_slice[None, :, :] 197 | 198 | bbx = masks_to_boxes(ct_slice) 199 | bbx = bbx[0].detach().tolist() 200 | 201 | if bbx[0] < x_min: x_min = int(bbx[0]) 202 | if bbx[1] < y_min: y_min = int(bbx[1]) 203 | if bbx[2] > x_max: x_max = int(bbx[2]) 204 | if bbx[3] > y_max: y_max = int(bbx[3]) 205 | 206 | z_min = frame_list[0] 207 | z_max = frame_list[-1] 208 | return [x_min, x_max, y_min, y_max, z_min, z_max] 209 | 210 | def read_ct(self, path): 211 | reader = sitk.ImageSeriesReader() 212 | dcm_names = reader.GetGDCMSeriesFileNames(path) 213 | reader.SetFileNames(dcm_names) 214 | image = reader.Execute() 215 | return image 216 | 217 | def save_img(self, patient, image, mode='full_ct', roi=None, is_Label=False): 218 | if not os.path.exists(os.path.join(self.dataoutput, mode, patient)): 219 | os.makedirs(os.path.join(self.dataoutput, mode, patient)) 220 | if roi is None: 221 | if is_Label: 222 | torch.save(image, os.path.join(self.dataoutput, mode, patient, patient + '_seg.pt')) 223 | else: 224 | torch.save(image, os.path.join(self.dataoutput, mode, patient, patient + '_ct.pt')) 225 | 226 | else: 227 | if is_Label: 228 | torch.save(image, os.path.join(self.dataoutput, mode, patient, patient + '_' + roi + '_seg.pt')) 229 | else: 230 | torch.save(image, os.path.join(self.dataoutput, mode, patient, patient + '_' + roi + '_ct.pt')) 231 | 232 | def normalize(self, image, space=(1,1,1.5), out_shape=(256, 256, 256), is_label=False): 233 | if is_label: 234 | # image = tt.Resample(space, image_interpolation='nearest')(image) 235 | image = tt.Resize(out_shape, image_interpolation='nearest')(image) 236 | image = tt.RescaleIntensity(out_min_max=(0,1))(image) 237 | else: 238 | image = tt.Resample(space, image_interpolation='bspline')(image) 239 | image = tt.Resize(out_shape, image_interpolation='bspline')(image) 240 | image = tt.Clamp(out_min= -200, out_max=250)(image) 241 | image = tt.RescaleIntensity(out_min_max=(0,1))(image) 242 | return image 243 | 244 | def tumor_bbx(self, outpath): 245 | print('Generating tumor bbx', self.dataset) 246 | print(outpath) 247 | 248 | patients_list = sorted(os.listdir(outpath)) 249 | for pat in patients_list: 250 | print(pat) 251 | element_list = sorted(os.listdir(os.path.join(outpath, pat))) 252 | for element in element_list: 253 | if "_seg" not in element: continue 254 | patient_seg = torch.load(os.path.join(outpath, pat, element)) 255 | tumor_cord = self.roi_coord(torch.tensor(patient_seg[0])) # x_min, x_max, y_min, y_max, z_min, z_max 256 | 257 | mask = np.zeros(patient_seg.shape[1:]) 258 | mask[tumor_cord[4]:tumor_cord[5]+1, tumor_cord[2]:tumor_cord[3]+1, tumor_cord[0]:tumor_cord[1]+1] = 1 259 | torch.save(mask[None,:,:,:], os.path.join(outpath, pat, pat + '_bbx.pt')) -------------------------------------------------------------------------------- /radiogenomics/models/ra_seg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Ivo Gollini Navarrete 3 | Date: 12/sep/2022 4 | Institution: MBZUAI 5 | 6 | "Attention to recurrence" (A2R) is a method to improve the performance introducing 7 | attention module to the skip connection to help the model focus on the ROI and 8 | recurrent module at the bottom of the UNET architecture to capture interslice continuity. 9 | This method is based on: 10 | - "A Teacher-Student Framework for Semi-supervised Medical Image Segmentation From Mixed Supervision". 11 | - "Lung Cancer Tumor Region Segmentation Using Recurrent 3D-DenseUNet". 12 | 13 | """ 14 | 15 | import warnings 16 | from typing import Optional, Sequence, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | from monai.networks.blocks.convolutions import Convolution, ResidualUnit 23 | from monai.networks.layers.factories import Act, Norm 24 | from monai.networks.layers.simplelayers import SkipConnection 25 | from monai.utils import alias, deprecated_arg, export 26 | from models.utils.convlstm import ConvLSTM 27 | 28 | 29 | __all__ = ["ra_seg", "RA_Seg"] 30 | 31 | 32 | @export("monai.networks.nets") 33 | @alias("Unet") 34 | class RA_Seg(nn.Module): 35 | @deprecated_arg( 36 | name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." 37 | ) 38 | def __init__( 39 | self, 40 | spatial_dims: int, 41 | in_channels: int, 42 | out_channels: int, 43 | channels: Sequence[int], 44 | strides: Sequence[int], 45 | kernel_size: Union[Sequence[int], int] = 3, 46 | up_kernel_size: Union[Sequence[int], int] = 3, 47 | num_res_units: int = 0, 48 | act: Union[Tuple, str] = Act.PRELU, 49 | act2: Union[Tuple, str] = Act.SIGMOID, 50 | norm: Union[Tuple, str] = Norm.INSTANCE, 51 | norm2: Union[Tuple, str] = Norm.BATCH, 52 | dropout: float = 0.1, 53 | bias: bool = True, 54 | adn_ordering: str = "NDA", 55 | dimensions: Optional[int] = None, 56 | ) -> None: 57 | 58 | super().__init__() 59 | 60 | if len(channels) < 2: 61 | raise ValueError("the length of `channels` should be no less than 2.") 62 | delta = len(strides) - (len(channels) - 1) 63 | if delta < 0: 64 | raise ValueError("the length of `strides` should equal to `len(channels) - 1`.") 65 | if delta > 0: 66 | warnings.warn(f"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.") 67 | if dimensions is not None: 68 | spatial_dims = dimensions 69 | if isinstance(kernel_size, Sequence): 70 | if len(kernel_size) != spatial_dims: 71 | raise ValueError("the length of `kernel_size` should equal to `dimensions`.") 72 | if isinstance(up_kernel_size, Sequence): 73 | if len(up_kernel_size) != spatial_dims: 74 | raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.") 75 | 76 | self.dimensions = spatial_dims 77 | self.in_channels = in_channels 78 | self.out_channels = out_channels 79 | self.channels = channels 80 | self.strides = strides 81 | self.kernel_size = kernel_size 82 | self.up_kernel_size = up_kernel_size 83 | self.num_res_units = num_res_units 84 | self.act = act 85 | self.act2 = act2 86 | self.norm = norm 87 | self.norm2 = norm2 88 | self.dropout = dropout 89 | self.bias = bias 90 | self.adn_ordering = adn_ordering 91 | 92 | self.encoder1 = self._get_down_layer(self.in_channels, self.channels[0], self.strides[0], is_top=True) # (1, 64, 2) 93 | self.a2o_encoder1 = self._a2o_conv(self.in_channels, self.channels[0], 1) 94 | 95 | self.encoder2 = self._get_down_layer(self.channels[0], self.channels[1], self.strides[1], is_top=False) # (64, 128, 2) 96 | self.a2o_encoder2 = self._a2o_conv(self.in_channels, self.channels[1], 1) 97 | 98 | self.encoder3 = self._get_down_layer(self.channels[1], self.channels[2], self.strides[2], is_top=False) # (128, 256, 2) 99 | self.a2o_encoder3 = self._a2o_conv(self.in_channels, self.channels[2], 1) 100 | 101 | self.encoder4 = self._get_down_layer(self.channels[2], self.channels[3], self.strides[3], is_top=False) # (256, 512, 2) 102 | self.a2o_encoder4 = self._a2o_conv(self.in_channels, self.channels[3], 1) 103 | 104 | self.bottom = self._get_bottom_layer(self.channels[3], self.channels[4]) # (512, 1024), stride = 1 105 | 106 | self.decoder4 = self._get_up_layer((self.channels[4]+self.channels[3]), self.channels[2], self.strides[3], is_top=False) # (1024+512, 256, 2) 107 | 108 | self.decoder3 = self._get_up_layer(self.channels[2]*2, self.channels[1], self.strides[2], is_top=False) # (512, 128, 2) 109 | 110 | self.decoder2 = self._get_up_layer(self.channels[1]*2, self.channels[0], self.strides[1], is_top=False) # (256, 64, 2) 111 | 112 | self.decoder1 = self._get_up_layer(self.channels[0]*2, self.out_channels, self.strides[0], is_top=True) # (128, 1, 2) -> output 113 | self.activation = nn.Sigmoid() 114 | 115 | def _a2o_conv(self, in_channels: int, out_channels: int, strides: int) -> nn.Module: 116 | mod: nn.Module 117 | mod = Convolution( 118 | self.dimensions, 119 | in_channels, 120 | out_channels, 121 | strides=strides, 122 | kernel_size=self.kernel_size, 123 | act=self.act2, 124 | norm=self.norm2, 125 | dropout = None, 126 | bias=self.bias, 127 | adn_ordering=self.adn_ordering, 128 | ) 129 | return mod 130 | 131 | def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: 132 | """ 133 | Returns the encoding (down) part of a layer of the network. This typically will downsample data at some point 134 | in its structure. Its output is used as input to the next layer down and is concatenated with output from the 135 | next layer to form the input for the decode (up) part of the layer. 136 | 137 | Args: 138 | in_channels: number of input channels. 139 | out_channels: number of output channels. 140 | strides: convolution stride. 141 | is_top: True if this is the top block. 142 | """ 143 | mod: nn.Module 144 | if self.num_res_units > 0: 145 | 146 | mod = ResidualUnit( 147 | self.dimensions, 148 | in_channels, 149 | out_channels, 150 | strides=strides, 151 | kernel_size=self.kernel_size, 152 | subunits=self.num_res_units, 153 | act=self.act, 154 | norm=self.norm, 155 | dropout=self.dropout, 156 | bias=self.bias, 157 | adn_ordering=self.adn_ordering, 158 | ) 159 | return mod 160 | mod = Convolution( 161 | self.dimensions, 162 | in_channels, 163 | out_channels, 164 | strides=strides, 165 | kernel_size=self.kernel_size, 166 | act=self.act, 167 | norm=self.norm, 168 | dropout=self.dropout, 169 | bias=self.bias, 170 | adn_ordering=self.adn_ordering, 171 | ) 172 | return mod 173 | 174 | def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: 175 | """ 176 | Returns the bottom or bottleneck layer at the bottom of the Recurrent network linking encode to decode halves. 177 | 178 | Args: 179 | in_channels: number of input channels. 180 | out_channels: number of output channels. 181 | """ 182 | mod: nn.Module 183 | 184 | mod = ConvLSTM( 185 | input_dim=in_channels, 186 | hidden_dim=[out_channels, out_channels, out_channels], 187 | kernel_size=(3, 3), 188 | num_layers=3, 189 | batch_first=True, 190 | bias=True, 191 | return_all_layers=False, 192 | ) 193 | return mod 194 | 195 | def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: 196 | """ 197 | Returns the decoding (up) part of a layer of the network. This typically will upsample data at some point 198 | in its structure. Its output is used as input to the next layer up. 199 | 200 | Args: 201 | in_channels: number of input channels. 202 | out_channels: number of output channels. 203 | strides: convolution stride. 204 | is_top: True if this is the top block. 205 | """ 206 | conv: Union[Convolution, nn.Sequential] 207 | 208 | conv = Convolution( 209 | self.dimensions, 210 | in_channels, 211 | out_channels, 212 | strides=strides, 213 | kernel_size=self.up_kernel_size, 214 | act=self.act, 215 | norm=self.norm, 216 | dropout=self.dropout, 217 | bias=self.bias, 218 | conv_only=is_top and self.num_res_units == 0, 219 | is_transposed=True, 220 | adn_ordering=self.adn_ordering, 221 | ) 222 | 223 | if self.num_res_units > 0: 224 | ru = ResidualUnit( 225 | self.dimensions, 226 | out_channels, 227 | out_channels, 228 | strides=1, 229 | kernel_size=self.kernel_size, 230 | subunits=1, 231 | act=self.act, 232 | norm=self.norm, 233 | dropout=self.dropout, 234 | bias=self.bias, 235 | last_conv_only=is_top, 236 | adn_ordering=self.adn_ordering, 237 | ) 238 | conv = nn.Sequential(conv, ru) 239 | 240 | return conv 241 | 242 | def forward(self, x: torch.Tensor, organ: torch.Tensor) -> torch.Tensor: 243 | enc1 = self.encoder1(x) 244 | skip1 = torch.add(enc1, torch.mul( 245 | enc1, 246 | self.a2o_encoder1(F.interpolate(organ, scale_factor=0.5, mode='trilinear')) 247 | )) 248 | 249 | enc2 = self.encoder2(enc1) 250 | skip2 = torch.add(enc2, torch.mul( 251 | enc2, 252 | self.a2o_encoder2(F.interpolate(organ, scale_factor=0.25, mode='trilinear')) 253 | )) 254 | 255 | enc3 = self.encoder3(enc2) 256 | skip3 = torch.add(enc3, torch.mul( 257 | enc3, 258 | self.a2o_encoder3(F.interpolate(organ, scale_factor=0.125, mode='trilinear')) 259 | )) 260 | 261 | enc4 = self.encoder4(enc3) 262 | skip4 = torch.add(enc4, torch.mul( 263 | enc4, 264 | self.a2o_encoder4(F.interpolate(organ, scale_factor=0.0625, mode='trilinear')) 265 | )) 266 | 267 | """ 268 | Permute from (batch, features, depth, h, w) to (batch, depth, features, h, w) [B, Time-steps, channels, H, W] 269 | last_state_list = return[0][0] 270 | """ 271 | bottom = self.bottom(enc4.permute(0, 2, 1, 3, 4))[0][0] 272 | 273 | dec4 = self.decoder4(torch.cat((bottom.permute(0, 2, 1, 3, 4), skip4), dim=1)) 274 | 275 | dec3 = self.decoder3(torch.cat((dec4, skip3), dim=1)) 276 | 277 | dec2 = self.decoder2(torch.cat((dec3, skip2), dim=1)) 278 | 279 | dec1 = self.decoder1(torch.cat((dec2, skip1), dim=1)) 280 | 281 | return self.activation(dec1), dec2 282 | 283 | ra_seg = RA_Seg 284 | -------------------------------------------------------------------------------- /radiogenomics/models/unet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import warnings 13 | from typing import Optional, Sequence, Tuple, Union 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from monai.networks.blocks.convolutions import Convolution, ResidualUnit 19 | from monai.networks.layers.factories import Act, Norm 20 | from monai.networks.layers.simplelayers import SkipConnection 21 | from monai.utils import alias, deprecated_arg, export 22 | 23 | __all__ = ["UNet", "Unet"] 24 | 25 | 26 | @export("monai.networks.nets") 27 | @alias("Unet") 28 | class UNet(nn.Module): 29 | """ 30 | Enhanced version of UNet which has residual units implemented with the ResidualUnit class. 31 | The residual part uses a convolution to change the input dimensions to match the output dimensions 32 | if this is necessary but will use nn.Identity if not. 33 | Refer to: https://link.springer.com/chapter/10.1007/978-3-030-12029-0_40. 34 | 35 | Each layer of the network has a encode and decode path with a skip connection between them. Data in the encode path 36 | is downsampled using strided convolutions (if `strides` is given values greater than 1) and in the decode path 37 | upsampled using strided transpose convolutions. These down or up sampling operations occur at the beginning of each 38 | block rather than afterwards as is typical in UNet implementations. 39 | 40 | To further explain this consider the first example network given below. This network has 3 layers with strides 41 | of 2 for each of the middle layers (the last layer is the bottom connection which does not down/up sample). Input 42 | data to this network is immediately reduced in the spatial dimensions by a factor of 2 by the first convolution of 43 | the residual unit defining the first layer of the encode part. The last layer of the decode part will upsample its 44 | input (data from the previous layer concatenated with data from the skip connection) in the first convolution. this 45 | ensures the final output of the network has the same shape as the input. 46 | 47 | Padding values for the convolutions are chosen to ensure output sizes are even divisors/multiples of the input 48 | sizes if the `strides` value for a layer is a factor of the input sizes. A typical case is to use `strides` values 49 | of 2 and inputs that are multiples of powers of 2. An input can thus be downsampled evenly however many times its 50 | dimensions can be divided by 2, so for the example network inputs would have to have dimensions that are multiples 51 | of 4. In the second example network given below the input to the bottom layer will have shape (1, 64, 15, 15) for 52 | an input of shape (1, 1, 240, 240) demonstrating the input being reduced in size spatially by 2**4. 53 | 54 | Args: 55 | spatial_dims: number of spatial dimensions. 56 | in_channels: number of input channels. 57 | out_channels: number of output channels. 58 | channels: sequence of channels. Top block first. The length of `channels` should be no less than 2. 59 | strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`. 60 | kernel_size: convolution kernel size, the value(s) should be odd. If sequence, 61 | its length should equal to dimensions. Defaults to 3. 62 | up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence, 63 | its length should equal to dimensions. Defaults to 3. 64 | num_res_units: number of residual units. Defaults to 0. 65 | act: activation type and arguments. Defaults to PReLU. 66 | norm: feature normalization type and arguments. Defaults to instance norm. 67 | dropout: dropout ratio. Defaults to no dropout. 68 | bias: whether to have a bias term in convolution blocks. Defaults to True. 69 | According to `Performance Tuning Guide `_, 70 | if a conv layer is directly followed by a batch norm layer, bias should be False. 71 | adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D). 72 | Defaults to "NDA". See also: :py:class:`monai.networks.blocks.ADN`. 73 | 74 | Examples:: 75 | 76 | from monai.networks.nets import UNet 77 | 78 | # 3 layer network with down/upsampling by a factor of 2 at each layer with 2-convolution residual units 79 | net = UNet( 80 | spatial_dims=2, 81 | in_channels=1, 82 | out_channels=1, 83 | channels=(4, 8, 16), 84 | strides=(2, 2), 85 | num_res_units=2 86 | ) 87 | 88 | # 5 layer network with simple convolution/normalization/dropout/activation blocks defining the layers 89 | net=UNet( 90 | spatial_dims=2, 91 | in_channels=1, 92 | out_channels=1, 93 | channels=(4, 8, 16, 32, 64), 94 | strides=(2, 2, 2, 2), 95 | ) 96 | 97 | .. deprecated:: 0.6.0 98 | ``dimensions`` is deprecated, use ``spatial_dims`` instead. 99 | 100 | Note: The acceptable spatial size of input data depends on the parameters of the network, 101 | to set appropriate spatial size, please check the tutorial for more details: 102 | https://github.com/Project-MONAI/tutorials/blob/master/modules/UNet_input_size_constrains.ipynb. 103 | Typically, when using a stride of 2 in down / up sampling, the output dimensions are either half of the 104 | input when downsampling, or twice when upsampling. In this case with N numbers of layers in the network, 105 | the inputs must have spatial dimensions that are all multiples of 2^N. 106 | Usually, applying `resize`, `pad` or `crop` transforms can help adjust the spatial size of input data. 107 | 108 | """ 109 | 110 | @deprecated_arg( 111 | name="dimensions", new_name="spatial_dims", since="0.6", msg_suffix="Please use `spatial_dims` instead." 112 | ) 113 | def __init__( 114 | self, 115 | spatial_dims: int, 116 | in_channels: int, 117 | out_channels: int, 118 | channels: Sequence[int], 119 | strides: Sequence[int], 120 | kernel_size: Union[Sequence[int], int] = 3, 121 | up_kernel_size: Union[Sequence[int], int] = 3, 122 | num_res_units: int = 0, 123 | act: Union[Tuple, str] = Act.PRELU, 124 | norm: Union[Tuple, str] = Norm.INSTANCE, 125 | dropout: float = 0.0, 126 | bias: bool = True, 127 | adn_ordering: str = "NDA", 128 | dimensions: Optional[int] = None, 129 | ) -> None: 130 | 131 | super().__init__() 132 | 133 | if len(channels) < 2: 134 | raise ValueError("the length of `channels` should be no less than 2.") 135 | delta = len(strides) - (len(channels) - 1) 136 | if delta < 0: 137 | raise ValueError("the length of `strides` should equal to `len(channels) - 1`.") 138 | if delta > 0: 139 | warnings.warn(f"`len(strides) > len(channels) - 1`, the last {delta} values of strides will not be used.") 140 | if dimensions is not None: 141 | spatial_dims = dimensions 142 | if isinstance(kernel_size, Sequence): 143 | if len(kernel_size) != spatial_dims: 144 | raise ValueError("the length of `kernel_size` should equal to `dimensions`.") 145 | if isinstance(up_kernel_size, Sequence): 146 | if len(up_kernel_size) != spatial_dims: 147 | raise ValueError("the length of `up_kernel_size` should equal to `dimensions`.") 148 | 149 | self.dimensions = spatial_dims 150 | self.in_channels = in_channels 151 | self.out_channels = out_channels 152 | self.channels = channels 153 | self.strides = strides 154 | self.kernel_size = kernel_size 155 | self.up_kernel_size = up_kernel_size 156 | self.num_res_units = num_res_units 157 | self.act = act 158 | self.norm = norm 159 | self.dropout = dropout 160 | self.bias = bias 161 | self.adn_ordering = adn_ordering 162 | 163 | def _create_block( 164 | inc: int, outc: int, channels: Sequence[int], strides: Sequence[int], is_top: bool 165 | ) -> nn.Module: 166 | """ 167 | Builds the UNet structure from the bottom up by recursing down to the bottom block, then creating sequential 168 | blocks containing the downsample path, a skip connection around the previous block, and the upsample path. 169 | 170 | Args: 171 | inc: number of input channels. 172 | outc: number of output channels. 173 | channels: sequence of channels. Top block first. 174 | strides: convolution stride. 175 | is_top: True if this is the top block. 176 | """ 177 | c = channels[0] 178 | s = strides[0] 179 | 180 | subblock: nn.Module 181 | 182 | if len(channels) > 2: 183 | subblock = _create_block(c, c, channels[1:], strides[1:], False) # continue recursion down 184 | upc = c * 2 185 | else: 186 | # the next layer is the bottom so stop recursion, create the bottom layer as the sublock for this layer 187 | subblock = self._get_bottom_layer(c, channels[1]) 188 | upc = c + channels[1] 189 | 190 | down = self._get_down_layer(inc, c, s, is_top) # create layer in downsampling path 191 | up = self._get_up_layer(upc, outc, s, is_top) # create layer in upsampling path 192 | 193 | return self._get_connection_block(down, up, subblock) 194 | 195 | self.model = _create_block(in_channels, out_channels, self.channels, self.strides, True) 196 | 197 | def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: 198 | """ 199 | Returns the block object defining a layer of the UNet structure including the implementation of the skip 200 | between encoding (down) and and decoding (up) sides of the network. 201 | 202 | Args: 203 | down_path: encoding half of the layer 204 | up_path: decoding half of the layer 205 | subblock: block defining the next layer in the network. 206 | Returns: block for this layer: `nn.Sequential(down_path, SkipConnection(subblock), up_path)` 207 | """ 208 | return nn.Sequential(down_path, SkipConnection(subblock), up_path) 209 | 210 | def _get_down_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: 211 | """ 212 | Returns the encoding (down) part of a layer of the network. This typically will downsample data at some point 213 | in its structure. Its output is used as input to the next layer down and is concatenated with output from the 214 | next layer to form the input for the decode (up) part of the layer. 215 | 216 | Args: 217 | in_channels: number of input channels. 218 | out_channels: number of output channels. 219 | strides: convolution stride. 220 | is_top: True if this is the top block. 221 | """ 222 | mod: nn.Module 223 | if self.num_res_units > 0: 224 | 225 | mod = ResidualUnit( 226 | self.dimensions, 227 | in_channels, 228 | out_channels, 229 | strides=strides, 230 | kernel_size=self.kernel_size, 231 | subunits=self.num_res_units, 232 | act=self.act, 233 | norm=self.norm, 234 | dropout=self.dropout, 235 | bias=self.bias, 236 | adn_ordering=self.adn_ordering, 237 | ) 238 | return mod 239 | mod = Convolution( 240 | self.dimensions, 241 | in_channels, 242 | out_channels, 243 | strides=strides, 244 | kernel_size=self.kernel_size, 245 | act=self.act, 246 | norm=self.norm, 247 | dropout=self.dropout, 248 | bias=self.bias, 249 | adn_ordering=self.adn_ordering, 250 | ) 251 | return mod 252 | 253 | def _get_bottom_layer(self, in_channels: int, out_channels: int) -> nn.Module: 254 | """ 255 | Returns the bottom or bottleneck layer at the bottom of the network linking encode to decode halves. 256 | 257 | Args: 258 | in_channels: number of input channels. 259 | out_channels: number of output channels. 260 | """ 261 | return self._get_down_layer(in_channels, out_channels, 1, False) 262 | 263 | def _get_up_layer(self, in_channels: int, out_channels: int, strides: int, is_top: bool) -> nn.Module: 264 | """ 265 | Returns the decoding (up) part of a layer of the network. This typically will upsample data at some point 266 | in its structure. Its output is used as input to the next layer up. 267 | 268 | Args: 269 | in_channels: number of input channels. 270 | out_channels: number of output channels. 271 | strides: convolution stride. 272 | is_top: True if this is the top block. 273 | """ 274 | conv: Union[Convolution, nn.Sequential] 275 | 276 | conv = Convolution( 277 | self.dimensions, 278 | in_channels, 279 | out_channels, 280 | strides=strides, 281 | kernel_size=self.up_kernel_size, 282 | act=self.act, 283 | norm=self.norm, 284 | dropout=self.dropout, 285 | bias=self.bias, 286 | conv_only=is_top and self.num_res_units == 0, 287 | is_transposed=True, 288 | adn_ordering=self.adn_ordering, 289 | ) 290 | 291 | if self.num_res_units > 0: 292 | ru = ResidualUnit( 293 | self.dimensions, 294 | out_channels, 295 | out_channels, 296 | strides=1, 297 | kernel_size=self.kernel_size, 298 | subunits=1, 299 | act=self.act, 300 | norm=self.norm, 301 | dropout=self.dropout, 302 | bias=self.bias, 303 | last_conv_only=is_top, 304 | adn_ordering=self.adn_ordering, 305 | ) 306 | conv = nn.Sequential(conv, ru) 307 | 308 | return conv 309 | 310 | def forward(self, x: torch.Tensor) -> torch.Tensor: 311 | x = self.model(x) 312 | return x 313 | 314 | 315 | Unet = UNet 316 | --------------------------------------------------------------------------------