├── .gitattributes ├── .gitignore ├── README.md ├── configs ├── __init__.py ├── bav_config.py ├── coord_config.py ├── deeplab_config.py ├── fcn_config.py ├── i_config.py ├── ibav_config.py ├── l_config.py ├── lbav_config.py └── unet_config.py ├── data ├── __init__.py ├── dataset.py └── transforms.py ├── evaluate.py ├── figures └── network.png ├── models ├── __init__.py ├── conv.py ├── deeplab.py ├── fcn.py ├── implicit_autoencoder.py ├── losses.py ├── resnet18.py └── unet.py ├── predict.py ├── predict_unet.py ├── sample_data ├── inference │ ├── pulse_00002.npz │ └── test.csv └── training │ ├── pulse_00002.npz │ └── train.csv ├── train.py ├── train_unet.py └── utils ├── __init__.py ├── logger.py └── metrics.py /.gitattributes: -------------------------------------------------------------------------------- 1 | sample_data/inference/pulse_00002.npz filter=lfs diff=lfs merge=lfs -text 2 | sample_data/inference/test.csv filter=lfs diff=lfs merge=lfs -text 3 | sample_data/training/pulse_00002.npz filter=lfs diff=lfs merge=lfs -text 4 | sample_data/training/train.csv filter=lfs diff=lfs merge=lfs -text 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | 154 | # scripts 155 | scripts/ 156 | 157 | # vscode 158 | .vscode/ 159 | logs/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Note the extended version is at [v2 branch](https://github.com/M3DV/ImPulSe/tree/impulse_V2), which will be finalized soon. 2 | 3 | # MICCAI 2022: What Makes for Automatic Reconstruction of Pulmonary Segments 4 | 5 | This is the official implementation of the MICCAI 2022 paper: *What Makes for Automatic Reconstruction of Pulmonary Segments*, created by Kaiming Kuang\*, Li Zhang\*, Jingyu Li, Hongwei Li, Jiajun Chen, Bo Du and Jiancheng Yang\*\*. 6 | 7 | \*: Equal contributions. 8 | \*\*: Corresponding author. 9 | 10 | ![Network figure](./figures/network.png) 11 | 12 | # Usage 13 | 14 | ## Data Preparation 15 | Since the ImPulSe model and the related dataset are restricted to the proprietary rights of Dianei Technologies, we may not open-source the trained model or the training dataset. However, we provide a sample dataset in `sample_data/`, which you should follow before running our scripts. 16 | 17 | The sample training data is in `sample_data/training`. There should be multiple `.npz` files that contain the training images and labels. In each `.npz` file, there are multiple keys: 18 | - airway: A 3D binary airway mask reshaped as `128×128×128`, cropped around the lung area. 19 | - artery: A 3D binary artery mask reshaped as `128×128×128`, cropped around the lung area. 20 | - vein: A 3D binary vein mask reshaped as `128×128×128`, cropped around the lung area. 21 | - lungsegment: A 3D integer lung segment mask of the original shape of the CT image, cropped around the lung area. There should be 19 classes in total, including the background class (0) and 18 lung segments (1-18). 22 | - image: A 3D CT image reshaped as `128×128×128`, cropped around the lung area. The original voxel values are kept. 23 | - lobe: A 3D integer lung lobe mask reshaped as `128×128×128`, cropped around the lung area. There should be 6 classes in total, including the background class (0) and 5 lung lobes (1-5). 24 | 25 | The sample test data is in `sample_data/inference`, which follows similar rules as the training dataset. However, there are two major differences to notice: 26 | - All masks and images in the test data should be in the original shape of the CT image. 27 | - An additional key `lung_bbox` is included in each `.npz` file, which indicates the lung bounding box as a matrix in shape of `3×2`. For example: 28 | ``` 29 | data["lung_bbox"] 30 | array([[ 5, 170], 31 | [176, 409], 32 | [ 87, 442]], dtype=int32) 33 | ``` 34 | means that the lung bounding box starts from `5, 176, 87` and ends at `170, 409, 442` in x, y, z axes, respectively. 35 | 36 | Both training and test folder should include a `.csv` file as well. For example: 37 | ``` 38 | pid,subset 39 | pulse_00001,train 40 | pulse_00002,val 41 | ``` 42 | This `.csv` file should include two columns: `pid` and `subset`. In the training `.csv` file, the subset column should be `train` or `val`. In the test `.csv` file, the subset column should be `val` or `test`. 43 | 44 | ## ImPulSe Training 45 | To train the proposed ImPulSe model, run: 46 | ```bash 47 | python train.py --cfg=ibav --data_dir=sample_data/training --df_path=sample_data/training/train.csv --log_dir=logs 48 | ``` 49 | The `--data_dir` argument indicates the training data folder, and the `--df_path` argument indicates the training csv file path. The `--log_dir` argument indicates the logging directory, where the trained model weights are saved. 50 | 51 | ## ImPulSe Inference 52 | To run inference with the trained ImPulSe, run: 53 | ```bash 54 | python predict.py --cfg=ibav --data_dir=/data/directory --df_path=/data/info/path --weight_path=/path/to/trained/model --output_dir=/prediction/output/directory 55 | ``` 56 | The `--data_dir` argument indicates the training data folder, and the `--df_path` argument indicates the training csv file path. The `--weight_path` argument indicates the trained weight path. The `--output_dir` is where the output predictions are saved. 57 | 58 | # Citation 59 | 60 | If you find our work useful, please consider citing as follows: 61 | ``` 62 | @inproceedings{Kuang2022WhatMF, 63 | title={What Makes for Automatic Reconstruction of Pulmonary Segments}, 64 | author={Kaiming Kuang and Li Zhang and Jingyu Li and Hongwei Li and Jiajun Chen and Bo Du and Jiancheng Yang}, 65 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 66 | year={2022} 67 | } 68 | ``` 69 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HINTLab/ImPulSe/c2f5ccd63651d2b7d5dd99a5f8870b176fed525b/configs/__init__.py -------------------------------------------------------------------------------- /configs/bav_config.py: -------------------------------------------------------------------------------- 1 | # preprocessing configs 2 | in_res = (128, 128, 128) 3 | out_res = (16, 16, 16) 4 | win_min, win_max = -1000, 400 5 | input_keys = ["airway", "artery", "vein"] 6 | resample_configs = { 7 | "image": {"size": in_res, "interp": "linear"}, 8 | "airway": {"size": in_res, "interp": "nearest"}, 9 | "artery": {"size": in_res, "interp": "nearest"}, 10 | "vein": {"size": in_res, "interp": "nearest"}, 11 | "lobe": {"size": in_res, "interp": "nearest"} 12 | } 13 | 14 | # model configs 15 | enc_cfgs = { 16 | "in_channels": 3 17 | } 18 | dec_cfgs = { 19 | "num_channels": [963, 64], 20 | "num_classes": 19, 21 | "num_layers": 1, 22 | "drop_prob": 0.3 23 | } 24 | 25 | # training configs 26 | w_ce = 0.5 27 | w_dice = 1 28 | batch_size = 8 29 | num_workers = 4 30 | max_lr = 1e-3 31 | min_lr = 1e-6 32 | epochs = 50 33 | 34 | # evaluation configs 35 | eval_freq = 5 36 | eval_res = (128, 128, 128) 37 | eval_batch_size = 4 38 | 39 | # inference configs 40 | weights_path = "/media/dntech/_mnt_storage/kaiming/data/lung_segment/logs/bav/model_44.pth" 41 | infer_batch_size = 1 42 | window_size = (96, 96, 96) 43 | -------------------------------------------------------------------------------- /configs/coord_config.py: -------------------------------------------------------------------------------- 1 | # preprocessing configs 2 | in_res = (128, 128, 128) 3 | out_res = in_res 4 | win_min, win_max = -1000, 400 5 | input_keys = ["grids"] 6 | resample_cfgs = { 7 | "image": {"size": out_res, "interp": "linear"}, 8 | "lungsegment": {"size": out_res, "interp": "nearest"}, 9 | } 10 | 11 | # model configs 12 | enc_cfgs = { 13 | "in_channels": 3 14 | } 15 | dec_cfgs = { 16 | "num_channels": [512, 256, 128, 64], 17 | "num_layers": 1, 18 | "num_classes": 19 19 | } 20 | 21 | # training configs 22 | batch_size = 2 23 | num_workers = 4 24 | max_lr = 1e-3 25 | min_lr = 1e-6 26 | epochs = 50 27 | w_ce = 0.5 28 | w_dice = 1 29 | 30 | # evaluation configs 31 | eval_freq = 5 32 | eval_res = (256, 256, 256) 33 | eval_batch_size = 1 34 | window_size = (128, 128, 128) 35 | weights_path = "/media/dntech/_mnt_storage/kaiming/data/lung_segment/logs/coord/model_49.pth" 36 | -------------------------------------------------------------------------------- /configs/deeplab_config.py: -------------------------------------------------------------------------------- 1 | # preprocessing configs 2 | in_res = (128, 128, 128) 3 | out_res = in_res 4 | win_min, win_max = -1000, 400 5 | input_keys = ["image"] 6 | resample_cfgs = { 7 | "image": {"size": out_res, "interp": "linear"}, 8 | "lungsegment": {"size": out_res, "interp": "nearest"}, 9 | } 10 | 11 | # model configs 12 | enc_cfgs = { 13 | "in_channels": 1 14 | } 15 | dec_cfgs = { 16 | "in_channels": 512, 17 | "atrous_rates": [2, 4, 6], 18 | "drop_prob": 0, 19 | "out_channels": 256, 20 | "num_classes": 19, 21 | "scale_factor": 8 22 | } 23 | 24 | # training configs 25 | batch_size = 4 26 | num_workers = 4 27 | max_lr = 1e-3 28 | min_lr = 1e-6 29 | epochs = 50 30 | w_ce = 0.5 31 | w_dice = 1 32 | 33 | # evaluation configs 34 | eval_freq = 5 35 | eval_res = (256, 256, 256) 36 | eval_batch_size = 1 37 | window_size = (128, 128, 128) 38 | weights_path = "/media/dntech/_mnt_storage/kaiming/data/lung_segment/logs/deeplab/model_44.pth" 39 | -------------------------------------------------------------------------------- /configs/fcn_config.py: -------------------------------------------------------------------------------- 1 | # preprocessing configs 2 | in_res = (128, 128, 128) 3 | out_res = in_res 4 | win_min, win_max = -1000, 400 5 | input_keys = ["image"] 6 | resample_cfgs = { 7 | "image": {"size": out_res, "interp": "linear"}, 8 | "lungsegment": {"size": out_res, "interp": "nearest"}, 9 | } 10 | 11 | # model configs 12 | enc_cfgs = { 13 | "in_channels": 1 14 | } 15 | dec_cfgs = { 16 | "in_channels": 512, 17 | "scale_factor": 8, 18 | "num_classes": 19, 19 | } 20 | 21 | # training configs 22 | batch_size = 4 23 | num_workers = 4 24 | max_lr = 1e-3 25 | min_lr = 1e-6 26 | epochs = 50 27 | w_ce = 0.5 28 | w_dice = 1 29 | 30 | # evaluation configs 31 | eval_freq = 5 32 | eval_res = (256, 256, 256) 33 | eval_batch_size = 1 34 | window_size = (128, 128, 128) 35 | weights_path = "/media/dntech/_mnt_storage/kaiming/data/lung_segment/logs/fcn/model_49.pth" 36 | -------------------------------------------------------------------------------- /configs/i_config.py: -------------------------------------------------------------------------------- 1 | # preprocessing configs 2 | in_res = (128, 128, 128) 3 | out_res = (16, 16, 16) 4 | win_min, win_max = -1000, 400 5 | input_keys = ["image"] 6 | resample_configs = { 7 | "image": {"size": in_res, "interp": "linear"}, 8 | "airway": {"size": in_res, "interp": "nearest"}, 9 | "artery": {"size": in_res, "interp": "nearest"}, 10 | "vein": {"size": in_res, "interp": "nearest"}, 11 | "lobe": {"size": in_res, "interp": "nearest"} 12 | } 13 | 14 | # model configs 15 | enc_cfgs = { 16 | "in_channels": 1 17 | } 18 | dec_cfgs = { 19 | "num_channels": [963, 64], 20 | "num_classes": 19, 21 | "num_layers": 1, 22 | "drop_prob": 0.3 23 | } 24 | 25 | # training configs 26 | w_ce = 0.5 27 | w_dice = 1 28 | batch_size = 8 29 | num_workers = 4 30 | max_lr = 1e-3 31 | min_lr = 1e-6 32 | epochs = 50 33 | 34 | # evaluation configs 35 | eval_freq = 5 36 | eval_res = (128, 128, 128) 37 | eval_batch_size = 4 38 | 39 | # inference configs 40 | weights_path = "/media/dntech/_mnt_storage/kaiming/data/lung_segment/logs/i/model_44.pth" 41 | infer_batch_size = 1 42 | window_size = (96, 96, 96) 43 | -------------------------------------------------------------------------------- /configs/ibav_config.py: -------------------------------------------------------------------------------- 1 | # preprocessing configs 2 | in_res = (128, 128, 128) 3 | out_res = (16, 16, 16) 4 | win_min, win_max = -1000, 400 5 | input_keys = ["airway", "artery", "vein", "image"] 6 | resample_configs = { 7 | "image": {"size": in_res, "interp": "linear"}, 8 | "airway": {"size": in_res, "interp": "nearest"}, 9 | "artery": {"size": in_res, "interp": "nearest"}, 10 | "vein": {"size": in_res, "interp": "nearest"}, 11 | "lobe": {"size": in_res, "interp": "nearest"} 12 | } 13 | 14 | # model configs 15 | enc_cfgs = { 16 | "in_channels": 4 17 | } 18 | dec_cfgs = { 19 | "num_channels": [963, 64], 20 | "num_classes": 19, 21 | "num_layers": 1, 22 | "drop_prob": 0.3 23 | } 24 | 25 | # training configs 26 | w_ce = 0.5 27 | w_dice = 1 28 | batch_size = 8 29 | num_workers = 4 30 | max_lr = 1e-3 31 | min_lr = 1e-6 32 | epochs = 50 33 | 34 | # evaluation configs 35 | eval_freq = 5 36 | eval_res = (128, 128, 128) 37 | eval_batch_size = 4 38 | 39 | # inference configs 40 | weights_path = "/media/dntech/_mnt_storage/kaiming/data/lung_segment/logs/ibav/model_49.pth" 41 | infer_batch_size = 1 42 | window_size = (96, 96, 96) 43 | -------------------------------------------------------------------------------- /configs/l_config.py: -------------------------------------------------------------------------------- 1 | # preprocessing configs 2 | in_res = (128, 128, 128) 3 | out_res = (16, 16, 16) 4 | win_min, win_max = -1000, 400 5 | input_keys = ["lobe"] 6 | resample_configs = { 7 | "image": {"size": in_res, "interp": "linear"}, 8 | "airway": {"size": in_res, "interp": "nearest"}, 9 | "artery": {"size": in_res, "interp": "nearest"}, 10 | "vein": {"size": in_res, "interp": "nearest"}, 11 | "lobe": {"size": in_res, "interp": "nearest"} 12 | } 13 | 14 | # model configs 15 | enc_cfgs = { 16 | "in_channels": 6 17 | } 18 | dec_cfgs = { 19 | "num_channels": [963, 64], 20 | "num_classes": 19, 21 | "num_layers": 1, 22 | "drop_prob": 0.3 23 | } 24 | 25 | # training configs 26 | w_ce = 0.5 27 | w_dice = 1 28 | batch_size = 8 29 | num_workers = 4 30 | max_lr = 1e-3 31 | min_lr = 1e-6 32 | epochs = 50 33 | 34 | # evaluation configs 35 | eval_freq = 5 36 | eval_res = (128, 128, 128) 37 | eval_batch_size = 4 38 | 39 | # inference configs 40 | weights_path = "/media/dntech/_mnt_storage/kaiming/data/lung_segment/logs/l/model_49.pth" 41 | infer_batch_size = 1 42 | window_size = (96, 96, 96) 43 | -------------------------------------------------------------------------------- /configs/lbav_config.py: -------------------------------------------------------------------------------- 1 | # preprocessing configs 2 | in_res = (128, 128, 128) 3 | out_res = (16, 16, 16) 4 | win_min, win_max = -1000, 400 5 | input_keys = ["airway", "artery", "vein", "lobe"] 6 | resample_configs = { 7 | "image": {"size": in_res, "interp": "linear"}, 8 | "airway": {"size": in_res, "interp": "nearest"}, 9 | "artery": {"size": in_res, "interp": "nearest"}, 10 | "vein": {"size": in_res, "interp": "nearest"}, 11 | "lobe": {"size": in_res, "interp": "nearest"} 12 | } 13 | 14 | # model configs 15 | enc_cfgs = { 16 | "in_channels": 9 17 | } 18 | dec_cfgs = { 19 | "num_channels": [963, 64], 20 | "num_classes": 19, 21 | "num_layers": 1, 22 | "drop_prob": 0.3 23 | } 24 | 25 | # training configs 26 | w_ce = 0.5 27 | w_dice = 1 28 | batch_size = 8 29 | num_workers = 4 30 | max_lr = 1e-3 31 | min_lr = 1e-6 32 | epochs = 50 33 | 34 | # evaluation configs 35 | eval_freq = 5 36 | eval_res = (128, 128, 128) 37 | eval_batch_size = 4 38 | 39 | # inference configs 40 | weights_path = "/media/dntech/_mnt_storage/kaiming/data/lung_segment/logs/lbav/model_44.pth" 41 | infer_batch_size = 1 42 | window_size = (96, 96, 96) 43 | -------------------------------------------------------------------------------- /configs/unet_config.py: -------------------------------------------------------------------------------- 1 | # preprocessing configs 2 | in_res = (128, 128, 128) 3 | out_res = in_res 4 | win_min, win_max = -1000, 400 5 | input_keys = ["image"] 6 | resample_cfgs = { 7 | "image": {"size": out_res, "interp": "linear"}, 8 | "lungsegment": {"size": out_res, "interp": "nearest"}, 9 | } 10 | 11 | # model configs 12 | enc_cfgs = { 13 | "in_channels": 1 14 | } 15 | dec_cfgs = { 16 | "num_channels": [512, 256, 128, 64], 17 | "num_layers": 1, 18 | "num_classes": 19 19 | } 20 | 21 | # training configs 22 | batch_size = 4 23 | num_workers = 4 24 | max_lr = 1e-3 25 | min_lr = 1e-6 26 | epochs = 50 27 | w_ce = 0.5 28 | w_dice = 1 29 | 30 | # evaluation configs 31 | eval_freq = 5 32 | eval_res = (256, 256, 256) 33 | eval_batch_size = 1 34 | window_size = (128, 128, 128) 35 | weights_path = "/media/dntech/_mnt_storage/kaiming/data/lung_segment/logs/unet/model_29.pth" 36 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HINTLab/ImPulSe/c2f5ccd63651d2b7d5dd99a5f8870b176fed525b/data/__init__.py -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader, Dataset 6 | 7 | 8 | def _collate_train(samples): 9 | inputs = torch.stack([x["inputs"] for x in samples]) 10 | targets = torch.cat([x["targets"] for x in samples]).squeeze_(dim=1) 11 | grids = torch.stack([x["grids"] for x in samples]) 12 | 13 | return inputs, targets, grids 14 | 15 | 16 | def _collate_infer(samples): 17 | inputs = torch.stack([x["inputs"] for x in samples]) 18 | grids = [x["grids"] for x in samples] 19 | pids = [x["pid"] for x in samples] 20 | bboxes = np.stack([x["lung_bbox"] for x in samples]) 21 | shapes = [x["raw_res"] for x in samples] 22 | 23 | return inputs, grids, pids, bboxes, shapes 24 | 25 | 26 | def _collate_unet(samples): 27 | inputs = torch.stack([x["inputs"] for x in samples]) 28 | targets = torch.stack([x["targets"] for x in samples]) 29 | 30 | return inputs, targets 31 | 32 | 33 | def _collate_unet_infer(samples): 34 | inputs = torch.stack([x["inputs"] for x in samples]) 35 | targets = torch.stack([x["targets"] for x in samples]) 36 | pids = [x["pid"] for x in samples] 37 | bboxes = np.stack([x["lung_bbox"] for x in samples]) 38 | shapes = [x["raw_res"] for x in samples] 39 | 40 | return inputs, targets, pids, bboxes, shapes 41 | 42 | 43 | class LungSegmentDataset(Dataset): 44 | 45 | def __init__(self, df, data_dir, transforms, subset): 46 | self.df = df.loc[df.subset == subset].reset_index(drop=True) 47 | self.data_dir = data_dir 48 | self.transforms = transforms 49 | 50 | def __len__(self): 51 | return len(self.df) 52 | 53 | def _apply_transforms(self, data): 54 | for t in self.transforms: 55 | data = t(data) 56 | 57 | return data 58 | 59 | def __getitem__(self, idx): 60 | data_path = os.path.join(self.data_dir, f"{self.df.pid[idx]}.npz") 61 | data = np.load(data_path) 62 | data = {k: v for k, v in data.items()} 63 | data["pid"] = self.df.pid[idx] 64 | 65 | data = self._apply_transforms(data) 66 | 67 | return data 68 | 69 | @staticmethod 70 | def get_dataloader(dataset, batch_size, shuffle=False, num_workers=0, 71 | mode="train"): 72 | if mode == "train": 73 | return DataLoader(dataset, batch_size, shuffle, 74 | num_workers=num_workers, 75 | collate_fn=_collate_train) 76 | elif mode == "infer": 77 | return DataLoader(dataset, batch_size, shuffle, 78 | num_workers=num_workers, 79 | collate_fn=_collate_infer) 80 | elif mode == "unet": 81 | return DataLoader(dataset, batch_size, shuffle, 82 | num_workers=num_workers, 83 | collate_fn=_collate_unet) 84 | elif mode == "unet_infer": 85 | return DataLoader(dataset, batch_size, shuffle, 86 | num_workers=num_workers, 87 | collate_fn=_collate_unet_infer) 88 | else: 89 | raise ValueError(f"Unrecognized mode {mode}.") 90 | -------------------------------------------------------------------------------- /data/transforms.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | import SimpleITK as sitk 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.distributions.uniform import Uniform 8 | 9 | 10 | INTERPOLATIONS = { 11 | "nearest": sitk.sitkNearestNeighbor, 12 | "linear": sitk.sitkLinear, 13 | "bspline": sitk.sitkBSpline, 14 | } 15 | 16 | 17 | def _resample(image, target_spacing, target_shape, interpolation): 18 | # set up resampling parameters 19 | resampler = sitk.ResampleImageFilter() 20 | resampler.SetInterpolator(INTERPOLATIONS[interpolation]) 21 | resampler.SetOutputSpacing(target_spacing) 22 | resampler.SetSize(target_shape) 23 | resampler.SetOutputOrigin(image.GetOrigin()) 24 | resampler.SetOutputDirection(image.GetDirection()) 25 | 26 | # execute the resampling 27 | image = resampler.Execute(image) 28 | 29 | return image 30 | 31 | 32 | def _resample_to_shape(arr, target_shape, interpolation): 33 | # convert np.ndarray to sitk.Image 34 | original_type = arr.dtype 35 | image = sitk.GetImageFromArray(arr.astype(float)) 36 | 37 | # calculate the target spacing, assuming the original spacing is 1x1x1 38 | target_spacing = tuple([arr.shape[i] / target_shape[i] 39 | for i in range(len(target_shape))]) 40 | 41 | # reverse spacing and shape to xyz format 42 | target_spacing = tuple(reversed(target_spacing)) 43 | target_shape = tuple(reversed(target_shape)) 44 | 45 | # resampling 46 | image = _resample(image, target_spacing, target_shape, interpolation) 47 | 48 | # convert sitk.Image back to np.ndarray 49 | new_arr = sitk.GetArrayFromImage(image).astype(original_type) 50 | 51 | return new_arr 52 | 53 | 54 | class CropLung: 55 | 56 | def __call__(self, data): 57 | bbox = data["lung_bbox"] 58 | data["raw_res"] = data["image"].shape 59 | data["raw_lungsegment"] = deepcopy(data["lungsegment"]) 60 | data["raw_airway"] = deepcopy(data["airway"]) 61 | data["raw_artery"] = deepcopy(data["artery"]) 62 | data["raw_vein"] = deepcopy(data["vein"]) 63 | keys = ["image", "airway", "artery", "vein", "lungsegment"] 64 | for k in keys: 65 | data[k] = data[k][ 66 | bbox[0, 0]:bbox[0, 1] + 1, 67 | bbox[1, 0]:bbox[1, 1] + 1, 68 | bbox[2, 0]:bbox[2, 1] + 1 69 | ] 70 | 71 | return data 72 | 73 | 74 | class GetLobe: 75 | 76 | def __init__(self): 77 | self.lobe_mapping = { 78 | (1, 3): 1, 79 | (4, 5): 2, 80 | (6, 9): 3, 81 | (10, 13): 4, 82 | (14, 17): 5 83 | } 84 | 85 | def __call__(self, data): 86 | if "lobe" not in data: 87 | data["lobe"] = np.zeros_like(data["lungsegment"]) 88 | for seg_rng, lobe_idx in self.lobe_mapping.items(): 89 | data["lobe"][ 90 | np.logical_and(data["lungsegment"] >= seg_rng[0], 91 | data["lungsegment"] <= seg_rng[1]) 92 | ] = lobe_idx 93 | 94 | return data 95 | 96 | 97 | class Resample: 98 | 99 | def __init__(self, configs): 100 | self.configs = configs 101 | 102 | def __call__(self, data): 103 | data["original_res"] = data["image"].shape 104 | for k in self.configs.keys(): 105 | if k in ["airway", "artery"]: 106 | data[f"original_{k}"] = deepcopy(data[k]) 107 | data[k] = _resample_to_shape(data[k], self.configs[k]["size"], 108 | self.configs[k]["interp"]) 109 | 110 | return data 111 | 112 | 113 | class ClipValue: 114 | 115 | def __init__(self, min_val, max_val): 116 | self.min_val = min_val 117 | self.max_val = max_val 118 | 119 | def __call__(self, data): 120 | data["image"] = np.clip(data["image"], self.min_val, self.max_val) 121 | 122 | return data 123 | 124 | 125 | class MinMaxNormalize: 126 | 127 | def __init__(self, min_val, max_val): 128 | self.min_val = min_val 129 | self.max_val = max_val 130 | 131 | def __call__(self, data): 132 | data["image"] = (data["image"] - self.min_val)\ 133 | / (self.max_val - self.min_val) 134 | 135 | return data 136 | 137 | 138 | class OnehotEncode: 139 | 140 | def __init__(self, key, num_classes): 141 | self.key = key 142 | self.num_classes = num_classes 143 | 144 | def __call__(self, data): 145 | d, h, w = data[self.key].shape 146 | flattened = data[self.key].reshape(-1) 147 | onehot_flattened = np.eye(self.num_classes)[flattened] 148 | data[self.key] = onehot_flattened.reshape((self.num_classes, d, h, w)) 149 | 150 | return data 151 | 152 | 153 | class ConcatInputs: 154 | 155 | def __init__(self, keys): 156 | self.keys = keys 157 | 158 | def __call__(self, data): 159 | if "grids" in self.keys: 160 | data["grids"] = data["grids"].permute(3, 0, 1, 2) 161 | data["inputs"] = np.concatenate([data[k] if data[k].ndim == 4 162 | else data[k][np.newaxis] for k in self.keys]) 163 | 164 | return data 165 | 166 | 167 | class SampleGrid: 168 | 169 | def __init__(self, resolution=None, mode="regular"): 170 | self.resolution = resolution 171 | if resolution is None: 172 | self.reg_grid = None 173 | else: 174 | axes = [torch.linspace(-1, 1, resolution[i]) for i in range(3)] 175 | self.reg_grid = torch.stack(torch.meshgrid(axes[2], axes[1], 176 | axes[0], indexing="ij"), -1) 177 | 178 | assert mode in ["regular", "perturbed", "random", "weighted"] 179 | self.mode = mode 180 | if self.mode == "perturbed": 181 | self.sampler = Uniform(-(2 / resolution), 2 / resolution) 182 | elif self.mode == "random": 183 | self.sampler = Uniform(-1, 1) 184 | elif self.mode == "weighted": 185 | self.random_sampler = Uniform(-1, 1) 186 | 187 | def __call__(self, data): 188 | if self.resolution is None: 189 | resolution = data["lungsegment"].shape 190 | axes = [torch.linspace(-1, 1, resolution[i]) for i in range(3)] 191 | self.reg_grid = torch.stack(torch.meshgrid(axes[2], axes[1], 192 | axes[0], indexing="ij"), -1) 193 | if self.mode == "regular": 194 | data["grids"] = self.reg_grid 195 | elif self.mode == "perturbed": 196 | pertubations = self.sampler.sample(self.reg_grid.size()) 197 | data["grids"] = self.reg_grid + pertubations 198 | elif self.mode == "weighted": 199 | n_pts = self.reg_grid.nelement() // 3 200 | rnd_coords = self.random_sampler.sample((n_pts // 2, 3)) 201 | lungseg_mask = torch.from_numpy(data["lungsegment"] > 0)\ 202 | .permute(2, 1, 0).reshape(-1) 203 | axes = [torch.linspace(-1, 1, data["lungsegment"].shape[i]) 204 | for i in range(3)] 205 | full_grid = torch.stack(torch.meshgrid(axes[2], axes[1], 206 | axes[0], indexing="ij"), -1) 207 | lungseg_coords = full_grid.view(-1, 3)[lungseg_mask, :] 208 | lungseg_coords = lungseg_coords[torch.randint(0, 209 | lungseg_coords.size(0), (n_pts // 2, ))] 210 | data["grids"] = torch.cat((rnd_coords, lungseg_coords), dim=0) 211 | d, h, w = self.resolution 212 | data["grids"] = data["grids"].view(w, h, d, 3) 213 | else: 214 | data["grids"] = self.sampler.sample(self.reg_grid.size()) 215 | 216 | return data 217 | 218 | 219 | class SampleTarget: 220 | 221 | def __init__(self, key, interpolation="nearest"): 222 | self.key = key 223 | self.interpolation = interpolation 224 | 225 | def __call__(self, data): 226 | data[self.key] = torch.from_numpy(data[self.key]).float() 227 | if data["grids"].size()[-2:0:-1] == data[self.key].size(): 228 | data["targets"] = data[self.key][None, None] 229 | else: 230 | data["grids"] = torch.flip(data["grids"], dims=(-1,)) 231 | data["targets"] = F.grid_sample(data[self.key][None, None], 232 | data["grids"][None], self.interpolation, align_corners=True) 233 | # data["targets"] = data["targets"].permute(0, 1, 4, 3, 2) 234 | data["grids"] = torch.flip(data["grids"], dims=(-1,)) 235 | 236 | return data 237 | 238 | 239 | class ToTensor: 240 | 241 | def __call__(self, data): 242 | data["inputs"] = torch.from_numpy(data["inputs"]).float() 243 | if data.get("targets") is not None: 244 | data["targets"] = data["targets"].long() 245 | else: 246 | data["targets"] = torch.from_numpy(data["lungsegment"]).long() 247 | 248 | return data 249 | 250 | 251 | class SetLobeTarget: 252 | 253 | def __call__(self, data): 254 | data["targets"] = torch.from_numpy(data["lobe"]) 255 | 256 | return data 257 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from dntk import medimg 7 | from tqdm import tqdm 8 | 9 | from utils.metrics import foreground_dice_score 10 | 11 | 12 | def _parse_cmd_args(): 13 | arg_parser = ArgumentParser() 14 | arg_parser.add_argument("--gt_dir", required=True, 15 | help="Ground-truth directory.") 16 | arg_parser.add_argument("--pred_dir", required=True, 17 | help="Prediction directory.") 18 | arg_parser.add_argument("--df_path", required=True, 19 | help="Data info csv path.") 20 | args = arg_parser.parse_args() 21 | 22 | return args 23 | 24 | 25 | def _evaluate(pid, gt_dir, pred_dir): 26 | gt_path = os.path.join(gt_dir, pid, f"{pid}_lungsegment.nii.gz") 27 | pred_path = os.path.join(pred_dir, f"{pid}_pred.nii.gz") 28 | airway_path = os.path.join(gt_dir, pid, f"{pid}_airway.nii.gz") 29 | artery_path = os.path.join(gt_dir, pid, f"{pid}_artery.nii.gz") 30 | vein_path = os.path.join(gt_dir, pid, f"{pid}_vein.nii.gz") 31 | interseg_path = os.path.join(gt_dir, pid, f"{pid}_interseg.nii.gz") 32 | gt = medimg.read_image(gt_path).array 33 | pred = medimg.read_image(pred_path).array 34 | airway = medimg.read_image(airway_path).array 35 | artery = medimg.read_image(artery_path).array 36 | vein = medimg.read_image(vein_path).array 37 | interseg_vein = medimg.read_image(interseg_path).array 38 | intraseg_vein = np.logical_and(vein > 0, interseg_vein == 0) 39 | 40 | total_dice = foreground_dice_score(gt, pred, 18) 41 | airway_dice = foreground_dice_score(gt[airway > 0], pred[airway > 0], 18) 42 | artery_dice = foreground_dice_score(gt[artery > 0], pred[artery > 0], 18) 43 | vein_dice = foreground_dice_score(gt[vein > 0], pred[vein > 0], 18) 44 | interseg_vein_dice = foreground_dice_score(gt[interseg_vein > 0], 45 | pred[interseg_vein > 0], 18) 46 | intraseg_vein_dice = foreground_dice_score(gt[intraseg_vein], 47 | pred[intraseg_vein], 18) 48 | result = { 49 | "pid": pid, 50 | "total": total_dice, 51 | "artery": artery_dice, 52 | "airway": airway_dice, 53 | "vein": vein_dice, 54 | "interseg_vein": interseg_vein_dice, 55 | "intraseg_vein": intraseg_vein_dice 56 | } 57 | 58 | return result 59 | 60 | 61 | def main(): 62 | args = _parse_cmd_args() 63 | 64 | info = pd.read_csv(args.df_path) 65 | info_val = info.loc[info.subset == "val"].reset_index(drop=True) 66 | info_test = info.loc[info.subset == "test"].reset_index(drop=True) 67 | 68 | cols = [ 69 | "total", 70 | "airway", 71 | "artery", 72 | "vein", 73 | "interseg_vein", 74 | "intraseg_vein" 75 | ] 76 | 77 | result_val = [] 78 | pid_val = sorted(info_val.pid.tolist()) 79 | for pid in tqdm(pid_val): 80 | result_val.append(_evaluate(pid, args.gt_dir, 81 | os.path.join(args.pred_dir, "val"))) 82 | result_val = pd.DataFrame(result_val) 83 | print(result_val[cols].mean(axis=0)) 84 | result_val.to_csv(os.path.join(args.pred_dir, "results_val.csv"), 85 | index=False) 86 | 87 | result_test = [] 88 | pid_test = sorted(info_test.pid.tolist()) 89 | for pid in tqdm(pid_test): 90 | result_test.append(_evaluate(pid, args.gt_dir, 91 | os.path.join(args.pred_dir, "test"))) 92 | result_test = pd.DataFrame(result_test) 93 | print(result_test[cols].mean(axis=0)) 94 | result_test.to_csv(os.path.join(args.pred_dir, "results_test.csv"), 95 | index=False) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /figures/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HINTLab/ImPulSe/c2f5ccd63651d2b7d5dd99a5f8870b176fed525b/figures/network.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HINTLab/ImPulSe/c2f5ccd63651d2b7d5dd99a5f8870b176fed525b/models/__init__.py -------------------------------------------------------------------------------- /models/conv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class ConvLayer3d(nn.Sequential): 6 | 7 | def __init__(self, in_channels, out_channels, kernel, bias=False, 8 | norm=nn.InstanceNorm3d, actv=nn.LeakyReLU, drop_prob=0): 9 | layers = [ 10 | nn.Conv3d(in_channels, out_channels, kernel, padding=kernel // 2, 11 | bias=bias) 12 | ] 13 | if norm is not None: 14 | layers.append(norm(out_channels)) 15 | if actv is not None: 16 | layers.append(actv()) 17 | if drop_prob > 0: 18 | layers.append(nn.Dropout(drop_prob)) 19 | super().__init__(*layers) 20 | 21 | 22 | class ConvBlock3d(nn.Module): 23 | 24 | def __init__(self, in_channels, out_channels, kernel, norm, num_layers, 25 | drop_prob, upsample=False, residual=True): 26 | super().__init__() 27 | 28 | conv_layers = [ConvLayer3d(in_channels, out_channels, kernel, 29 | norm=norm, drop_prob=drop_prob) if i == 0 else 30 | ConvLayer3d(out_channels, out_channels, kernel, 31 | norm=norm, drop_prob=drop_prob) 32 | for i in range(num_layers)] 33 | self.conv_layers = nn.Sequential(*conv_layers) 34 | self.residual = residual 35 | if self.residual: 36 | self.res_layer = ConvLayer3d(in_channels, out_channels, 1, 37 | norm=norm, actv=None, drop_prob=drop_prob) 38 | self.upsample = upsample 39 | 40 | def forward(self, x): 41 | output = self.conv_layers(x) 42 | 43 | if self.residual: 44 | res = self.res_layer(x) 45 | output = res + output 46 | 47 | if self.upsample: 48 | output = F.interpolate(output, scale_factor=2, mode="trilinear", 49 | align_corners=True) 50 | 51 | return output 52 | -------------------------------------------------------------------------------- /models/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class _ASPPConv3d(nn.Sequential): 7 | 8 | def __init__(self, in_channels, out_channels, dilation): 9 | modules = [ 10 | nn.Conv3d(in_channels, out_channels, 3, padding=dilation, 11 | dilation=dilation, bias=False), 12 | nn.BatchNorm3d(out_channels), 13 | nn.LeakyReLU(inplace=True) 14 | ] 15 | super().__init__(*modules) 16 | 17 | 18 | class _ASPPPooling3d(nn.Sequential): 19 | 20 | def __init__(self, in_channels, out_channels): 21 | super(_ASPPPooling3d, self).__init__( 22 | nn.AdaptiveAvgPool3d(2), 23 | nn.Conv3d(in_channels, out_channels, 1, bias=False), 24 | nn.BatchNorm3d(out_channels), 25 | nn.LeakyReLU(inplace=True)) 26 | 27 | def forward(self, x): 28 | size = x.shape[-3:] 29 | for mod in self: 30 | x = mod(x) 31 | return F.interpolate(x, size=size, mode='trilinear', 32 | align_corners=False) 33 | 34 | 35 | class _ASPP3d(nn.Module): 36 | 37 | def __init__(self, in_channels, atrous_rates, drop_prob, out_channels): 38 | super().__init__() 39 | modules = [] 40 | modules.append(nn.Sequential( 41 | nn.Conv3d(in_channels, out_channels, 1, bias=False), 42 | nn.BatchNorm3d(out_channels), 43 | nn.LeakyReLU(inplace=True))) 44 | 45 | rates = tuple(atrous_rates) 46 | for rate in rates: 47 | modules.append(_ASPPConv3d(in_channels, out_channels, rate)) 48 | 49 | modules.append(_ASPPPooling3d(in_channels, out_channels)) 50 | 51 | self.convs = nn.ModuleList(modules) 52 | 53 | self.project = nn.Sequential( 54 | nn.Conv3d(len(self.convs) * out_channels, out_channels, 1, 55 | bias=False), 56 | nn.BatchNorm3d(out_channels), 57 | nn.LeakyReLU(inplace=True), 58 | nn.Dropout(drop_prob) 59 | ) 60 | 61 | def forward(self, x): 62 | res = [] 63 | for conv in self.convs: 64 | res.append(conv(x)) 65 | res = torch.cat(res, dim=1) 66 | 67 | return self.project(res) 68 | 69 | 70 | class DeepLabDecoder3d(nn.Sequential): 71 | 72 | def __init__(self, in_channels, atrous_rates, drop_prob, out_channels, 73 | num_classes, scale_factor): 74 | layers = [ 75 | _ASPP3d(in_channels, atrous_rates, drop_prob, out_channels), 76 | nn.Conv3d(out_channels, num_classes, 1), 77 | nn.Upsample(scale_factor=scale_factor, mode="trilinear", 78 | align_corners=True) 79 | ] 80 | 81 | super().__init__(*layers) 82 | 83 | 84 | class DeepLab3d(nn.Module): 85 | 86 | def __init__(self, encoder, decoder): 87 | super().__init__() 88 | 89 | self.encoder = encoder 90 | self.decoder = decoder 91 | 92 | def forward(self, x): 93 | features = self.encoder(x) 94 | output = self.decoder(features[-1]) 95 | 96 | return output 97 | -------------------------------------------------------------------------------- /models/fcn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FCNDecoder3d(nn.Sequential): 5 | 6 | def __init__(self, in_channels, scale_factor, num_classes): 7 | layers = [ 8 | nn.Conv3d(in_channels, num_classes, 1), 9 | nn.ConvTranspose3d(num_classes, num_classes, scale_factor, 10 | scale_factor) 11 | ] 12 | 13 | super().__init__(*layers) 14 | 15 | 16 | class FCN3d(nn.Module): 17 | 18 | def __init__(self, encoder, decoder): 19 | super().__init__() 20 | 21 | self.encoder = encoder 22 | self.decoder = decoder 23 | 24 | def forward(self, x): 25 | features = self.encoder(x) 26 | output = self.decoder(features[-1]) 27 | 28 | return output 29 | -------------------------------------------------------------------------------- /models/implicit_autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .conv import ConvBlock3d, ConvLayer3d 6 | 7 | 8 | class ImplicitEncoder(nn.Module): 9 | 10 | def __init__(self, in_channels, num_channels, num_layers): 11 | super().__init__() 12 | 13 | self.stem = ConvLayer3d(in_channels, num_channels[0], 7, True) 14 | self.conv_blocks = nn.ModuleList([ConvBlock3d(num_channels[i], 15 | num_channels[i + 1], 3, nn.InstanceNorm3d, num_layers, 0) 16 | for i in range(len(num_channels) - 1)]) 17 | self.downsample = nn.MaxPool3d(2) 18 | 19 | def forward(self, x): 20 | in_feature = self.stem(x) 21 | 22 | features = [in_feature] 23 | for i in range(len(self.conv_blocks)): 24 | out_feature = self.conv_blocks[i](in_feature) 25 | in_feature = self.downsample(out_feature) 26 | features.append(in_feature) 27 | 28 | return features 29 | 30 | 31 | class _LocalNorm3d(nn.Module): 32 | 33 | def __init__(self, num_features): 34 | super().__init__() 35 | self.norm = nn.LayerNorm(num_features) 36 | 37 | def forward(self, x): 38 | x = x.permute(0, 2, 3, 4, 1) 39 | output = self.norm(x) 40 | output = output.permute(0, 4, 1, 2, 3) 41 | 42 | return output.contiguous() 43 | 44 | 45 | class ImplicitDecoder(nn.Module): 46 | 47 | def __init__(self, num_channels, num_classes, num_layers, drop_prob): 48 | super().__init__() 49 | # MLP made up of consecutive 1x1x1 conv layers 50 | self.conv_blocks = nn.ModuleList([ConvBlock3d(num_channels[i], 51 | num_channels[i + 1], 1, _LocalNorm3d, num_layers, drop_prob) 52 | for i in range(len(num_channels) - 1)]) 53 | self.output_layer = ConvLayer3d(num_channels[-1], num_classes, 1, 54 | True, None, None, drop_prob) 55 | 56 | def forward(self, x): 57 | for i in range(len(self.conv_blocks)): 58 | x = self.conv_blocks[i](x) 59 | 60 | output = self.output_layer(x) 61 | 62 | return output 63 | 64 | 65 | class ImplicitAutoEncoder(nn.Module): 66 | 67 | def __init__(self, encoder, decoder): 68 | super().__init__() 69 | 70 | self.encoder = encoder 71 | self.decoder = decoder 72 | 73 | def make_point_encoding(self, features, grids): 74 | # interpolate features at continuous locations and concatenate 75 | # zyx -> xyz 76 | grids = torch.flip(grids, dims=(-1,)) 77 | point_encodings = [F.grid_sample(features[i], grids, "bilinear", 78 | align_corners=True) for i in range(len(features))] 79 | # zyx -> xyz 80 | grids = torch.flip(grids, dims=(-1,)) 81 | point_encodings.insert(0, grids.permute(0, 4, 1, 2, 3)) 82 | point_encodings = torch.cat(point_encodings, dim=1) 83 | 84 | return point_encodings 85 | 86 | def forward(self, x, grids): 87 | # encode spatial features 88 | features = self.encoder(x) 89 | 90 | # calculate point encodings with features and coordinates 91 | point_encodings = self.make_point_encoding(features, grids) 92 | 93 | # decode features and calculate SDF 94 | output = self.decoder(point_encodings) 95 | 96 | return output 97 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class _TotalVarLoss3d(nn.Module): 7 | 8 | def _calculate_variation(self, x, dim): 9 | assert dim in (2, 3, 4) 10 | 11 | length = x.size(dim) 12 | x0 = torch.index_select(x, dim, torch.arange(length - 1)) 13 | x1 = torch.index_select(x, dim, torch.arange(1, length)) 14 | 15 | return x0, x1 16 | 17 | def forward(self, output): 18 | x0, x1 = self._calculate_variation(output, 2) 19 | y0, y1 = self._calculate_variation(output, 3) 20 | z0, z1 = self._calculate_variation(output, 4) 21 | total_var = F.l1_loss(x0, x1) + F.l1_loss(y0, y1)\ 22 | + F.l1_loss(z0, z1) 23 | 24 | return total_var 25 | 26 | 27 | class JointLoss(nn.Module): 28 | 29 | def __init__(self, w_ce, w_var): 30 | super().__init__() 31 | 32 | self.w_ce = w_ce 33 | self.w_var = w_var 34 | self.ce_fn = nn.CrossEntropyLoss() 35 | self.var_fn = _TotalVarLoss3d() 36 | 37 | def forward(self, output, target): 38 | ce_loss = self.ce_fn(output, target) 39 | var_loss = self.var_fn(output) 40 | joint_loss = self.w_ce * ce_loss + self.w_var * var_loss 41 | 42 | return joint_loss 43 | 44 | 45 | class DiceLoss(nn.Module): 46 | 47 | def __init__(self, num_classes): 48 | super().__init__() 49 | 50 | self.num_classes = num_classes 51 | 52 | def forward(self, output, target): 53 | eps = 1e-8 54 | fg_output = torch.softmax(output, dim=1)[:, 1:, ...] 55 | fg_target = F.one_hot(target, self.num_classes)[..., 1:] 56 | fg_target = fg_target.permute(0, 4, 1, 2, 3) 57 | dice_loss = 1 - 2 * (fg_output * fg_target).sum()\ 58 | / (fg_output.sum() + fg_target.sum() + eps) 59 | 60 | return dice_loss 61 | 62 | 63 | class SegLoss(nn.Module): 64 | 65 | def __init__(self, w_ce, w_dice, num_classes): 66 | super().__init__() 67 | 68 | self.w_ce = w_ce 69 | self.w_dice = w_dice 70 | self.ce_fn = nn.CrossEntropyLoss() 71 | self.dice_fn = DiceLoss(num_classes) 72 | 73 | def forward(self, output, target): 74 | ce_loss = self.ce_fn(output, target) 75 | dice_loss = self.dice_fn(output, target) 76 | total_loss = self.w_ce * ce_loss + self.w_dice * dice_loss 77 | 78 | return total_loss 79 | -------------------------------------------------------------------------------- /models/resnet18.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.utils.checkpoint import checkpoint 3 | from torchvision.models.video import r3d_18 4 | 5 | 6 | class ResNet3d18Backbone(nn.Module): 7 | 8 | def __init__(self, in_channels): 9 | super().__init__() 10 | layers = r3d_18() 11 | 12 | for name, layer in layers.named_children(): 13 | if name == "stem": 14 | self.add_module(name, nn.Sequential( 15 | nn.Conv3d(in_channels, 64, 7, 1, 3, bias=False), 16 | nn.BatchNorm3d(64), 17 | nn.ReLU() 18 | )) 19 | elif name == "fc": 20 | break 21 | else: 22 | self.add_module(name, layer) 23 | 24 | self.num_features = self.layer4[-1].conv2[0].out_channels 25 | self.memory_efficient = False 26 | 27 | def forward(self, x): 28 | if self.memory_efficient: 29 | features = self._forward_efficient(x) 30 | else: 31 | features = self._forward_fast(x) 32 | 33 | return features 34 | 35 | def _forward_efficient(self, x): 36 | x = self.stem(x) 37 | 38 | feature_0 = checkpoint(self.layer1, x) 39 | feature_1 = checkpoint(self.layer2, feature_0) 40 | feature_2 = checkpoint(self.layer3, feature_1) 41 | feature_3 = self.layer4(feature_2) 42 | features = [feature_0, feature_1, feature_2, feature_3] 43 | 44 | return features 45 | 46 | def _forward_fast(self, x): 47 | x = self.stem(x) 48 | 49 | feature_0 = self.layer1(x) 50 | feature_1 = self.layer2(feature_0) 51 | feature_2 = self.layer3(feature_1) 52 | feature_3 = self.layer4(feature_2) 53 | features = [feature_0, feature_1, feature_2, feature_3] 54 | 55 | return features 56 | 57 | 58 | if __name__ == "__main__": 59 | import torch 60 | 61 | 62 | model = ResNet3d18Backbone(1) 63 | inputs = torch.rand(1, 1, 128, 128, 128) 64 | features = model(inputs) 65 | print(features[-1].size()) 66 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .conv import ConvBlock3d, ConvLayer3d 5 | 6 | 7 | class UNetDecoder(nn.Module): 8 | 9 | def __init__(self, num_channels, num_layers, num_classes): 10 | super().__init__() 11 | 12 | self.conv_blocks = nn.ModuleList() 13 | 14 | for i in range(len(num_channels) - 1): 15 | in_channels = num_channels[i] if i == 0 else num_channels[i] * 2 16 | out_channels = num_channels[i + 1] 17 | self.conv_blocks.append( 18 | ConvBlock3d(in_channels, out_channels, 3, 19 | nn.BatchNorm3d, num_layers, 0, True, False) 20 | ) 21 | 22 | self.output_layer = ConvLayer3d(num_channels[-1] * 2, num_classes, 1, 23 | True, nn.BatchNorm3d, None) 24 | 25 | def forward(self, features): 26 | in_features = features[-1] 27 | for i in range(len(self.conv_blocks)): 28 | out_features = self.conv_blocks[i](in_features) 29 | in_features = torch.cat((out_features, features[-(i + 2)]), dim=1) 30 | 31 | output = self.output_layer(in_features) 32 | 33 | return output 34 | 35 | 36 | class UNet(nn.Module): 37 | 38 | def __init__(self, encoder, decoder): 39 | super().__init__() 40 | 41 | self.encoder = encoder 42 | self.decoder = decoder 43 | 44 | def forward(self, x): 45 | features = self.encoder(x) 46 | output = self.decoder(features) 47 | 48 | return output 49 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from argparse import ArgumentParser 4 | from importlib import import_module 5 | 6 | import SimpleITK as sitk 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | import torch.nn as nn 11 | from tqdm import tqdm 12 | 13 | from data.dataset import LungSegmentDataset 14 | from data import transforms as aug 15 | from models.implicit_autoencoder import ImplicitAutoEncoder, ImplicitDecoder 16 | from models.resnet18 import ResNet3d18Backbone 17 | from utils.logger import logger 18 | 19 | 20 | def _parse_cmd_args(): 21 | arg_parser = ArgumentParser() 22 | arg_parser.add_argument("--gpu", default="0", help="GPU ID.") 23 | arg_parser.add_argument("--cfg", required=True, 24 | help="Python config module.") 25 | arg_parser.add_argument("--data_dir", required=True, 26 | help="Data directory") 27 | arg_parser.add_argument("--df_path", required=True, 28 | help="Data info csv path.") 29 | arg_parser.add_argument("--weight_path", required=True, 30 | help="Train model weight path.") 31 | arg_parser.add_argument("--output_dir", required=True, 32 | help="Prediction output directory.") 33 | args = arg_parser.parse_args() 34 | 35 | return args 36 | 37 | 38 | def _set_rng_seed(seed): 39 | random.seed(seed) 40 | np.random.seed(seed) 41 | torch.manual_seed(seed) 42 | 43 | 44 | def _init_dataloaders(args): 45 | data_dir = args.data_dir 46 | df = pd.read_csv(args.df_path) 47 | 48 | transforms = [ 49 | aug.CropLung(), 50 | aug.GetLobe(), 51 | aug.Resample(cfg.resample_configs), 52 | aug.OnehotEncode("lobe", 6), 53 | aug.MinMaxNormalize(cfg.win_min, cfg.win_max), 54 | aug.SampleGrid(mode="regular"), 55 | aug.SampleTarget("lungsegment"), 56 | aug.ConcatInputs(cfg.input_keys), 57 | aug.ToTensor() 58 | ] 59 | 60 | ds_val = LungSegmentDataset(df, data_dir, transforms, "val") 61 | dl_val = LungSegmentDataset.get_dataloader(ds_val, cfg.infer_batch_size, 62 | False, cfg.num_workers, "infer") 63 | ds_test = LungSegmentDataset(df, data_dir, transforms, "test") 64 | dl_test = LungSegmentDataset.get_dataloader(ds_test, cfg.infer_batch_size, 65 | False, cfg.num_workers, "infer") 66 | 67 | return dl_val, dl_test 68 | 69 | 70 | def _init_model(args): 71 | encoder = ResNet3d18Backbone(**cfg.enc_cfgs) 72 | decoder = ImplicitDecoder(**cfg.dec_cfgs) 73 | model = ImplicitAutoEncoder(encoder, decoder) 74 | 75 | devices = [torch.device(f"cuda:{i}") for i in args.gpu.split(",")] 76 | if len(devices) > 1: 77 | model = nn.DataParallel(model.cuda(), devices) 78 | else: 79 | model = model.cuda() 80 | 81 | return model 82 | 83 | 84 | def _calculate_bboxes(image_shape, crop_size): 85 | steps = [np.arange(0, image_shape[i], crop_size[i]).tolist() 86 | + [image_shape[i]] for i in range(3)] 87 | begs = [steps[i][:-1] for i in range(3)] 88 | ends = [steps[i][1:] for i in range(3)] 89 | bboxes = [] 90 | for i in range(len(begs[0])): 91 | for j in range(len(begs[1])): 92 | for k in range(len(begs[2])): 93 | bboxes.append(np.array([ 94 | [begs[0][i], ends[0][i]], 95 | [begs[1][j], ends[1][j]], 96 | [begs[2][k], ends[2][k]], 97 | ])) 98 | bboxes = np.stack(bboxes) 99 | 100 | return bboxes 101 | 102 | 103 | def _sliding_window_predict(model, inputs, grids, window_size): 104 | resolution = grids.size()[-2:0:-1] 105 | bboxes = _calculate_bboxes(resolution, window_size) 106 | output = np.zeros(resolution, dtype=np.uint8) 107 | features = model.encoder(inputs) 108 | 109 | for i in range(bboxes.shape[0]): 110 | grid_patch = grids[ 111 | :, 112 | bboxes[i, 2, 0]:bboxes[i, 2, 1], 113 | bboxes[i, 1, 0]:bboxes[i, 1, 1], 114 | bboxes[i, 0, 0]:bboxes[i, 0, 1], 115 | :].cuda() 116 | point_encodings = model.make_point_encoding(features, grid_patch) 117 | output_patch = model.decoder(point_encodings) 118 | output_patch = output_patch.cpu().numpy().squeeze(axis=0) 119 | output_patch = output_patch.argmax(axis=0).astype(np.uint8) 120 | output[ 121 | bboxes[i, 0, 0]:bboxes[i, 0, 1], 122 | bboxes[i, 1, 0]:bboxes[i, 1, 1], 123 | bboxes[i, 2, 0]:bboxes[i, 2, 1] 124 | ] = output_patch.transpose((2, 1, 0)) 125 | 126 | return output 127 | 128 | 129 | @logger 130 | @torch.no_grad() 131 | def _predict(model, dataloader, output_dir, subset): 132 | model.eval() 133 | progress = tqdm(total=len(dataloader)) 134 | os.makedirs(os.path.join(output_dir, subset), exist_ok=True) 135 | 136 | for i, sample in enumerate(dataloader): 137 | inputs, grids, pids, bboxes, shapes = sample 138 | pid = pids[0] 139 | bbox = bboxes[0] 140 | grids = grids[0][None] 141 | original_shape = shapes[0] 142 | inputs = inputs.cuda() 143 | 144 | y_pred_lung = _sliding_window_predict(model, inputs, grids, 145 | cfg.window_size) 146 | y_pred = np.zeros(original_shape, dtype=np.uint8) 147 | y_pred[ 148 | bbox[0, 0]:bbox[0, 1] + 1, 149 | bbox[1, 0]:bbox[1, 1] + 1, 150 | bbox[2, 0]:bbox[2, 1] + 1 151 | ] = y_pred_lung 152 | 153 | y_pred_img = sitk.GetImageFromArray(y_pred.astype(np.uint8)) 154 | sitk.WriteImage(y_pred_img, os.path.join(output_dir, subset, 155 | f"{pid}_pred.nii.gz")) 156 | 157 | progress.update() 158 | 159 | progress.close() 160 | 161 | 162 | def _load_weights(weight_path): 163 | model_weights = torch.load(weight_path) 164 | new_model_weights = {} 165 | for k in model_weights.keys(): 166 | if k.startswith("module."): 167 | new_k = k[7:] 168 | new_model_weights[new_k] = model_weights[k] 169 | 170 | return new_model_weights 171 | 172 | 173 | def main(): 174 | _set_rng_seed(42) 175 | 176 | args = _parse_cmd_args() 177 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 178 | global cfg 179 | cfg = import_module(f"configs.{args.cfg}_config") 180 | 181 | dl_val, dl_test = _init_dataloaders(args) 182 | 183 | model = _init_model(args) 184 | model_weights = _load_weights(args.weight_path) 185 | model.load_state_dict(model_weights) 186 | 187 | _predict(model, dl_val, args.output_dir, "val") 188 | _predict(model, dl_test, args.output_dir, "test") 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /predict_unet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from argparse import ArgumentParser 4 | from importlib import import_module 5 | 6 | import SimpleITK as sitk 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | import torch.nn as nn 11 | from tqdm import tqdm 12 | 13 | from data.dataset import LungSegmentDataset 14 | from data import transforms as aug 15 | from models.deeplab import DeepLab3d, DeepLabDecoder3d 16 | from models.fcn import FCN3d, FCNDecoder3d 17 | from models.resnet18 import ResNet3d18Backbone 18 | from models.unet import UNet, UNetDecoder 19 | from utils.logger import logger 20 | 21 | 22 | def _parse_cmd_args(): 23 | arg_parser = ArgumentParser() 24 | arg_parser.add_argument("--gpu", default="0,1,2,3", help="GPU ID.") 25 | arg_parser.add_argument("--cfg", required=True, 26 | help="Python config module.") 27 | arg_parser.add_argument("--data_dir", required=True, 28 | help="Data directory") 29 | arg_parser.add_argument("--df_path", required=True, 30 | help="Data info csv path.") 31 | arg_parser.add_argument("--weight_path", required=True, 32 | help="Train model weight path.") 33 | arg_parser.add_argument("--output_dir", required=True, 34 | help="Prediction output directory.") 35 | args = arg_parser.parse_args() 36 | 37 | return args 38 | 39 | 40 | def _set_rng_seed(seed): 41 | random.seed(seed) 42 | np.random.seed(seed) 43 | torch.manual_seed(seed) 44 | 45 | 46 | def _init_dataloaders(args): 47 | data_dir = args.data_dir 48 | df = pd.read_csv(args.df_path) 49 | 50 | transforms = [ 51 | aug.CropLung(), 52 | aug.Resample(cfg.resample_cfgs), 53 | aug.SampleGrid(cfg.in_res, "regular"), 54 | aug.MinMaxNormalize(cfg.win_min, cfg.win_max), 55 | aug.ConcatInputs(cfg.input_keys), 56 | aug.ToTensor() 57 | ] 58 | 59 | ds_val = LungSegmentDataset(df, data_dir, transforms, "val") 60 | dl_val = LungSegmentDataset.get_dataloader(ds_val, cfg.eval_batch_size, 61 | False, cfg.num_workers, "unet_infer") 62 | ds_test = LungSegmentDataset(df, data_dir, transforms, "test") 63 | dl_test = LungSegmentDataset.get_dataloader(ds_test, cfg.eval_batch_size, 64 | False, cfg.num_workers, "unet_infer") 65 | 66 | return dl_val, dl_test 67 | 68 | 69 | def _init_model(args): 70 | encoder = ResNet3d18Backbone(**cfg.enc_cfgs) 71 | 72 | if args.cfg in ["unet", "coord"]: 73 | decoder = UNetDecoder(**cfg.dec_cfgs) 74 | model = UNet(encoder, decoder) 75 | elif args.cfg == "fcn": 76 | decoder = FCNDecoder3d(**cfg.dec_cfgs) 77 | model = FCN3d(encoder, decoder) 78 | elif args.cfg == "deeplab": 79 | decoder = DeepLabDecoder3d(**cfg.dec_cfgs) 80 | model = DeepLab3d(encoder, decoder) 81 | 82 | devices = [torch.device(f"cuda:{i}") for i in args.gpu.split(",")] 83 | if len(devices) > 1: 84 | model = nn.DataParallel(model.cuda(), devices) 85 | else: 86 | model = model.cuda() 87 | 88 | return model 89 | 90 | 91 | @logger 92 | @torch.no_grad() 93 | def _predict(model, dataloader, output_dir, subset, interpolation): 94 | model.eval() 95 | progress = tqdm(total=len(dataloader)) 96 | os.makedirs(os.path.join(output_dir, subset), exist_ok=True) 97 | 98 | for i, sample in enumerate(dataloader): 99 | inputs, _, pids, bboxes, original_shapes = sample 100 | inputs = inputs.cuda() 101 | 102 | y_prob_lung = model(inputs).cpu().numpy() 103 | n_classes = y_prob_lung.shape[1] 104 | y_pred_lung = np.argmax(y_prob_lung, axis=1).astype(np.uint8) 105 | batch_size = y_pred_lung.shape[0] 106 | for j in range(batch_size): 107 | y_pred = np.zeros(original_shapes[j], dtype=np.uint8) 108 | lung_shape = tuple((np.diff(bboxes[j], axis=-1) + 1).squeeze()) 109 | lung_shape = tuple([int(x) for x in lung_shape]) 110 | d, h, w = y_pred_lung[j].shape 111 | y_pred_lung_onehot = np.eye(n_classes)[y_pred_lung[j].reshape(-1)] 112 | y_pred_lung_onehot = y_pred_lung_onehot.reshape((d, h, 113 | w, n_classes)).transpose((3, 0, 1, 2)) 114 | pred_lung = np.argmax(np.stack([aug._resample_to_shape( 115 | y_pred_lung_onehot[c], lung_shape, interpolation) 116 | for c in range(n_classes)]), axis=0) 117 | y_pred[ 118 | bboxes[j, 0, 0]:bboxes[j, 0, 1] + 1, 119 | bboxes[j, 1, 0]:bboxes[j, 1, 1] + 1, 120 | bboxes[j, 2, 0]:bboxes[j, 2, 1] + 1 121 | ] = pred_lung 122 | 123 | y_pred_img = sitk.GetImageFromArray(y_pred) 124 | sitk.WriteImage(y_pred_img, os.path.join(output_dir, subset, 125 | f"{pids[j]}_pred.nii.gz")) 126 | 127 | progress.update() 128 | 129 | progress.close() 130 | 131 | 132 | def _load_weights(weight_path): 133 | raw_weights = torch.load(weight_path) 134 | weights = {} 135 | for k in raw_weights.keys(): 136 | new_k = k.replace("module.", "") 137 | weights[new_k] = raw_weights[k] 138 | 139 | return weights 140 | 141 | 142 | def main(): 143 | _set_rng_seed(42) 144 | 145 | args = _parse_cmd_args() 146 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 147 | global cfg 148 | cfg = import_module(f"configs.{args.cfg}_config") 149 | 150 | dl_val, dl_test = _init_dataloaders() 151 | 152 | model = _init_model(args) 153 | model_weights = _load_weights(args.weight_path) 154 | model.load_state_dict(model_weights) 155 | 156 | os.makedirs(args.output_dir, exist_ok=True) 157 | interp = "linear" 158 | _predict(model, dl_val, args.output_dir, "val", interp) 159 | _predict(model, dl_test, args.output_dir, "test", interp) 160 | 161 | 162 | if __name__ == "__main__": 163 | main() 164 | -------------------------------------------------------------------------------- /sample_data/inference/pulse_00002.npz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5aeab67ce9ea838cd477f065776da2d3d1524702e13c1cfb5f38e9d07ab57cf3 3 | size 445908704 4 | -------------------------------------------------------------------------------- /sample_data/inference/test.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0a912f5c3f198bcd3e8a8e0dc7c937e2b7908888a5b6035d8cd0b7da8163f772 3 | size 43 4 | -------------------------------------------------------------------------------- /sample_data/training/pulse_00002.npz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0209b481a8cbab6689ceeed79e9979aa8e0846cde652e334ca624d0b98a0d5bc 3 | size 26232373 4 | -------------------------------------------------------------------------------- /sample_data/training/train.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1a331778ffb01945c1e4deb11b69760857957399ea6e64d3ee46d053a18dc19a 3 | size 44 4 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from argparse import ArgumentParser 4 | from datetime import datetime 5 | from importlib import import_module 6 | from time import perf_counter 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import torch.nn as nn 12 | from torch import optim 13 | 14 | from data.dataset import LungSegmentDataset 15 | from data import transforms as aug 16 | from models.implicit_autoencoder import ImplicitAutoEncoder, ImplicitDecoder 17 | from models.losses import SegLoss 18 | from models.resnet18 import ResNet3d18Backbone 19 | from utils.logger import logger 20 | from utils.metrics import foreground_dice_score 21 | 22 | 23 | def _parse_cmd_args(): 24 | arg_parser = ArgumentParser() 25 | arg_parser.add_argument("--gpu", default="0,1,2,3", help="GPU ID.") 26 | arg_parser.add_argument("--cfg", required=True, 27 | help="Python config module.") 28 | arg_parser.add_argument("--data_dir", required=True, 29 | help="Data directory") 30 | arg_parser.add_argument("--df_path", required=True, 31 | help="Data info csv path.") 32 | arg_parser.add_argument("--log_dir", required=True, 33 | help="Tensorboard log directory.") 34 | args = arg_parser.parse_args() 35 | 36 | return args 37 | 38 | 39 | def _set_rng_seed(seed): 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | torch.manual_seed(seed) 43 | 44 | 45 | def _init_dataloaders(args): 46 | data_dir = args.data_dir 47 | df = pd.read_csv(args.df_path) 48 | 49 | transforms_train = [ 50 | aug.OnehotEncode("lobe", 6), 51 | aug.MinMaxNormalize(cfg.win_min, cfg.win_max), 52 | aug.SampleGrid(cfg.out_res, "random"), 53 | aug.SampleTarget("lungsegment"), 54 | aug.ConcatInputs(cfg.input_keys), 55 | aug.ToTensor() 56 | ] 57 | transforms_val = [ 58 | aug.OnehotEncode("lobe", 6), 59 | aug.MinMaxNormalize(cfg.win_min, cfg.win_max), 60 | aug.SampleGrid(cfg.eval_res, "regular"), 61 | aug.SampleTarget("lungsegment"), 62 | aug.ConcatInputs(cfg.input_keys), 63 | aug.ToTensor() 64 | ] 65 | ds_train = LungSegmentDataset(df, data_dir, transforms_train, "train") 66 | dl_train = LungSegmentDataset.get_dataloader(ds_train, cfg.batch_size, 67 | True, cfg.num_workers) 68 | ds_val = LungSegmentDataset(df, data_dir, transforms_val, "val") 69 | dl_val = LungSegmentDataset.get_dataloader(ds_val, cfg.eval_batch_size, 70 | False, cfg.num_workers) 71 | 72 | return dl_train, dl_val 73 | 74 | 75 | def _init_model(args): 76 | encoder = ResNet3d18Backbone(**cfg.enc_cfgs) 77 | decoder = ImplicitDecoder(**cfg.dec_cfgs) 78 | model = ImplicitAutoEncoder(encoder, decoder) 79 | 80 | devices = [int(x) for x in args.gpu.split(",")] 81 | if len(devices) > 1: 82 | model = nn.DataParallel(model.cuda(), devices) 83 | else: 84 | model = model.cuda() 85 | 86 | return model 87 | 88 | 89 | @logger 90 | def _train_epoch(model, dataloader, criterion, optimizer, scheduler): 91 | model.train() 92 | loss_train = 0 93 | fg_acc_train = 0 94 | 95 | for _, sample in enumerate(dataloader): 96 | optimizer.zero_grad() 97 | 98 | inputs, targets, grids = sample 99 | inputs = inputs.cuda() 100 | targets = targets.cuda() 101 | grids = grids.cuda() 102 | outputs = model(inputs, grids) 103 | loss = criterion(outputs, targets) 104 | 105 | loss.backward() 106 | optimizer.step() 107 | scheduler.step() 108 | 109 | with torch.no_grad(): 110 | loss_train += loss.cpu().item() 111 | y_true = targets.cpu().numpy().reshape(-1) 112 | y_pred = np.argmax(outputs.cpu().numpy(), axis=1).reshape(-1) 113 | fg_acc_train += (y_true[y_true > 0] == y_pred[y_true > 0]).mean() 114 | 115 | loss_train /= len(dataloader) 116 | fg_acc_train /= len(dataloader) 117 | 118 | results = { 119 | "loss": loss_train, 120 | "fg_accuracy": fg_acc_train 121 | } 122 | 123 | return results 124 | 125 | 126 | @logger 127 | @torch.no_grad() 128 | @torch.cuda.amp.autocast() 129 | def _eval_epoch(model, dataloader, criterion): 130 | torch.cuda.empty_cache() 131 | model.eval() 132 | loss_val = 0 133 | fg_acc_val = 0 134 | fg_dice_val = 0 135 | 136 | for _, sample in enumerate(dataloader): 137 | inputs, targets, grids = sample 138 | inputs = inputs.cuda() 139 | targets = targets.cuda() 140 | grids = grids.cuda() 141 | outputs = model(inputs, grids) 142 | loss = criterion(outputs, targets) 143 | 144 | loss_val += loss.cpu().item() 145 | y_true = targets.cpu().numpy().reshape(-1) 146 | y_pred = np.argmax(outputs.cpu().numpy(), axis=1).reshape(-1) 147 | fg_acc_val += (y_true[y_true > 0] == y_pred[y_true > 0]).mean() 148 | fg_dice_val += foreground_dice_score(y_true, y_pred, 18) 149 | 150 | loss_val /= len(dataloader) 151 | fg_acc_val /= len(dataloader) 152 | fg_dice_val /= len(dataloader) 153 | 154 | results = { 155 | "loss": loss_val, 156 | "fg_accuracy": fg_acc_val, 157 | "fg_dice": fg_dice_val, 158 | } 159 | 160 | return results 161 | 162 | 163 | def _log_metrics(results_train, results_val): 164 | metrics = {"train": results_train, "val": results_val} 165 | metrics = pd.DataFrame(metrics) 166 | print(metrics) 167 | 168 | 169 | def _log_tensorboard(tb_writer, epoch, results_train, results_val): 170 | for k in results_train.keys(): 171 | tb_writer.add_scalars(k, {"train": results_train[k], 172 | "val": results_val[k]}, epoch) 173 | for k in results_val.keys(): 174 | if "dice" in k: 175 | tb_writer.add_scalar(k, results_val[k], epoch) 176 | 177 | tb_writer.flush() 178 | 179 | 180 | def main(): 181 | _set_rng_seed(42) 182 | 183 | args = _parse_cmd_args() 184 | torch.cuda.set_device(int(args.gpu.split(",")[0])) 185 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 186 | print(args.cfg) 187 | global cfg 188 | cfg = import_module(f"configs.{args.cfg}_config") 189 | 190 | dl_train, dl_val = _init_dataloaders(args) 191 | 192 | model = _init_model(args) 193 | criterion = SegLoss(cfg.w_ce, cfg.w_dice, 19) 194 | optimizer = optim.AdamW(model.parameters(), cfg.max_lr) 195 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 196 | len(dl_train) * cfg.epochs, cfg.min_lr) 197 | 198 | # set up logging 199 | log_dir = args.log_dir 200 | cur_time = datetime.now().strftime("%Y%m%d-%H%M%S") 201 | print(cur_time) 202 | log_dir = os.path.join(log_dir, cur_time) 203 | os.makedirs(log_dir) 204 | time_train = 0 205 | 206 | for i in range(cfg.epochs): 207 | print(f"Epoch {i}") 208 | epoch_start = perf_counter() 209 | res_train = _train_epoch(model, dl_train, criterion, 210 | optimizer, scheduler) 211 | time_train += (perf_counter() - epoch_start) 212 | 213 | if (i + 1) % cfg.eval_freq == 0: 214 | res_val = _eval_epoch(model, dl_val, criterion) 215 | _log_metrics(res_train, res_val) 216 | 217 | torch.save(model.state_dict(), os.path.join(log_dir, 218 | f"model_{i + 1}.pth")) 219 | 220 | print(f"Total training time: {time_train:.4f}") 221 | 222 | 223 | if __name__ == "__main__": 224 | main() 225 | -------------------------------------------------------------------------------- /train_unet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from argparse import ArgumentParser 4 | from datetime import datetime 5 | from importlib import import_module 6 | from time import perf_counter 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import torch.nn as nn 12 | from torch import optim 13 | from torch.utils.tensorboard import SummaryWriter 14 | 15 | from data.dataset import LungSegmentDataset 16 | from data import transforms as aug 17 | from models.deeplab import DeepLab3d, DeepLabDecoder3d 18 | from models.fcn import FCN3d, FCNDecoder3d 19 | from models.losses import SegLoss 20 | from models.resnet18 import ResNet3d18Backbone 21 | from models.unet import UNet, UNetDecoder 22 | from utils.logger import logger 23 | from utils.metrics import foreground_dice_score 24 | 25 | 26 | def _parse_cmd_args(): 27 | arg_parser = ArgumentParser() 28 | arg_parser.add_argument("--gpu", default="0,1,2,3", help="GPU ID.") 29 | arg_parser.add_argument("--cfg", required=True, 30 | help="Python config module.") 31 | arg_parser.add_argument("--data_dir", required=True, 32 | help="Data directory") 33 | arg_parser.add_argument("--df_path", required=True, 34 | help="Data info csv path.") 35 | arg_parser.add_argument("--log_dir", required=True, 36 | help="Tensorboard log directory.") 37 | args = arg_parser.parse_args() 38 | 39 | return args 40 | 41 | 42 | def _set_rng_seed(seed): 43 | random.seed(seed) 44 | np.random.seed(seed) 45 | torch.manual_seed(seed) 46 | 47 | 48 | def _init_dataloaders(args): 49 | data_dir = args.data_dir 50 | df = pd.read_csv(args.df_path) 51 | 52 | transforms_train = [ 53 | aug.Resample(cfg.resample_cfgs), 54 | aug.SampleGrid(cfg.in_res, "regular"), 55 | aug.MinMaxNormalize(cfg.win_min, cfg.win_max), 56 | aug.ConcatInputs(cfg.input_keys), 57 | aug.ToTensor() 58 | ] 59 | transforms_val = [ 60 | aug.Resample(cfg.resample_cfgs), 61 | aug.SampleGrid(cfg.in_res, "regular"), 62 | aug.MinMaxNormalize(cfg.win_min, cfg.win_max), 63 | aug.ConcatInputs(cfg.input_keys), 64 | aug.ToTensor() 65 | ] 66 | ds_train = LungSegmentDataset(df, data_dir, transforms_train, "train") 67 | dl_train = LungSegmentDataset.get_dataloader(ds_train, cfg.batch_size, 68 | True, cfg.num_workers, "unet") 69 | ds_val = LungSegmentDataset(df, data_dir, transforms_val, "val") 70 | dl_val = LungSegmentDataset.get_dataloader(ds_val, cfg.batch_size, False, 71 | cfg.num_workers, "unet") 72 | 73 | return dl_train, dl_val 74 | 75 | 76 | def _init_model(args): 77 | encoder = ResNet3d18Backbone(**cfg.enc_cfgs) 78 | 79 | if args.cfg in ["unet", "coord"]: 80 | decoder = UNetDecoder(**cfg.dec_cfgs) 81 | model = UNet(encoder, decoder) 82 | elif args.cfg == "fcn": 83 | decoder = FCNDecoder3d(**cfg.dec_cfgs) 84 | model = FCN3d(encoder, decoder) 85 | elif args.cfg == "deeplab": 86 | decoder = DeepLabDecoder3d(**cfg.dec_cfgs) 87 | model = DeepLab3d(encoder, decoder) 88 | 89 | devices = [int(x) for x in args.gpu.split(",")] 90 | if len(devices) > 1: 91 | model = nn.DataParallel(model.cuda(), devices) 92 | else: 93 | model = model.cuda() 94 | 95 | return model 96 | 97 | 98 | @logger 99 | def _train_epoch(model, dataloader, criterion, optimizer, scheduler): 100 | model.train() 101 | loss_train = 0 102 | dice_train = 0 103 | 104 | for i, sample in enumerate(dataloader): 105 | optimizer.zero_grad() 106 | 107 | inputs, targets = sample 108 | inputs = inputs.cuda() 109 | targets = targets.cuda() 110 | outputs = model(inputs) 111 | loss = criterion(outputs, targets) 112 | 113 | loss.backward() 114 | optimizer.step() 115 | scheduler.step() 116 | 117 | with torch.no_grad(): 118 | loss_train += loss.detach().cpu().item() 119 | y_true = targets.detach().cpu().numpy() 120 | y_pred = outputs.detach().argmax(dim=1).cpu().numpy() 121 | dice_train += foreground_dice_score(y_true, y_pred, 122 | cfg.dec_cfgs["num_classes"]) 123 | 124 | loss_train /= len(dataloader) 125 | dice_train /= len(dataloader) 126 | 127 | results = { 128 | "loss": loss_train, 129 | "dice": dice_train 130 | } 131 | 132 | return results 133 | 134 | 135 | @logger 136 | @torch.no_grad() 137 | def _eval_epoch(model, dataloader, criterion): 138 | model.eval() 139 | loss_val = 0 140 | dice_val = 0 141 | 142 | for i, sample in enumerate(dataloader): 143 | inputs, targets = sample 144 | inputs = inputs.cuda() 145 | targets = targets.cuda() 146 | outputs = model(inputs) 147 | loss = criterion(outputs, targets) 148 | 149 | loss_val += loss.detach().cpu().item() 150 | y_true = targets.cpu().numpy() 151 | y_pred = outputs.argmax(dim=1).cpu().numpy() 152 | dice_val += foreground_dice_score(y_true, y_pred, 153 | cfg.dec_cfgs["num_classes"]) 154 | 155 | loss_val /= len(dataloader) 156 | dice_val /= len(dataloader) 157 | 158 | results = { 159 | "loss": loss_val, 160 | "dice": dice_val 161 | } 162 | 163 | return results 164 | 165 | 166 | def _log_metrics(results_train, results_val): 167 | metrics = {"train": results_train, "val": results_val} 168 | metrics = pd.DataFrame(metrics) 169 | print(metrics) 170 | 171 | 172 | def _log_tensorboard(tb_writer, epoch, results_train, results_val): 173 | for k in results_train.keys(): 174 | tb_writer.add_scalars(k, {"train": results_train[k], 175 | "val": results_val[k]}, epoch) 176 | 177 | tb_writer.flush() 178 | 179 | 180 | def main(): 181 | _set_rng_seed(42) 182 | 183 | args = _parse_cmd_args() 184 | torch.cuda.set_device(int(args.gpu.split(",")[0])) 185 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 186 | global cfg 187 | cfg = import_module(f"configs.{args.cfg}_config") 188 | print(args.cfg) 189 | 190 | dl_train, dl_val = _init_dataloaders() 191 | 192 | model = _init_model(args) 193 | criterion = SegLoss(cfg.w_ce, cfg.w_dice, cfg.dec_cfgs["num_classes"]) 194 | optimizer = optim.AdamW(model.parameters(), cfg.max_lr) 195 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 196 | len(dl_train) * cfg.epochs, cfg.min_lr) 197 | 198 | # set up tensorboard 199 | log_dir = args.log_dir 200 | cur_time = datetime.now().strftime("%Y%m%d-%H%M%S") 201 | print(cur_time) 202 | log_dir = os.path.join(log_dir, cur_time) 203 | time_train = 0 204 | 205 | for i in range(cfg.epochs): 206 | print(f"Epoch {i}") 207 | epoch_start = perf_counter() 208 | res_train = _train_epoch(model, dl_train, criterion, 209 | optimizer, scheduler) 210 | time_train += (perf_counter() - epoch_start) 211 | 212 | if (i + 1) % cfg.eval_freq == 0: 213 | res_val = _eval_epoch(model, dl_val, criterion) 214 | _log_metrics(res_train, res_val) 215 | 216 | torch.save(model.state_dict(), os.path.join(log_dir, 217 | f"model_{i}.pth")) 218 | 219 | print(f"Total training time: {time_train:.4f}") 220 | 221 | 222 | if __name__ == "__main__": 223 | main() 224 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HINTLab/ImPulSe/c2f5ccd63651d2b7d5dd99a5f8870b176fed525b/utils/__init__.py -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from time import perf_counter 3 | 4 | 5 | def logger(func): 6 | @wraps(func) 7 | def _decorated(*args, **kwargs): 8 | start = perf_counter() 9 | res = func(*args, **kwargs) 10 | end = perf_counter() 11 | time_elapsed = round(end - start) 12 | mins = time_elapsed // 60 13 | secs = time_elapsed % 60 14 | print(f"{func.__name__}: {mins} min {secs} sec.") 15 | 16 | return res 17 | 18 | return _decorated 19 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | EPS = 1e-8 5 | 6 | 7 | def _oneclass_dice(y_true, y_pred): 8 | dice = 2 * np.logical_and(y_true, y_pred).sum()\ 9 | / (y_true.sum() + y_pred.sum() + EPS) 10 | 11 | return dice 12 | 13 | 14 | def dice_score(y_true, y_pred, num_classes): 15 | dice = np.mean([_oneclass_dice(y_true == i, y_pred == i) for i 16 | in range(num_classes)]) 17 | 18 | return dice 19 | 20 | 21 | def foreground_dice_score(y_true, y_pred, num_classes): 22 | dice = np.mean([_oneclass_dice(y_true == i, y_pred == i) for i 23 | in range(1, num_classes + 1)]) 24 | 25 | return dice 26 | --------------------------------------------------------------------------------