├── .github └── workflows │ └── format.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── bin ├── C2C ├── C2C-slurm └── install.sh ├── comp2comp ├── __init__.py ├── aaa │ └── aaa.py ├── aortic_calcium │ ├── aortic_calcium.py │ ├── aortic_calcium_visualization.py │ └── visualization_utils.py ├── contrast_phase │ ├── contrast_inf.py │ ├── contrast_phase.py │ └── xgboost.pkl ├── hip │ ├── hip.py │ ├── hip_utils.py │ ├── hip_visualization.py │ └── tunnelvision.ipynb ├── inference_class_base.py ├── inference_pipeline.py ├── io │ ├── io.py │ └── io_utils.py ├── liver_spleen_pancreas │ ├── liver_spleen_pancreas.py │ ├── liver_spleen_pancreas_visualization.py │ └── visualization_utils.py ├── metrics │ └── metrics.py ├── models │ └── models.py ├── muscle_adipose_tissue │ ├── data.py │ ├── muscle_adipose_tissue.py │ └── muscle_adipose_tissue_visualization.py ├── spine │ ├── spine.py │ ├── spine_utils.py │ └── spine_visualization.py ├── utils │ ├── __init__.py │ ├── colormap.py │ ├── dl_utils.py │ ├── env.py │ ├── logger.py │ ├── orientation.py │ ├── process.py │ └── run.py └── visualization │ ├── detectron_visualizer.py │ ├── dicom.py │ └── linear_planar_reformation.py ├── docs ├── Local Implementation @ M1 arm64 Silicon.md ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── conf.py │ └── index.rst ├── figures ├── aaa_diameter_graph.png ├── aaa_segmentation_video.gif ├── aortic_aneurysm_example.png ├── aortic_calcium_overview.png ├── hip_example.png ├── liver_spleen_pancreas_example.png ├── muscle_adipose_tissue_example.png ├── spine_example.png └── spine_muscle_adipose_tissue_example.png ├── logo.png ├── setup.cfg └── setup.py /.github/workflows/format.yml: -------------------------------------------------------------------------------- 1 | name: Autoformat code 2 | 3 | on: 4 | push: 5 | branches: [ 'main' ] 6 | pull_request: 7 | branches: [ 'main' ] 8 | 9 | jobs: 10 | format: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Format code 15 | run: | 16 | pip install black 17 | black . 18 | - name: Sort imports 19 | run: | 20 | pip install isort 21 | isort . 22 | - name: Remove unused imports 23 | run: | 24 | pip install autoflake 25 | autoflake --in-place --remove-all-unused-imports --remove-unused-variables --recursive . 26 | - name: Commit changes 27 | uses: EndBug/add-and-commit@v4 28 | with: 29 | author_name: ${{ github.actor }} 30 | author_email: ${{ github.actor }}@users.noreply.github.com 31 | message: "Autoformat code" 32 | add: "." 33 | branch: ${{ github.ref }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore project files 2 | **/.idea 3 | **/.DS_Store 4 | **/.vscode 5 | 6 | # Ignore cache 7 | **/__pycache__ 8 | 9 | # Ignore egg files 10 | **/*.egg-info 11 | 12 | # Docs build files 13 | docs/_build 14 | 15 | # Ignore tensorflow logs 16 | **/tf_log 17 | 18 | # Ignore results 19 | **/pik_data 20 | **/preds 21 | 22 | # Ignore test_data 23 | **/test_data 24 | **/testing_data 25 | **/sample_data 26 | **/test_results 27 | 28 | # Ignore images 29 | **/model_imgs 30 | 31 | # Ignore data visualization scripts/images 32 | **/data_visualization 33 | **/OAI-iMorphics 34 | 35 | # temp files 36 | ._* 37 | # ignore checkpoint files 38 | **/.ipynb_checkpoints/ 39 | **/.comp2comp/ 40 | 41 | # ignore cross validation files 42 | *.cv 43 | 44 | # ignore yml file 45 | *.yml 46 | *.yaml 47 | !.github/workflows/format.yml 48 | 49 | # ignore images 50 | *.png 51 | !panel_example.png 52 | !logo.png 53 | # except for pngs in the figures folder 54 | !figures/*.png 55 | 56 | # ignore any weights files 57 | weights/ 58 | 59 | # preferences file 60 | comp2comp/preferences.yaml 61 | 62 | # model directory 63 | **/.comp2comp_model_dir/ 64 | 65 | # slurm outputs 66 | **/slurm/ 67 | 68 | # ignore outputs file 69 | **/outputs/ 70 | 71 | **/models/ 72 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | COPY . /Comp2Comp 3 | WORKDIR /Comp2Comp 4 | RUN pip install -e . 5 | RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Comp2Comp 2 | [![License: Apache 2.0](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 3 | ![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/StanfordMIMI/Comp2Comp/format.yml?branch=master) 4 | [![Documentation Status](https://readthedocs.org/projects/comp2comp/badge/?version=latest)](https://comp2comp.readthedocs.io/en/latest/?badge=latest) 5 | 6 | [**Paper**](https://arxiv.org/abs/2302.06568) 7 | | [**Installation**](#installation) 8 | | [**Basic Usage**](#basic_usage) 9 | | [**Inference Pipelines**](#basic_usage) 10 | | [**Contribute**](#contribute) 11 | | [**Citation**](#citation) 12 | 13 | Comp2Comp is a library for extracting clinical insights from computed tomography scans. 14 | 15 | ## Installation 16 | 17 | ```bash 18 | git clone https://github.com/StanfordMIMI/Comp2Comp/ 19 | 20 | # Install script requires Anaconda/Miniconda. 21 | cd Comp2Comp && bin/install.sh 22 | ``` 23 | 24 | Alternatively, Comp2Comp can be installed with `pip`: 25 | ```bash 26 | git clone https://github.com/StanfordMIMI/Comp2Comp/ 27 | cd Comp2Comp 28 | conda create -n c2c_env python=3.9 29 | conda activate c2c_env 30 | pip install -e . 31 | ``` 32 | 33 | For installing on the Apple M1 chip, see [these instructions](https://github.com/StanfordMIMI/Comp2Comp/blob/master/docs/Local%20Implementation%20%40%20M1%20arm64%20Silicon.md). 34 | 35 | ## Basic Usage 36 | 37 | ```bash 38 | bin/C2C -i 39 | ``` 40 | 41 | For running on slurm, modify the above commands as follow: 42 | ```bash 43 | bin/C2C-slurm -i 44 | ``` 45 | 46 | ## Inference Pipelines 47 | 48 | We have designed Comp2Comp to be highly extensible and to enable the development of complex clinically-relevant applications. We observed that many clinical applications require chaining several machine learning or other computational modules together to generate complex insights. The inference pipeline system is designed to make this easy. Furthermore, we seek to make the code readable and modular, so that the community can easily contribute to the project. 49 | 50 | The [`InferencePipeline` class](comp2comp/inference_pipeline.py) is used to create inference pipelines, which are made up of a sequence of [`InferenceClass` objects](comp2comp/inference_class_base.py). When the `InferencePipeline` object is called, it sequentially calls the `InferenceClasses` that were provided to the constructor. 51 | 52 | The first argument of the `__call__` function of `InferenceClass` must be the `InferencePipeline` object. This allows each `InferenceClass` object to access or set attributes of the `InferencePipeline` object that can be accessed by the subsequent `InferenceClass` objects in the pipeline. Each `InferenceClass` object should return a dictionary where the keys of the dictionary should match the keyword arguments of the subsequent `InferenceClass's` `__call__` function. If an `InferenceClass` object only sets attributes of the `InferencePipeline` object but does not return any value, an empty dictionary can be returned. 53 | 54 | Below are the inference pipelines currently supported by Comp2Comp. 55 | 56 | ## End-to-End Spine, Muscle, and Adipose Tissue Analysis at T12-L5 57 | 58 | ### Usage 59 | ```bash 60 | bin/C2C spine_muscle_adipose_tissue -i 61 | ``` 62 | - input_path should contain a DICOM series or subfolders that contain DICOM series. 63 | 64 | ### Example Output Image 65 |

66 | 67 |

68 | 69 | ## Spine Bone Mineral Density from 3D Trabecular Bone Regions at T12-L5 70 | 71 | ### Usage 72 | ```bash 73 | bin/C2C spine -i 74 | ``` 75 | - input_path should contain a DICOM series or subfolders that contain DICOM series. 76 | 77 | ### Example Output Image 78 |

79 | 80 |

81 | 82 | ## Abdominal Aortic Calcification Segmentation 83 | 84 | ### Usage 85 | ```bash 86 | bin/C2C aortic_calcium -i -o --threshold --mosaic-type 87 | ``` 88 | The input path should contain a DICOM series or subfolders that contain DICOM series or a nifty file. 89 | - The threshold can be controlled with `--threshold` and be either an integer HU threshold, "adataptive" or "agatson". 90 | - If "agatson" is used, agatson score is calculated and a threshold of 130 HU is used 91 | - Aortic calcifications are divided into abdominal and thoracic at the end of the T12 level 92 | - Segmentation masks for the aortic calcium, the dilated aorta mask, and the T12 seperation plane are saved in ./segmentation_masks/ 93 | - Metrics on an aggregated and individual level for the calcifications are written to .csv files in ./metrics/ 94 | - Visualizations are saved to ./images/ 95 | - The visualization presents coronal and sagittal MIP projections with the aorta overlay, featuring a heat map of calcifications alongside extracted calcification metrics. Below is a mosaic of each aortic slice with calcifications. 96 | - The mosaic will default show all slices with califications but a subset at each vertebra level can be used instead with `--mosaic-type vertebrae` 97 | 98 |

99 | 100 |

101 | 102 | ### Example Output 103 | ``` 104 | Statistics on aortic calcifications: 105 | Abdominal: 106 | Total number: 21 107 | Total volume (cm³): 1.042 108 | Mean HU: 218.6+/-91.4 109 | Median HU: 195.6+/-65.8 110 | Max HU: 449.4+/-368.6 111 | Mean volume (cm³): 0.050+/-0.100 112 | Median volume (cm³): 0.006 113 | Max volume (cm³): 0.425 114 | Min volume (cm³): 0.002 115 | Threshold (HU): 130.000 116 | % Calcified aorta 3.429 117 | Agatston score: 4224.7 118 | 119 | 120 | Thoracic: 121 | Total number: 5 122 | Total volume (cm³): 0.012 123 | Mean HU: 171.6+/-41.0 124 | Median HU: 168.5+/-42.7 125 | Max HU: 215.8+/-87.1 126 | Mean volume (cm³): 0.002+/-0.001 127 | Median volume (cm³): 0.002 128 | Max volume (cm³): 0.004 129 | Min volume (cm³): 0.002 130 | Threshold (HU): 130.000 131 | % Calcified aorta 0.026 132 | Agatston score: 21.1 133 | ``` 134 | 135 | ## AAA Segmentation and Maximum Diameter Measurement 136 | 137 | ### Usage 138 | ```bash 139 | bin/C2C aaa -i 140 | ``` 141 | - input_path should contain a DICOM series or subfolders that contain DICOM series. 142 | 143 | ### Example Output Image (slice with largest diameter) 144 |

145 | 146 |

147 | 148 |
149 | 150 | | Example Output Video | Example Output Graph | 151 | |-----------------------------|----------------------------| 152 | |

|

| 153 | 154 |
155 | 156 | ## Contrast Phase Detection 157 | 158 | ### Usage 159 | ```bash 160 | bin/C2C contrast_phase -i 161 | ``` 162 | - input_path should contain a DICOM series or subfolders that contain DICOM series. 163 | - This package has extra dependencies. To install those, run: 164 | ```bash 165 | cd Comp2Comp 166 | pip install -e '.[contrast_phase]' 167 | ``` 168 | 169 | ## 3D Analysis of Liver, Spleen, and Pancreas 170 | 171 | ### Usage 172 | ```bash 173 | bin/C2C liver_spleen_pancreas -i 174 | ``` 175 | - input_path should contain a DICOM series or subfolders that contain DICOM series. 176 | 177 | ### Example Output Image 178 |

179 | 180 |

181 | 182 | 183 | ## Contribute 184 | 185 | We welcome all pull requests. If you have any issues, suggestions, or feedback, please open a new issue. 186 | 187 | ## Citation 188 | 189 | ``` 190 | @article{blankemeier2023comp2comp, 191 | title={Comp2Comp: Open-Source Body Composition Assessment on Computed Tomography}, 192 | author={Blankemeier, Louis and Desai, Arjun and Chaves, Juan Manuel Zambrano and Wentland, Andrew and Yao, Sally and Reis, Eduardo and Jensen, Malte and Bahl, Bhanushree and Arora, Khushboo and Patel, Bhavik N and others}, 193 | journal={arXiv preprint arXiv:2302.06568}, 194 | year={2023} 195 | } 196 | ``` 197 | 198 | In addition to Comp2Comp, please consider citing TotalSegmentator: 199 | ``` 200 | @article{wasserthal2022totalsegmentator, 201 | title={TotalSegmentator: robust segmentation of 104 anatomical structures in CT images}, 202 | author={Wasserthal, Jakob and Meyer, Manfred and Breit, Hanns-Christian and Cyriac, Joshy and Yang, Shan and Segeroth, Martin}, 203 | journal={arXiv preprint arXiv:2208.05868}, 204 | year={2022} 205 | } 206 | ``` 207 | 208 | 209 | -------------------------------------------------------------------------------- /bin/C2C: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os 4 | 5 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 6 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" 7 | 8 | from comp2comp.aaa import aaa 9 | from comp2comp.aortic_calcium import ( 10 | aortic_calcium, 11 | aortic_calcium_visualization, 12 | ) 13 | from comp2comp.contrast_phase.contrast_phase import ContrastPhaseDetection 14 | from comp2comp.hip import hip 15 | from comp2comp.inference_pipeline import InferencePipeline 16 | from comp2comp.io import io 17 | from comp2comp.liver_spleen_pancreas import ( 18 | liver_spleen_pancreas, 19 | liver_spleen_pancreas_visualization, 20 | ) 21 | from comp2comp.muscle_adipose_tissue import ( 22 | muscle_adipose_tissue, 23 | muscle_adipose_tissue_visualization, 24 | ) 25 | from comp2comp.spine import spine 26 | from comp2comp.utils import orientation 27 | from comp2comp.utils.process import process_3d 28 | 29 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 30 | 31 | ### AAA Pipeline 32 | 33 | def AAAPipelineBuilder(path, args): 34 | pipeline = InferencePipeline( 35 | [ 36 | AxialCropperPipelineBuilder(path, args), 37 | aaa.AortaSegmentation(), 38 | aaa.AortaDiameter(), 39 | aaa.AortaMetricsSaver() 40 | ] 41 | ) 42 | return pipeline 43 | 44 | def MuscleAdiposeTissuePipelineBuilder(args): 45 | pipeline = InferencePipeline( 46 | [ 47 | muscle_adipose_tissue.MuscleAdiposeTissueSegmentation( 48 | 16, args.muscle_fat_model 49 | ), 50 | muscle_adipose_tissue.MuscleAdiposeTissuePostProcessing(), 51 | muscle_adipose_tissue.MuscleAdiposeTissueComputeMetrics(), 52 | muscle_adipose_tissue_visualization.MuscleAdiposeTissueVisualizer(), 53 | muscle_adipose_tissue.MuscleAdiposeTissueH5Saver(), 54 | muscle_adipose_tissue.MuscleAdiposeTissueMetricsSaver(), 55 | ] 56 | ) 57 | return pipeline 58 | 59 | 60 | def MuscleAdiposeTissueFullPipelineBuilder(args): 61 | pipeline = InferencePipeline( 62 | [io.DicomFinder(args.input_path), MuscleAdiposeTissuePipelineBuilder(args)] 63 | ) 64 | return pipeline 65 | 66 | 67 | def SpinePipelineBuilder(path, args): 68 | pipeline = InferencePipeline( 69 | [ 70 | io.DicomToNifti(path), 71 | spine.SpineSegmentation(args.spine_model, save=True), 72 | orientation.ToCanonical(), 73 | spine.SpineComputeROIs(args.spine_model), 74 | spine.SpineMetricsSaver(), 75 | spine.SpineCoronalSagittalVisualizer(format="png"), 76 | spine.SpineReport(format="png"), 77 | ] 78 | ) 79 | return pipeline 80 | 81 | 82 | def AxialCropperPipelineBuilder(path, args): 83 | pipeline = InferencePipeline( 84 | [ 85 | io.DicomToNifti(path, "aaa"), 86 | spine.SpineSegmentation(args.spine_model), 87 | orientation.ToCanonical(), 88 | spine.AxialCropper(lower_level="L5", upper_level="L1", save=True), 89 | ] 90 | ) 91 | return pipeline 92 | 93 | 94 | def SpineMuscleAdiposeTissuePipelineBuilder(path, args): 95 | pipeline = InferencePipeline( 96 | [ 97 | SpinePipelineBuilder(path, args), 98 | spine.SpineFindDicoms(), 99 | MuscleAdiposeTissuePipelineBuilder(args), 100 | spine.SpineMuscleAdiposeTissueReport(), 101 | ] 102 | ) 103 | return pipeline 104 | 105 | 106 | def LiverSpleenPancreasPipelineBuilder(path, args): 107 | pipeline = InferencePipeline( 108 | [ 109 | io.DicomToNifti(path), 110 | liver_spleen_pancreas.LiverSpleenPancreasSegmentation(), 111 | orientation.ToCanonical(), 112 | liver_spleen_pancreas_visualization.LiverSpleenPancreasVisualizer(), 113 | liver_spleen_pancreas_visualization.LiverSpleenPancreasMetricsPrinter(), 114 | ] 115 | ) 116 | return pipeline 117 | 118 | 119 | def AorticCalciumPipelineBuilder(path, args): 120 | pipeline = InferencePipeline( 121 | [ 122 | io.DicomToNifti(path), 123 | spine.SpineSegmentation(model_name=args.spine_model), 124 | orientation.ToCanonical(), 125 | aortic_calcium.AortaSegmentation(), 126 | orientation.ToCanonical(), 127 | aortic_calcium.AorticCalciumSegmentation(), 128 | aortic_calcium.AorticCalciumMetrics(), 129 | aortic_calcium_visualization.AorticCalciumVisualizer(), 130 | aortic_calcium_visualization.AorticCalciumPrinter(), 131 | ], 132 | args=args 133 | ) 134 | return pipeline 135 | 136 | 137 | def ContrastPhasePipelineBuilder(path, args): 138 | pipeline = InferencePipeline([io.DicomToNifti(path), ContrastPhaseDetection(path)]) 139 | return pipeline 140 | 141 | 142 | def HipPipelineBuilder(path, args): 143 | pipeline = InferencePipeline( 144 | [ 145 | io.DicomToNifti(path), 146 | hip.HipSegmentation(args.hip_model), 147 | orientation.ToCanonical(), 148 | hip.HipComputeROIs(args.hip_model), 149 | hip.HipMetricsSaver(), 150 | hip.HipVisualizer(), 151 | ] 152 | ) 153 | return pipeline 154 | 155 | 156 | def AllPipelineBuilder(path, args): 157 | pipeline = InferencePipeline( 158 | [ 159 | io.DicomToNifti(path), 160 | SpineMuscleAdiposeTissuePipelineBuilder(path, args), 161 | LiverSpleenPancreasPipelineBuilder(path, args), 162 | HipPipelineBuilder(path, args), 163 | ] 164 | ) 165 | return pipeline 166 | 167 | 168 | def argument_parser(): 169 | base_parser = argparse.ArgumentParser(add_help=False) 170 | base_parser.add_argument("--input_path", "-i", type=str, required=True) 171 | base_parser.add_argument("--output_path", "-o", type=str) 172 | base_parser.add_argument("--save_segmentations", action="store_true") 173 | base_parser.add_argument("--overwrite_outputs", action="store_true") 174 | 175 | parser = argparse.ArgumentParser() 176 | subparsers = parser.add_subparsers(dest="pipeline", help="Pipeline to run") 177 | 178 | # Add the help option to each subparser 179 | muscle_adipose_tissue_parser = subparsers.add_parser( 180 | "muscle_adipose_tissue", parents=[base_parser] 181 | ) 182 | muscle_adipose_tissue_parser.add_argument( 183 | "--muscle_fat_model", default="abCT_v0.0.1", type=str 184 | ) 185 | 186 | # Spine 187 | spine_parser = subparsers.add_parser("spine", parents=[base_parser]) 188 | spine_parser.add_argument("--spine_model", default="ts_spine", type=str) 189 | 190 | # Spine + muscle + fat 191 | spine_muscle_adipose_tissue_parser = subparsers.add_parser( 192 | "spine_muscle_adipose_tissue", parents=[base_parser] 193 | ) 194 | spine_muscle_adipose_tissue_parser.add_argument( 195 | "--muscle_fat_model", default="stanford_v0.0.2", type=str 196 | ) 197 | spine_muscle_adipose_tissue_parser.add_argument( 198 | "--spine_model", default="ts_spine", type=str 199 | ) 200 | 201 | # Liver spleen pancreas 202 | liver_spleen_pancreas = subparsers.add_parser( 203 | "liver_spleen_pancreas", parents=[base_parser] 204 | ) 205 | 206 | # Aortic calcium 207 | aortic_calcium = subparsers.add_parser( 208 | "aortic_calcium", parents=[base_parser]) 209 | 210 | aortic_calcium.add_argument( 211 | "--threshold", default="adaptive", type=str 212 | ) 213 | aortic_calcium.add_argument( 214 | "--spine-model", default="ts_spine", type=str, help='Chose the model to perfom the spine segmentation' 215 | ) 216 | aortic_calcium.add_argument( 217 | "--mosaic-type", default='all', type=str, help='Chose the the type of axial mosaic in the overview image' 218 | ) 219 | 220 | # Contrast phase 221 | contrast_phase_parser = subparsers.add_parser( 222 | "contrast_phase", parents=[base_parser] 223 | ) 224 | 225 | hip_parser = subparsers.add_parser("hip", parents=[base_parser]) 226 | hip_parser.add_argument( 227 | "--hip_model", 228 | default="ts_hip", 229 | type=str, 230 | ) 231 | 232 | # AAA 233 | aorta_diameter_parser = subparsers.add_parser("aaa", help="aorta diameter", parents=[base_parser]) 234 | 235 | aorta_diameter_parser.add_argument( 236 | "--aorta_model", 237 | default="ts_spine", 238 | type=str, 239 | help="aorta model to use for inference", 240 | ) 241 | 242 | aorta_diameter_parser.add_argument( 243 | "--spine_model", 244 | default="ts_spine", 245 | type=str, 246 | help="spine model to use for inference", 247 | ) 248 | 249 | all_parser = subparsers.add_parser("all", parents=[base_parser]) 250 | all_parser.add_argument( 251 | "--muscle_fat_model", 252 | default="abCT_v0.0.1", 253 | type=str, 254 | ) 255 | all_parser.add_argument( 256 | "--spine_model", 257 | default="ts_spine", 258 | type=str, 259 | ) 260 | all_parser.add_argument( 261 | "--hip_model", 262 | default="ts_hip", 263 | type=str, 264 | ) 265 | return parser 266 | 267 | 268 | def main(): 269 | args = argument_parser().parse_args() 270 | if args.pipeline == "spine_muscle_adipose_tissue": 271 | process_3d(args, SpineMuscleAdiposeTissuePipelineBuilder) 272 | elif args.pipeline == "spine": 273 | process_3d(args, SpinePipelineBuilder) 274 | elif args.pipeline == "contrast_phase": 275 | process_3d(args, ContrastPhasePipelineBuilder) 276 | elif args.pipeline == "liver_spleen_pancreas": 277 | process_3d(args, LiverSpleenPancreasPipelineBuilder) 278 | elif args.pipeline == "aortic_calcium": 279 | process_3d(args, AorticCalciumPipelineBuilder) 280 | elif args.pipeline == "hip": 281 | process_3d(args, HipPipelineBuilder) 282 | elif args.pipeline == "aaa": 283 | process_3d(args, AAAPipelineBuilder) 284 | elif args.pipeline == "all": 285 | process_3d(args, AllPipelineBuilder) 286 | else: 287 | raise AssertionError("{} command not supported".format(args.action)) 288 | 289 | 290 | if __name__ == "__main__": 291 | main() 292 | -------------------------------------------------------------------------------- /bin/C2C-slurm: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import pipes 4 | import subprocess 5 | import sys 6 | from pathlib import Path 7 | 8 | exec_file = sys.argv[0].split("-")[0] 9 | command = exec_file + " " + " ".join([pipes.quote(s) for s in sys.argv[1:]]) 10 | 11 | def submit_command(command): 12 | subprocess.run(command.split(" "), check=True, capture_output=False) 13 | 14 | 15 | def python_submit(command, node=None): 16 | bash_file = open("./slurm.sh", "w") 17 | bash_file.write(f"#!/bin/bash\n{command}") 18 | bash_file.close() 19 | slurm_output_path = Path("./slurm/") 20 | slurm_output_path.mkdir(parents=True, exist_ok=True) 21 | 22 | try: 23 | if node is None: 24 | command = "sbatch --ntasks=1 --cpus-per-task=8 --output ./slurm/slurm-%j.out \ 25 | --mem-per-cpu=3G -p gpu --gpus 1 --time=1:00:00 slurm.sh" 26 | submit_command(command) 27 | print(f'Submitted the command --- "{command}" --- to slurm.') 28 | else: 29 | command = f"sbatch --ntasks=1 --cpus-per-task=8 --output ./slurm/slurm-%j.out \ 30 | --nodelist={node} --mem-per-cpu=3G -p gpu --gpus 1 --time=1:00:00 slurm.sh" 31 | submit_command(command) 32 | print(f'Submitted the command --- "{command}" --- to slurm.') 33 | except subprocess.CalledProcessError: 34 | if node == None: 35 | command = f"sbatch -c 8 --gres=gpu:1 --output ./slurm/slurm-%j.out --mem=60gb --time=100-00:00:00 slurm.sh " 36 | submit_command(command) 37 | print(f'Submitted the command --- "{command}" --- to slurm.') 38 | else: 39 | # command = f"sbatch -c 8 --gres=gpu:titanrtx:1 --output ./slurm/slurm-%j.out --nodelist={node} --mem=60gb --time=100-00:00:00 slurm.sh" 40 | command = f"sbatch -c 8 --gres=gpu:1 --output ./slurm/slurm-%j.out --nodelist={node} --mem=60gb --time=100-00:00:00 slurm.sh" 41 | submit_command(command) 42 | print(f'Submitted the command --- "{command}" --- to slurm.') 43 | os.remove("./slurm.sh") 44 | 45 | 46 | python_submit(command, node='amalfi') 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /bin/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ============================================================================== 4 | # Auto-installation for abCTSeg for Linux and Mac machines. 5 | # This setup script is adapted from DOSMA: 6 | # https://github.com/ad12/DOSMA 7 | # ============================================================================== 8 | 9 | BIN_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 10 | 11 | ANACONDA_KEYWORD="anaconda" 12 | ANACONDA_DOWNLOAD_URL="https://www.anaconda.com/distribution/" 13 | MINICONDA_KEYWORD="miniconda" 14 | 15 | # FIXME: Update the name. 16 | ABCT_ENV_NAME="c2c_env" 17 | 18 | hasAnaconda=0 19 | updateEnv=0 20 | updatePath=1 21 | pythonVersion="3.9" 22 | cudaVersion="" 23 | 24 | while [[ $# -gt 0 ]]; do 25 | key="$1" 26 | case $key in 27 | -h|--help) 28 | echo "Batch evaluation with ss_recon" 29 | echo "" 30 | echo "Usage:" 31 | echo " --python Python version" 32 | echo " -f, --force Force environment update" 33 | exit 34 | ;; 35 | --python) 36 | pythonVersion=$2 37 | shift # past argument 38 | shift # past value 39 | ;; 40 | --cuda) 41 | cudaVersion=$2 42 | shift # past argument 43 | shift # past value 44 | ;; 45 | -f|--force) 46 | updateEnv=1 47 | shift # past argument 48 | ;; 49 | *) 50 | echo "Unknown option: $key" 51 | exit 1 52 | ;; 53 | esac 54 | done 55 | 56 | # Initial setup 57 | source ~/.bashrc 58 | currDir=`pwd` 59 | 60 | 61 | if echo $PATH | grep -q $ANACONDA_KEYWORD; then 62 | hasAnaconda=1 63 | echo "Conda found in path" 64 | fi 65 | 66 | if echo $PATH | grep -q $MINICONDA_KEYWORD; then 67 | hasAnaconda=1 68 | echo "Miniconda found in path" 69 | fi 70 | 71 | if [[ $hasAnaconda -eq 0 ]]; then 72 | echo "Anaconda/Miniconda not installed - install from $ANACONDA_DOWNLOAD_URL" 73 | openURL $ANACONDA_DOWNLOAD_URL 74 | exit 125 75 | fi 76 | 77 | # Hacky way of finding the conda base directory 78 | condaPath=`which conda` 79 | condaPath=`dirname ${condaPath}` 80 | condaPath=`dirname ${condaPath}` 81 | # Source conda 82 | source $condaPath/etc/profile.d/conda.sh 83 | 84 | # Check if OS is supported 85 | if [[ "$OSTYPE" != "linux-gnu" && "$OSTYPE" != "darwin"* ]]; then 86 | echo "Only Linux and MacOS are supported" 87 | exit 125 88 | fi 89 | 90 | # Create Anaconda environment (dosma_env) 91 | if [[ `conda env list | grep $ABCT_ENV_NAME` ]]; then 92 | if [[ ${updateEnv} -eq 0 ]]; then 93 | echo "Environment '${ABCT_ENV_NAME}' is installed. Run 'conda activate ${ABCT_ENV_NAME}' to get started." 94 | exit 0 95 | else 96 | conda env remove -n $ABCT_ENV_NAME 97 | conda create -y -n $ABCT_ENV_NAME python=3.9 98 | fi 99 | else 100 | conda create -y -n $ABCT_ENV_NAME python=3.9 101 | fi 102 | 103 | conda activate $ABCT_ENV_NAME 104 | 105 | # Install tensorflow and keras 106 | # https://www.tensorflow.org/install/source#gpu 107 | # pip install tensorflow 108 | 109 | # Install pytorch 110 | # FIXME: PyTorch has to be installed with pip to respect setup.py files from nn UNet 111 | # pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu 112 | # if [[ "$OSTYPE" == "darwin"* ]]; then 113 | # # Mac 114 | # if [[ $cudaVersion != "" ]]; then 115 | # # CPU 116 | # echo "Cannot install PyTorch with CUDA support on Mac" 117 | # exit 1 118 | # fi 119 | # conda install -y pytorch torchvision torchaudio -c pytorch 120 | # else 121 | # # Linux 122 | # if [[ $cudaVersion == "" ]]; then 123 | # cudatoolkit="cpuonly" 124 | # else 125 | # cudatoolkit="cudatoolkit=${cudaVersion}" 126 | # fi 127 | # conda install -y pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 $cudatoolkit -c pytorch 128 | # fi 129 | 130 | # Install detectron2 131 | # FIXME: Remove dependency on detectron2 132 | #pip3 install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html 133 | 134 | # Install totalSegmentor 135 | # FIXME: Add this to the setup.py file 136 | # pip3 install git+https://github.com/StanfordMIMI/TotalSegmentator.git 137 | 138 | # cd $currDir/.. 139 | # echo $currDir 140 | # exit 1 141 | 142 | pip install -e . --no-cache-dir 143 | 144 | echo "" 145 | echo "" 146 | echo "Run 'conda activate ${ABCT_ENV_NAME}' to get started." -------------------------------------------------------------------------------- /comp2comp/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils.env import setup_environment 2 | 3 | setup_environment() 4 | 5 | 6 | # This line will be programatically read/write by setup.py. 7 | # Leave them at the bottom of this file and don't touch them. 8 | __version__ = "0.0.1" 9 | -------------------------------------------------------------------------------- /comp2comp/aaa/aaa.py: -------------------------------------------------------------------------------- 1 | import math 2 | import operator 3 | import os 4 | import zipfile 5 | from pathlib import Path 6 | from time import time 7 | from tkinter import Tcl 8 | from typing import Union 9 | 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | import moviepy.video.io.ImageSequenceClip 13 | import nibabel as nib 14 | import numpy as np 15 | import pandas as pd 16 | import pydicom 17 | import wget 18 | from totalsegmentator.libs import nostdout 19 | 20 | from comp2comp.inference_class_base import InferenceClass 21 | 22 | 23 | class AortaSegmentation(InferenceClass): 24 | """Spine segmentation.""" 25 | 26 | def __init__(self, save=True): 27 | super().__init__() 28 | self.model_name = "totalsegmentator" 29 | self.save_segmentations = save 30 | 31 | def __call__(self, inference_pipeline): 32 | # inference_pipeline.dicom_series_path = self.input_path 33 | self.output_dir = inference_pipeline.output_dir 34 | self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") 35 | if not os.path.exists(self.output_dir_segmentations): 36 | os.makedirs(self.output_dir_segmentations) 37 | 38 | self.model_dir = inference_pipeline.model_dir 39 | 40 | seg, mv = self.spine_seg( 41 | os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), 42 | self.output_dir_segmentations + "spine.nii.gz", 43 | inference_pipeline.model_dir, 44 | ) 45 | 46 | seg = seg.get_fdata() 47 | medical_volume = mv.get_fdata() 48 | 49 | axial_masks = [] 50 | ct_image = [] 51 | 52 | for i in range(seg.shape[2]): 53 | axial_masks.append(seg[:, :, i]) 54 | 55 | for i in range(medical_volume.shape[2]): 56 | ct_image.append(medical_volume[:, :, i]) 57 | 58 | # Save input axial slices to pipeline 59 | inference_pipeline.ct_image = ct_image 60 | 61 | # Save aorta masks to pipeline 62 | inference_pipeline.axial_masks = axial_masks 63 | 64 | return {} 65 | 66 | def setup_nnunet_c2c(self, model_dir: Union[str, Path]): 67 | """Adapted from TotalSegmentator.""" 68 | 69 | model_dir = Path(model_dir) 70 | config_dir = model_dir / Path("." + self.model_name) 71 | (config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir( 72 | exist_ok=True, parents=True 73 | ) 74 | (config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True) 75 | weights_dir = config_dir / "nnunet/results" 76 | self.weights_dir = weights_dir 77 | 78 | os.environ["nnUNet_raw_data_base"] = str( 79 | weights_dir 80 | ) # not needed, just needs to be an existing directory 81 | os.environ["nnUNet_preprocessed"] = str( 82 | weights_dir 83 | ) # not needed, just needs to be an existing directory 84 | os.environ["RESULTS_FOLDER"] = str(weights_dir) 85 | 86 | def download_spine_model(self, model_dir: Union[str, Path]): 87 | download_dir = Path( 88 | os.path.join( 89 | self.weights_dir, 90 | "nnUNet/3d_fullres/Task253_Aorta/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1", 91 | ) 92 | ) 93 | print(download_dir) 94 | fold_0_path = download_dir / "fold_0" 95 | if not os.path.exists(fold_0_path): 96 | download_dir.mkdir(parents=True, exist_ok=True) 97 | wget.download( 98 | "https://huggingface.co/AdritRao/aaa_test/resolve/main/fold_0.zip", 99 | out=os.path.join(download_dir, "fold_0.zip"), 100 | ) 101 | with zipfile.ZipFile( 102 | os.path.join(download_dir, "fold_0.zip"), "r" 103 | ) as zip_ref: 104 | zip_ref.extractall(download_dir) 105 | os.remove(os.path.join(download_dir, "fold_0.zip")) 106 | wget.download( 107 | "https://huggingface.co/AdritRao/aaa_test/resolve/main/plans.pkl", 108 | out=os.path.join(download_dir, "plans.pkl"), 109 | ) 110 | print("Spine model downloaded.") 111 | else: 112 | print("Spine model already downloaded.") 113 | 114 | def spine_seg( 115 | self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir 116 | ): 117 | """Run spine segmentation. 118 | 119 | Args: 120 | input_path (Union[str, Path]): Input path. 121 | output_path (Union[str, Path]): Output path. 122 | """ 123 | 124 | print("Segmenting spine...") 125 | st = time() 126 | os.environ["SCRATCH"] = self.model_dir 127 | 128 | print(self.model_dir) 129 | 130 | # Setup nnunet 131 | model = "3d_fullres" 132 | folds = [0] 133 | trainer = "nnUNetTrainerV2_ep4000_nomirror" 134 | crop_path = None 135 | task_id = [253] 136 | 137 | self.setup_nnunet_c2c(model_dir) 138 | self.download_spine_model(model_dir) 139 | 140 | from totalsegmentator.nnunet import nnUNet_predict_image 141 | 142 | with nostdout(): 143 | img, seg = nnUNet_predict_image( 144 | input_path, 145 | output_path, 146 | task_id, 147 | model=model, 148 | folds=folds, 149 | trainer=trainer, 150 | tta=False, 151 | multilabel_image=True, 152 | resample=1.5, 153 | crop=None, 154 | crop_path=crop_path, 155 | task_name="total", 156 | nora_tag="None", 157 | preview=False, 158 | nr_threads_resampling=1, 159 | nr_threads_saving=6, 160 | quiet=False, 161 | verbose=False, 162 | test=0, 163 | ) 164 | end = time() 165 | 166 | # Log total time for spine segmentation 167 | print(f"Total time for spine segmentation: {end-st:.2f}s.") 168 | 169 | seg_data = seg.get_fdata() 170 | seg = nib.Nifti1Image(seg_data, seg.affine, seg.header) 171 | 172 | return seg, img 173 | 174 | 175 | class AortaDiameter(InferenceClass): 176 | def __init__(self): 177 | super().__init__() 178 | 179 | def normalize_img(self, img: np.ndarray) -> np.ndarray: 180 | """Normalize the image. 181 | Args: 182 | img (np.ndarray): Input image. 183 | Returns: 184 | np.ndarray: Normalized image. 185 | """ 186 | return (img - img.min()) / (img.max() - img.min()) 187 | 188 | def __call__(self, inference_pipeline): 189 | axial_masks = ( 190 | inference_pipeline.axial_masks 191 | ) # list of 2D numpy arrays of shape (512, 512) 192 | ct_img = ( 193 | inference_pipeline.ct_image 194 | ) # 3D numpy array of shape (512, 512, num_axial_slices) 195 | 196 | # image output directory 197 | output_dir = inference_pipeline.output_dir 198 | output_dir_slices = os.path.join(output_dir, "images/slices/") 199 | if not os.path.exists(output_dir_slices): 200 | os.makedirs(output_dir_slices) 201 | 202 | output_dir = inference_pipeline.output_dir 203 | output_dir_summary = os.path.join(output_dir, "images/summary/") 204 | if not os.path.exists(output_dir_summary): 205 | os.makedirs(output_dir_summary) 206 | 207 | DICOM_PATH = inference_pipeline.dicom_series_path 208 | dicom = pydicom.dcmread(DICOM_PATH + "/" + os.listdir(DICOM_PATH)[0]) 209 | 210 | dicom.PhotometricInterpretation = "YBR_FULL" 211 | pixel_conversion = dicom.PixelSpacing 212 | print("Pixel conversion: " + str(pixel_conversion)) 213 | RATIO_PIXEL_TO_MM = pixel_conversion[0] 214 | 215 | SLICE_COUNT = dicom["InstanceNumber"].value 216 | print(SLICE_COUNT) 217 | 218 | SLICE_COUNT = len(ct_img) 219 | diameterDict = {} 220 | 221 | for i in range(len(ct_img)): 222 | mask = axial_masks[i].astype("uint8") 223 | 224 | img = ct_img[i] 225 | 226 | img = np.clip(img, -300, 1800) 227 | img = self.normalize_img(img) * 255.0 228 | img = img.reshape((img.shape[0], img.shape[1], 1)) 229 | img = np.tile(img, (1, 1, 3)) 230 | 231 | contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) 232 | 233 | if len(contours) != 0: 234 | areas = [cv2.contourArea(c) for c in contours] 235 | sorted_areas = np.sort(areas) 236 | 237 | areas = [cv2.contourArea(c) for c in contours] 238 | sorted_areas = np.sort(areas) 239 | contours = contours[areas.index(sorted_areas[-1])] 240 | 241 | img.copy() 242 | 243 | back = img.copy() 244 | cv2.drawContours(back, [contours], 0, (0, 255, 0), -1) 245 | 246 | alpha = 0.25 247 | img = cv2.addWeighted(img, 1 - alpha, back, alpha, 0) 248 | 249 | ellipse = cv2.fitEllipse(contours) 250 | (xc, yc), (d1, d2), angle = ellipse 251 | 252 | cv2.ellipse(img, ellipse, (0, 255, 0), 1) 253 | 254 | xc, yc = ellipse[0] 255 | cv2.circle(img, (int(xc), int(yc)), 5, (0, 0, 255), -1) 256 | 257 | rmajor = max(d1, d2) / 2 258 | rminor = min(d1, d2) / 2 259 | 260 | ### Draw major axes 261 | 262 | if angle > 90: 263 | angle = angle - 90 264 | else: 265 | angle = angle + 90 266 | print(angle) 267 | xtop = xc + math.cos(math.radians(angle)) * rmajor 268 | ytop = yc + math.sin(math.radians(angle)) * rmajor 269 | xbot = xc + math.cos(math.radians(angle + 180)) * rmajor 270 | ybot = yc + math.sin(math.radians(angle + 180)) * rmajor 271 | cv2.line( 272 | img, (int(xtop), int(ytop)), (int(xbot), int(ybot)), (0, 0, 255), 3 273 | ) 274 | 275 | ### Draw minor axes 276 | 277 | if angle > 90: 278 | angle = angle - 90 279 | else: 280 | angle = angle + 90 281 | print(angle) 282 | x1 = xc + math.cos(math.radians(angle)) * rminor 283 | y1 = yc + math.sin(math.radians(angle)) * rminor 284 | x2 = xc + math.cos(math.radians(angle + 180)) * rminor 285 | y2 = yc + math.sin(math.radians(angle + 180)) * rminor 286 | cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 3) 287 | 288 | # pixel_length = math.sqrt( (x1-x2)**2 + (y1-y2)**2 ) 289 | pixel_length = rminor * 2 290 | 291 | print("Pixel_length_minor: " + str(pixel_length)) 292 | 293 | area_px = cv2.contourArea(contours) 294 | area_mm = round(area_px * RATIO_PIXEL_TO_MM) 295 | area_cm = area_mm / 10 296 | 297 | diameter_mm = round((pixel_length) * RATIO_PIXEL_TO_MM) 298 | diameter_cm = diameter_mm / 10 299 | 300 | diameterDict[(SLICE_COUNT - (i))] = diameter_cm 301 | 302 | img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) 303 | 304 | h, w, c = img.shape 305 | lbls = [ 306 | "Area (mm): " + str(area_mm) + "mm", 307 | "Area (cm): " + str(area_cm) + "cm", 308 | "Diameter (mm): " + str(diameter_mm) + "mm", 309 | "Diameter (cm): " + str(diameter_cm) + "cm", 310 | "Slice: " + str(SLICE_COUNT - (i)), 311 | ] 312 | font = cv2.FONT_HERSHEY_SIMPLEX 313 | 314 | scale = 0.03 315 | fontScale = min(w, h) / (25 / scale) 316 | 317 | cv2.putText(img, lbls[0], (10, 40), font, fontScale, (0, 255, 0), 2) 318 | 319 | cv2.putText(img, lbls[1], (10, 70), font, fontScale, (0, 255, 0), 2) 320 | 321 | cv2.putText(img, lbls[2], (10, 100), font, fontScale, (0, 255, 0), 2) 322 | 323 | cv2.putText(img, lbls[3], (10, 130), font, fontScale, (0, 255, 0), 2) 324 | 325 | cv2.putText(img, lbls[4], (10, 160), font, fontScale, (0, 255, 0), 2) 326 | 327 | cv2.imwrite( 328 | output_dir_slices + "slice" + str(SLICE_COUNT - (i)) + ".png", img 329 | ) 330 | 331 | plt.bar(list(diameterDict.keys()), diameterDict.values(), color="b") 332 | 333 | plt.title(r"$\bf{Diameter}$" + " " + r"$\bf{Progression}$") 334 | 335 | plt.xlabel("Slice Number") 336 | 337 | plt.ylabel("Diameter Measurement (cm)") 338 | plt.savefig(output_dir_summary + "diameter_graph.png", dpi=500) 339 | 340 | print(diameterDict) 341 | print(max(diameterDict.items(), key=operator.itemgetter(1))[0]) 342 | print(diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]]) 343 | 344 | inference_pipeline.max_diameter = diameterDict[ 345 | max(diameterDict.items(), key=operator.itemgetter(1))[0] 346 | ] 347 | 348 | img = ct_img[ 349 | SLICE_COUNT - (max(diameterDict.items(), key=operator.itemgetter(1))[0]) 350 | ] 351 | img = np.clip(img, -300, 1800) 352 | img = self.normalize_img(img) * 255.0 353 | img = img.reshape((img.shape[0], img.shape[1], 1)) 354 | img2 = np.tile(img, (1, 1, 3)) 355 | img2 = cv2.rotate(img2, cv2.ROTATE_90_COUNTERCLOCKWISE) 356 | 357 | img1 = cv2.imread( 358 | output_dir_slices 359 | + "slice" 360 | + str(max(diameterDict.items(), key=operator.itemgetter(1))[0]) 361 | + ".png" 362 | ) 363 | 364 | border_size = 3 365 | img1 = cv2.copyMakeBorder( 366 | img1, 367 | top=border_size, 368 | bottom=border_size, 369 | left=border_size, 370 | right=border_size, 371 | borderType=cv2.BORDER_CONSTANT, 372 | value=[0, 244, 0], 373 | ) 374 | img2 = cv2.copyMakeBorder( 375 | img2, 376 | top=border_size, 377 | bottom=border_size, 378 | left=border_size, 379 | right=border_size, 380 | borderType=cv2.BORDER_CONSTANT, 381 | value=[244, 0, 0], 382 | ) 383 | 384 | vis = np.concatenate((img2, img1), axis=1) 385 | cv2.imwrite(output_dir_summary + "out.png", vis) 386 | 387 | image_folder = output_dir_slices 388 | fps = 20 389 | image_files = [ 390 | os.path.join(image_folder, img) 391 | for img in Tcl().call("lsort", "-dict", os.listdir(image_folder)) 392 | if img.endswith(".png") 393 | ] 394 | clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip( 395 | image_files, fps=fps 396 | ) 397 | clip.write_videofile(output_dir_summary + "aaa.mp4") 398 | 399 | return {} 400 | 401 | 402 | class AortaMetricsSaver(InferenceClass): 403 | """Save metrics to a CSV file.""" 404 | 405 | def __init__(self): 406 | super().__init__() 407 | 408 | def __call__(self, inference_pipeline): 409 | """Save metrics to a CSV file.""" 410 | self.max_diameter = inference_pipeline.max_diameter 411 | self.dicom_series_path = inference_pipeline.dicom_series_path 412 | self.output_dir = inference_pipeline.output_dir 413 | self.csv_output_dir = os.path.join(self.output_dir, "metrics") 414 | if not os.path.exists(self.csv_output_dir): 415 | os.makedirs(self.csv_output_dir, exist_ok=True) 416 | self.save_results() 417 | return {} 418 | 419 | def save_results(self): 420 | """Save results to a CSV file.""" 421 | _, filename = os.path.split(self.dicom_series_path) 422 | data = [[filename, str(self.max_diameter)]] 423 | df = pd.DataFrame(data, columns=["Filename", "Max Diameter"]) 424 | df.to_csv(os.path.join(self.csv_output_dir, "aorta_metrics.csv"), index=False) 425 | -------------------------------------------------------------------------------- /comp2comp/aortic_calcium/aortic_calcium_visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | from comp2comp.inference_class_base import InferenceClass 6 | from comp2comp.aortic_calcium.visualization_utils import createMipPlot, createCalciumMosaic, mergeMipAndMosaic 7 | 8 | class AorticCalciumVisualizer(InferenceClass): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def __call__(self, inference_pipeline): 13 | self.output_dir = inference_pipeline.output_dir 14 | self.output_dir_images_organs = os.path.join(self.output_dir, "images/") 15 | inference_pipeline.output_dir_images_organs = self.output_dir_images_organs 16 | 17 | if not os.path.exists(self.output_dir_images_organs): 18 | os.makedirs(self.output_dir_images_organs) 19 | 20 | # Create MIP part of the overview plot 21 | createMipPlot( 22 | inference_pipeline.ct, 23 | inference_pipeline.calc_mask, 24 | inference_pipeline.aorta_mask, 25 | inference_pipeline.t12_plane == 1, 26 | inference_pipeline.calcium_threshold, 27 | inference_pipeline.pix_dims, 28 | inference_pipeline.metrics, 29 | self.output_dir_images_organs, 30 | ) 31 | 32 | ab_num = inference_pipeline.metrics['Abdominal']['num_calc'] 33 | th_num = inference_pipeline.metrics['Thoracic']['num_calc'] 34 | # Create mosaic part of the overview plot 35 | if not (ab_num == 0 and th_num == 0): 36 | createCalciumMosaic( 37 | inference_pipeline.ct, 38 | inference_pipeline.calc_mask, 39 | inference_pipeline.dilated_aorta_mask, # the dilated mask is used here 40 | inference_pipeline.spine_mask, 41 | inference_pipeline.pix_dims, 42 | self.output_dir_images_organs, 43 | inference_pipeline.args.mosaic_type, 44 | ) 45 | 46 | # Merge the two images created above for the final report 47 | mergeMipAndMosaic( 48 | self.output_dir_images_organs 49 | ) 50 | 51 | return {} 52 | 53 | 54 | class AorticCalciumPrinter(InferenceClass): 55 | def __init__(self): 56 | super().__init__() 57 | 58 | def __call__(self, inference_pipeline): 59 | 60 | all_metrics = inference_pipeline.metrics 61 | 62 | inference_pipeline.csv_output_dir = os.path.join( 63 | inference_pipeline.output_dir, "metrics" 64 | ) 65 | os.makedirs(inference_pipeline.csv_output_dir, exist_ok=True) 66 | 67 | # Write metrics to CSV file 68 | with open( 69 | os.path.join(inference_pipeline.csv_output_dir, "aortic_calcification.csv"), 70 | "w", 71 | ) as f: 72 | f.write("Volume (cm^3),Mean HU,Median HU,Max HU\n") 73 | 74 | with open( 75 | os.path.join(inference_pipeline.csv_output_dir, "aortic_calcification.csv"), 76 | "a", 77 | ) as f: 78 | 79 | for region, metrics in all_metrics.items(): 80 | f.write(region + ",,,\n") 81 | 82 | for vol, mean, median, max in zip( 83 | metrics["volume"], 84 | metrics["mean_hu"], 85 | metrics["median_hu"], 86 | metrics["max_hu"], 87 | ): 88 | f.write("{},{:.1f},{:.1f},{:.1f}\n".format(vol, mean, median, max)) 89 | 90 | # Write total results 91 | with open( 92 | os.path.join( 93 | inference_pipeline.csv_output_dir, "aortic_calcification_total.csv" 94 | ), 95 | "w", 96 | ) as f: 97 | for region, metrics in all_metrics.items(): 98 | f.write(region + ",\n") 99 | 100 | f.write("Total number,{}\n".format(metrics["num_calc"])) 101 | f.write("Total volume (cm^3),{:.3f}\n".format(metrics["volume_total"])) 102 | f.write( 103 | "Threshold (HU),{:.1f}\n".format( 104 | inference_pipeline.calcium_threshold 105 | ) 106 | ) 107 | 108 | f.write( 109 | "{},{:.1f}+/-{:.1f}\n".format( 110 | "Mean HU", 111 | np.mean(metrics["mean_hu"]), 112 | np.std(metrics["mean_hu"]), 113 | ) 114 | ) 115 | f.write( 116 | "{},{:.1f}+/-{:.1f}\n".format( 117 | "Median HU", 118 | np.mean(metrics["median_hu"]), 119 | np.std(metrics["median_hu"]), 120 | ) 121 | ) 122 | f.write( 123 | "{},{:.1f}+/-{:.1f}\n".format( 124 | "Max HU", 125 | np.mean(metrics["max_hu"]), 126 | np.std(metrics["max_hu"]), 127 | ) 128 | ) 129 | f.write( 130 | "{},{:.3f}+/-{:.3f}\n".format( 131 | "Mean volume (cm³):", 132 | np.mean(metrics["volume"]), 133 | np.std(metrics["volume"]), 134 | ) 135 | ) 136 | f.write( 137 | "{},{:.3f}\n".format( 138 | "Median volume (cm³)", np.median(metrics["volume"]) 139 | ) 140 | ) 141 | f.write( 142 | "{},{:.3f}\n".format("Max volume (cm³)", np.max(metrics["volume"])) 143 | ) 144 | f.write( 145 | "{},{:.3f}\n".format("Min volume (cm³):", np.min(metrics["volume"])) 146 | ) 147 | f.write( 148 | "{},{:.3f}\n".format("% Calcified aorta:", metrics["perc_calcified"]) 149 | ) 150 | 151 | if inference_pipeline.args.threshold == "agatston": 152 | f.write("Agatston score,{:.1f}\n".format(metrics["agatston_score"])) 153 | 154 | distance = 25 155 | print("\n") 156 | print("Statistics on aortic calcifications:") 157 | 158 | for region, metrics in all_metrics.items(): 159 | print(region + ":") 160 | 161 | if metrics["num_calc"] == 0: 162 | print("No aortic calcifications were found.\n") 163 | else: 164 | print("{:<{}}{}".format("Total number:", distance, metrics["num_calc"])) 165 | print( 166 | "{:<{}}{:.3f}".format( 167 | "Total volume (cm³):", distance, metrics["volume_total"] 168 | ) 169 | ) 170 | print( 171 | "{:<{}}{:.1f}+/-{:.1f}".format( 172 | "Mean HU:", 173 | distance, 174 | np.mean(metrics["mean_hu"]), 175 | np.std(metrics["mean_hu"]), 176 | ) 177 | ) 178 | print( 179 | "{:<{}}{:.1f}+/-{:.1f}".format( 180 | "Median HU:", 181 | distance, 182 | np.mean(metrics["median_hu"]), 183 | np.std(metrics["median_hu"]), 184 | ) 185 | ) 186 | print( 187 | "{:<{}}{:.1f}+/-{:.1f}".format( 188 | "Max HU:", 189 | distance, 190 | np.mean(metrics["max_hu"]), 191 | np.std(metrics["max_hu"]), 192 | ) 193 | ) 194 | print( 195 | "{:<{}}{:.3f}+/-{:.3f}".format( 196 | "Mean volume (cm³):", 197 | distance, 198 | np.mean(metrics["volume"]), 199 | np.std(metrics["volume"]), 200 | ) 201 | ) 202 | print( 203 | "{:<{}}{:.3f}".format( 204 | "Median volume (cm³):", distance, np.median(metrics["volume"]) 205 | ) 206 | ) 207 | print( 208 | "{:<{}}{:.3f}".format( 209 | "Max volume (cm³):", distance, np.max(metrics["volume"]) 210 | ) 211 | ) 212 | print( 213 | "{:<{}}{:.3f}".format( 214 | "Min volume (cm³):", distance, np.min(metrics["volume"]) 215 | ) 216 | ) 217 | print( 218 | "{:<{}}{:.3f}".format( 219 | "Threshold (HU):", 220 | distance, 221 | inference_pipeline.calcium_threshold, 222 | ) 223 | ) 224 | print( 225 | "{:<{}}{:.3f}".format( 226 | "% Calcified aorta", 227 | distance, 228 | metrics["perc_calcified"], 229 | ) 230 | ) 231 | 232 | if inference_pipeline.args.threshold == "agatston": 233 | print( 234 | "{:<{}}{:.1f}".format( 235 | "Agatston score:", distance, metrics["agatston_score"] 236 | ) 237 | ) 238 | 239 | print("\n") 240 | 241 | return {} 242 | -------------------------------------------------------------------------------- /comp2comp/contrast_phase/contrast_phase.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from time import time 4 | from typing import Union 5 | 6 | from totalsegmentator.libs import ( 7 | download_pretrained_weights, 8 | nostdout, 9 | setup_nnunet, 10 | ) 11 | # from totalsegmentatorv2.python_api import totalsegmentator 12 | 13 | from comp2comp.contrast_phase.contrast_inf import predict_phase 14 | from comp2comp.inference_class_base import InferenceClass 15 | 16 | 17 | class ContrastPhaseDetection(InferenceClass): 18 | """Contrast Phase Detection.""" 19 | 20 | def __init__(self, input_path): 21 | super().__init__() 22 | self.input_path = input_path 23 | 24 | def __call__(self, inference_pipeline): 25 | self.output_dir = inference_pipeline.output_dir 26 | self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") 27 | if not os.path.exists(self.output_dir_segmentations): 28 | os.makedirs(self.output_dir_segmentations) 29 | self.model_dir = inference_pipeline.model_dir 30 | 31 | seg, img = self.run_segmentation( 32 | os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), 33 | self.output_dir_segmentations + "s01.nii.gz", 34 | inference_pipeline.model_dir, 35 | ) 36 | 37 | # segArray, imgArray = self.convertNibToNumpy(seg, img) 38 | 39 | imgNiftiPath = os.path.join( 40 | self.output_dir_segmentations, "converted_dcm.nii.gz" 41 | ) 42 | segNiftPath = os.path.join(self.output_dir_segmentations, "s01.nii.gz") 43 | 44 | predict_phase(segNiftPath, imgNiftiPath, outputPath=self.output_dir) 45 | 46 | return {} 47 | 48 | def run_segmentation( 49 | self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir 50 | ): 51 | """Run segmentation. 52 | 53 | Args: 54 | input_path (Union[str, Path]): Input path. 55 | output_path (Union[str, Path]): Output path. 56 | """ 57 | 58 | print("Segmenting...") 59 | st = time() 60 | os.environ["SCRATCH"] = self.model_dir 61 | 62 | # Setup nnunet 63 | model = "3d_fullres" 64 | folds = [0] 65 | trainer = "nnUNetTrainerV2_ep4000_nomirror" 66 | crop_path = None 67 | task_id = [251] 68 | 69 | setup_nnunet() 70 | for task_id in [251]: 71 | download_pretrained_weights(task_id) 72 | 73 | from totalsegmentator.nnunet import nnUNet_predict_image 74 | 75 | with nostdout(): 76 | img, seg = nnUNet_predict_image( 77 | input_path, 78 | output_path, 79 | task_id, 80 | model=model, 81 | folds=folds, 82 | trainer=trainer, 83 | tta=False, 84 | multilabel_image=True, 85 | resample=1.5, 86 | crop=None, 87 | crop_path=crop_path, 88 | task_name="total", 89 | nora_tag=None, 90 | preview=False, 91 | nr_threads_resampling=1, 92 | nr_threads_saving=6, 93 | quiet=False, 94 | verbose=False, 95 | test=0, 96 | ) 97 | 98 | # seg = totalsegmentator( 99 | # input = os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), 100 | # output = os.path.join(self.output_dir_segmentations, "segmentation.nii"), 101 | # task_ids = [293], 102 | # ml = True, 103 | # nr_thr_resamp = 1, 104 | # nr_thr_saving = 6, 105 | # fast = False, 106 | # nora_tag = "None", 107 | # preview = False, 108 | # task = "total", 109 | # roi_subset = None, 110 | # statistics = False, 111 | # radiomics = False, 112 | # crop_path = None, 113 | # body_seg = False, 114 | # force_split = False, 115 | # output_type = "nifti", 116 | # quiet = False, 117 | # verbose = False, 118 | # test = 0, 119 | # skip_saving = True, 120 | # device = "gpu", 121 | # license_number = None, 122 | # statistics_exclude_masks_at_border = True, 123 | # no_derived_masks = False, 124 | # v1_order = False, 125 | # ) 126 | end = time() 127 | 128 | # Log total time for spine segmentation 129 | print(f"Total time for segmentation: {end-st:.2f}s.") 130 | 131 | # return seg, img 132 | return seg, img 133 | 134 | def convertNibToNumpy(self, TSNib, ImageNib): 135 | """Convert nifti to numpy array. 136 | 137 | Args: 138 | TSNib (nibabel.nifti1.Nifti1Image): TotalSegmentator output. 139 | ImageNib (nibabel.nifti1.Nifti1Image): Input image. 140 | 141 | Returns: 142 | numpy.ndarray: TotalSegmentator output. 143 | numpy.ndarray: Input image. 144 | """ 145 | TS_array = TSNib.get_fdata() 146 | img_array = ImageNib.get_fdata() 147 | return TS_array, img_array 148 | -------------------------------------------------------------------------------- /comp2comp/contrast_phase/xgboost.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/comp2comp/contrast_phase/xgboost.pkl -------------------------------------------------------------------------------- /comp2comp/hip/hip.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | import os 6 | from pathlib import Path 7 | from time import time 8 | from typing import Union 9 | 10 | import pandas as pd 11 | from totalsegmentator.libs import ( 12 | download_pretrained_weights, 13 | nostdout, 14 | setup_nnunet, 15 | ) 16 | 17 | from comp2comp.hip import hip_utils 18 | from comp2comp.hip.hip_visualization import ( 19 | hip_report_visualizer, 20 | hip_roi_visualizer, 21 | ) 22 | from comp2comp.inference_class_base import InferenceClass 23 | from comp2comp.models.models import Models 24 | 25 | 26 | class HipSegmentation(InferenceClass): 27 | """Spine segmentation.""" 28 | 29 | def __init__(self, model_name): 30 | super().__init__() 31 | self.model_name = model_name 32 | self.model = Models.model_from_name(model_name) 33 | 34 | def __call__(self, inference_pipeline): 35 | # inference_pipeline.dicom_series_path = self.input_path 36 | self.output_dir = inference_pipeline.output_dir 37 | self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") 38 | if not os.path.exists(self.output_dir_segmentations): 39 | os.makedirs(self.output_dir_segmentations) 40 | 41 | self.model_dir = inference_pipeline.model_dir 42 | 43 | seg, mv = self.hip_seg( 44 | os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), 45 | self.output_dir_segmentations + "hip.nii.gz", 46 | inference_pipeline.model_dir, 47 | ) 48 | 49 | inference_pipeline.model = self.model 50 | inference_pipeline.segmentation = seg 51 | inference_pipeline.medical_volume = mv 52 | 53 | return {} 54 | 55 | def hip_seg( 56 | self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir 57 | ): 58 | """Run spine segmentation. 59 | 60 | Args: 61 | input_path (Union[str, Path]): Input path. 62 | output_path (Union[str, Path]): Output path. 63 | """ 64 | 65 | print("Segmenting hip...") 66 | st = time() 67 | os.environ["SCRATCH"] = self.model_dir 68 | 69 | # Setup nnunet 70 | model = "3d_fullres" 71 | folds = [0] 72 | trainer = "nnUNetTrainerV2_ep4000_nomirror" 73 | crop_path = None 74 | task_id = [254] 75 | 76 | if self.model_name == "ts_hip": 77 | setup_nnunet() 78 | download_pretrained_weights(task_id[0]) 79 | else: 80 | raise ValueError("Invalid model name.") 81 | 82 | from totalsegmentator.nnunet import nnUNet_predict_image 83 | 84 | with nostdout(): 85 | img, seg = nnUNet_predict_image( 86 | input_path, 87 | output_path, 88 | task_id, 89 | model=model, 90 | folds=folds, 91 | trainer=trainer, 92 | tta=False, 93 | multilabel_image=True, 94 | resample=1.5, 95 | crop=None, 96 | crop_path=crop_path, 97 | task_name="total", 98 | nora_tag=None, 99 | preview=False, 100 | nr_threads_resampling=1, 101 | nr_threads_saving=6, 102 | quiet=False, 103 | verbose=False, 104 | test=0, 105 | ) 106 | end = time() 107 | 108 | # Log total time for hip segmentation 109 | print(f"Total time for hip segmentation: {end-st:.2f}s.") 110 | 111 | return seg, img 112 | 113 | 114 | class HipComputeROIs(InferenceClass): 115 | def __init__(self, hip_model): 116 | super().__init__() 117 | self.hip_model_name = hip_model 118 | self.hip_model_type = Models.model_from_name(self.hip_model_name) 119 | 120 | def __call__(self, inference_pipeline): 121 | segmentation = inference_pipeline.segmentation 122 | medical_volume = inference_pipeline.medical_volume 123 | 124 | model = inference_pipeline.model 125 | images_folder = os.path.join(inference_pipeline.output_dir, "dev") 126 | results_dict = hip_utils.compute_rois( 127 | medical_volume, segmentation, model, images_folder 128 | ) 129 | inference_pipeline.femur_results_dict = results_dict 130 | return {} 131 | 132 | 133 | class HipMetricsSaver(InferenceClass): 134 | """Save metrics to a CSV file.""" 135 | 136 | def __init__(self): 137 | super().__init__() 138 | 139 | def __call__(self, inference_pipeline): 140 | metrics_output_dir = os.path.join(inference_pipeline.output_dir, "metrics") 141 | if not os.path.exists(metrics_output_dir): 142 | os.makedirs(metrics_output_dir) 143 | results_dict = inference_pipeline.femur_results_dict 144 | left_head_hu = results_dict["left_head"]["hu"] 145 | right_head_hu = results_dict["right_head"]["hu"] 146 | left_intertrochanter_hu = results_dict["left_intertrochanter"]["hu"] 147 | right_intertrochanter_hu = results_dict["right_intertrochanter"]["hu"] 148 | left_neck_hu = results_dict["left_neck"]["hu"] 149 | right_neck_hu = results_dict["right_neck"]["hu"] 150 | # save to csv 151 | df = pd.DataFrame( 152 | { 153 | "Left Head (HU)": [left_head_hu], 154 | "Right Head (HU)": [right_head_hu], 155 | "Left Intertrochanter (HU)": [left_intertrochanter_hu], 156 | "Right Intertrochanter (HU)": [right_intertrochanter_hu], 157 | "Left Neck (HU)": [left_neck_hu], 158 | "Right Neck (HU)": [right_neck_hu], 159 | } 160 | ) 161 | df.to_csv(os.path.join(metrics_output_dir, "hip_metrics.csv"), index=False) 162 | return {} 163 | 164 | 165 | class HipVisualizer(InferenceClass): 166 | def __init__(self): 167 | super().__init__() 168 | 169 | def __call__(self, inference_pipeline): 170 | medical_volume = inference_pipeline.medical_volume 171 | 172 | left_head_roi = inference_pipeline.femur_results_dict["left_head"]["roi"] 173 | left_head_centroid = inference_pipeline.femur_results_dict["left_head"][ 174 | "centroid" 175 | ] 176 | left_head_hu = inference_pipeline.femur_results_dict["left_head"]["hu"] 177 | 178 | left_intertrochanter_roi = inference_pipeline.femur_results_dict[ 179 | "left_intertrochanter" 180 | ]["roi"] 181 | left_intertrochanter_centroid = inference_pipeline.femur_results_dict[ 182 | "left_intertrochanter" 183 | ]["centroid"] 184 | left_intertrochanter_hu = inference_pipeline.femur_results_dict[ 185 | "left_intertrochanter" 186 | ]["hu"] 187 | 188 | left_neck_roi = inference_pipeline.femur_results_dict["left_neck"]["roi"] 189 | left_neck_centroid = inference_pipeline.femur_results_dict["left_neck"][ 190 | "centroid" 191 | ] 192 | left_neck_hu = inference_pipeline.femur_results_dict["left_neck"]["hu"] 193 | 194 | right_head_roi = inference_pipeline.femur_results_dict["right_head"]["roi"] 195 | right_head_centroid = inference_pipeline.femur_results_dict["right_head"][ 196 | "centroid" 197 | ] 198 | right_head_hu = inference_pipeline.femur_results_dict["right_head"]["hu"] 199 | 200 | right_intertrochanter_roi = inference_pipeline.femur_results_dict[ 201 | "right_intertrochanter" 202 | ]["roi"] 203 | right_intertrochanter_centroid = inference_pipeline.femur_results_dict[ 204 | "right_intertrochanter" 205 | ]["centroid"] 206 | right_intertrochanter_hu = inference_pipeline.femur_results_dict[ 207 | "right_intertrochanter" 208 | ]["hu"] 209 | 210 | right_neck_roi = inference_pipeline.femur_results_dict["right_neck"]["roi"] 211 | right_neck_centroid = inference_pipeline.femur_results_dict["right_neck"][ 212 | "centroid" 213 | ] 214 | right_neck_hu = inference_pipeline.femur_results_dict["right_neck"]["hu"] 215 | 216 | output_dir = inference_pipeline.output_dir 217 | images_output_dir = os.path.join(output_dir, "images") 218 | if not os.path.exists(images_output_dir): 219 | os.makedirs(images_output_dir) 220 | hip_roi_visualizer( 221 | medical_volume, 222 | left_head_roi, 223 | left_head_centroid, 224 | left_head_hu, 225 | images_output_dir, 226 | "left_head", 227 | ) 228 | hip_roi_visualizer( 229 | medical_volume, 230 | left_intertrochanter_roi, 231 | left_intertrochanter_centroid, 232 | left_intertrochanter_hu, 233 | images_output_dir, 234 | "left_intertrochanter", 235 | ) 236 | hip_roi_visualizer( 237 | medical_volume, 238 | left_neck_roi, 239 | left_neck_centroid, 240 | left_neck_hu, 241 | images_output_dir, 242 | "left_neck", 243 | ) 244 | hip_roi_visualizer( 245 | medical_volume, 246 | right_head_roi, 247 | right_head_centroid, 248 | right_head_hu, 249 | images_output_dir, 250 | "right_head", 251 | ) 252 | hip_roi_visualizer( 253 | medical_volume, 254 | right_intertrochanter_roi, 255 | right_intertrochanter_centroid, 256 | right_intertrochanter_hu, 257 | images_output_dir, 258 | "right_intertrochanter", 259 | ) 260 | hip_roi_visualizer( 261 | medical_volume, 262 | right_neck_roi, 263 | right_neck_centroid, 264 | right_neck_hu, 265 | images_output_dir, 266 | "right_neck", 267 | ) 268 | hip_report_visualizer( 269 | medical_volume.get_fdata(), 270 | left_head_roi + right_head_roi, 271 | [left_head_centroid, right_head_centroid], 272 | images_output_dir, 273 | "head", 274 | { 275 | "Left Head HU": round(left_head_hu), 276 | "Right Head HU": round(right_head_hu), 277 | }, 278 | ) 279 | hip_report_visualizer( 280 | medical_volume.get_fdata(), 281 | left_intertrochanter_roi + right_intertrochanter_roi, 282 | [left_intertrochanter_centroid, right_intertrochanter_centroid], 283 | images_output_dir, 284 | "intertrochanter", 285 | { 286 | "Left Intertrochanter HU": round(left_intertrochanter_hu), 287 | "Right Intertrochanter HU": round(right_intertrochanter_hu), 288 | }, 289 | ) 290 | hip_report_visualizer( 291 | medical_volume.get_fdata(), 292 | left_neck_roi + right_neck_roi, 293 | [left_neck_centroid, right_neck_centroid], 294 | images_output_dir, 295 | "neck", 296 | { 297 | "Left Neck HU": round(left_neck_hu), 298 | "Right Neck HU": round(right_neck_hu), 299 | }, 300 | ) 301 | return {} 302 | -------------------------------------------------------------------------------- /comp2comp/hip/hip_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | import math 6 | import os 7 | import shutil 8 | 9 | import cv2 10 | import nibabel as nib 11 | import numpy as np 12 | import scipy.ndimage as ndi 13 | from scipy.ndimage import zoom 14 | from skimage.morphology import ball, binary_erosion 15 | 16 | from comp2comp.hip.hip_visualization import method_visualizer 17 | 18 | 19 | def compute_rois(medical_volume, segmentation, model, output_dir, save=False): 20 | left_femur_mask = segmentation.get_fdata() == model.categories["femur_left"] 21 | left_femur_mask = left_femur_mask.astype(np.uint8) 22 | right_femur_mask = segmentation.get_fdata() == model.categories["femur_right"] 23 | right_femur_mask = right_femur_mask.astype(np.uint8) 24 | left_head_roi, left_head_centroid, left_head_hu = get_femural_head_roi( 25 | left_femur_mask, medical_volume, output_dir, "left_head" 26 | ) 27 | right_head_roi, right_head_centroid, right_head_hu = get_femural_head_roi( 28 | right_femur_mask, medical_volume, output_dir, "right_head" 29 | ) 30 | ( 31 | left_intertrochanter_roi, 32 | left_intertrochanter_centroid, 33 | left_intertrochanter_hu, 34 | ) = get_femural_head_roi( 35 | left_femur_mask, medical_volume, output_dir, "left_intertrochanter" 36 | ) 37 | ( 38 | right_intertrochanter_roi, 39 | right_intertrochanter_centroid, 40 | right_intertrochanter_hu, 41 | ) = get_femural_head_roi( 42 | right_femur_mask, medical_volume, output_dir, "right_intertrochanter" 43 | ) 44 | ( 45 | left_neck_roi, 46 | left_neck_centroid, 47 | left_neck_hu, 48 | ) = get_femural_neck_roi( 49 | left_femur_mask, 50 | medical_volume, 51 | left_intertrochanter_roi, 52 | left_intertrochanter_centroid, 53 | left_head_roi, 54 | left_head_centroid, 55 | output_dir, 56 | ) 57 | ( 58 | right_neck_roi, 59 | right_neck_centroid, 60 | right_neck_hu, 61 | ) = get_femural_neck_roi( 62 | right_femur_mask, 63 | medical_volume, 64 | right_intertrochanter_roi, 65 | right_intertrochanter_centroid, 66 | right_head_roi, 67 | right_head_centroid, 68 | output_dir, 69 | ) 70 | combined_roi = ( 71 | left_head_roi 72 | + (right_head_roi) # * 2) 73 | + (left_intertrochanter_roi) # * 3) 74 | + (right_intertrochanter_roi) # * 4) 75 | + (left_neck_roi) # * 5) 76 | + (right_neck_roi) # * 6) 77 | ) 78 | 79 | if save: 80 | # make roi directory if it doesn't exist 81 | parent_output_dir = os.path.dirname(output_dir) 82 | roi_output_dir = os.path.join(parent_output_dir, "rois") 83 | if not os.path.exists(roi_output_dir): 84 | os.makedirs(roi_output_dir) 85 | 86 | # Convert left ROI to NIfTI 87 | left_roi_nifti = nib.Nifti1Image(combined_roi, medical_volume.affine) 88 | left_roi_path = os.path.join(roi_output_dir, "roi.nii.gz") 89 | nib.save(left_roi_nifti, left_roi_path) 90 | shutil.copy( 91 | os.path.join( 92 | os.path.dirname(os.path.abspath(__file__)), 93 | "tunnelvision.ipynb", 94 | ), 95 | parent_output_dir, 96 | ) 97 | 98 | return { 99 | "left_head": { 100 | "roi": left_head_roi, 101 | "centroid": left_head_centroid, 102 | "hu": left_head_hu, 103 | }, 104 | "right_head": { 105 | "roi": right_head_roi, 106 | "centroid": right_head_centroid, 107 | "hu": right_head_hu, 108 | }, 109 | "left_intertrochanter": { 110 | "roi": left_intertrochanter_roi, 111 | "centroid": left_intertrochanter_centroid, 112 | "hu": left_intertrochanter_hu, 113 | }, 114 | "right_intertrochanter": { 115 | "roi": right_intertrochanter_roi, 116 | "centroid": right_intertrochanter_centroid, 117 | "hu": right_intertrochanter_hu, 118 | }, 119 | "left_neck": { 120 | "roi": left_neck_roi, 121 | "centroid": left_neck_centroid, 122 | "hu": left_neck_hu, 123 | }, 124 | "right_neck": { 125 | "roi": right_neck_roi, 126 | "centroid": right_neck_centroid, 127 | "hu": right_neck_hu, 128 | }, 129 | } 130 | 131 | 132 | def get_femural_head_roi( 133 | femur_mask, 134 | medical_volume, 135 | output_dir, 136 | anatomy, 137 | visualize_method=False, 138 | min_pixel_count=20, 139 | ): 140 | top = np.where(femur_mask.sum(axis=(0, 1)) != 0)[0].max() 141 | top_mask = femur_mask[:, :, top] 142 | 143 | print(f"======== Computing {anatomy} femur ROIs ========") 144 | 145 | while True: 146 | labeled, num_features = ndi.label(top_mask) 147 | 148 | component_sizes = np.bincount(labeled.ravel()) 149 | valid_components = np.where(component_sizes >= min_pixel_count)[0][1:] 150 | 151 | if len(valid_components) == 2: 152 | break 153 | 154 | top -= 1 155 | if top < 0: 156 | print("Two connected components not found in the femur mask.") 157 | break 158 | top_mask = femur_mask[:, :, top] 159 | 160 | if len(valid_components) == 2: 161 | # Find the center of mass for each connected component 162 | center_of_mass_1 = list( 163 | ndi.center_of_mass(top_mask, labeled, valid_components[0]) 164 | ) 165 | center_of_mass_2 = list( 166 | ndi.center_of_mass(top_mask, labeled, valid_components[1]) 167 | ) 168 | 169 | # Assign left_center_of_mass to be the center of mass with lowest value in the first dimension 170 | if center_of_mass_1[0] < center_of_mass_2[0]: 171 | left_center_of_mass = center_of_mass_1 172 | right_center_of_mass = center_of_mass_2 173 | else: 174 | left_center_of_mass = center_of_mass_2 175 | right_center_of_mass = center_of_mass_1 176 | 177 | print(f"Left center of mass: {left_center_of_mass}") 178 | print(f"Right center of mass: {right_center_of_mass}") 179 | 180 | if anatomy == "left_intertrochanter" or anatomy == "right_head": 181 | center_of_mass = left_center_of_mass 182 | elif anatomy == "right_intertrochanter" or anatomy == "left_head": 183 | center_of_mass = right_center_of_mass 184 | 185 | coronal_slice = femur_mask[:, round(center_of_mass[1]), :] 186 | coronal_image = medical_volume.get_fdata()[:, round(center_of_mass[1]), :] 187 | sagittal_slice = femur_mask[round(center_of_mass[0]), :, :] 188 | sagittal_image = medical_volume.get_fdata()[round(center_of_mass[0]), :, :] 189 | 190 | zooms = medical_volume.header.get_zooms() 191 | zoom_factor = zooms[2] / zooms[1] 192 | 193 | coronal_slice = zoom(coronal_slice, (1, zoom_factor), order=1).round() 194 | coronal_image = zoom(coronal_image, (1, zoom_factor), order=3).round() 195 | sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=3).round() 196 | 197 | centroid = [round(center_of_mass[0]), 0, 0] 198 | 199 | print(f"Starting centroid: {centroid}") 200 | 201 | for _ in range(3): 202 | sagittal_slice = femur_mask[centroid[0], :, :] 203 | sagittal_slice = zoom(sagittal_slice, (1, zoom_factor), order=1).round() 204 | centroid[1], centroid[2], radius_sagittal = inscribe_sagittal( 205 | sagittal_slice, zoom_factor 206 | ) 207 | 208 | print(f"Centroid after inscribe sagittal: {centroid}") 209 | 210 | axial_slice = femur_mask[:, :, centroid[2]] 211 | if anatomy == "left_intertrochanter" or anatomy == "right_head": 212 | axial_slice[round(right_center_of_mass[0]) :, :] = 0 213 | elif anatomy == "right_intertrochanter" or anatomy == "left_head": 214 | axial_slice[: round(left_center_of_mass[0]), :] = 0 215 | centroid[0], centroid[1], radius_axial = inscribe_axial(axial_slice) 216 | 217 | print(f"Centroid after inscribe axial: {centroid}") 218 | 219 | axial_image = medical_volume.get_fdata()[:, :, round(centroid[2])] 220 | sagittal_image = medical_volume.get_fdata()[round(centroid[0]), :, :] 221 | sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=3).round() 222 | 223 | if visualize_method: 224 | method_visualizer( 225 | sagittal_image, 226 | axial_image, 227 | axial_slice, 228 | sagittal_slice, 229 | [centroid[2], centroid[1]], 230 | radius_sagittal, 231 | [centroid[1], centroid[0]], 232 | radius_axial, 233 | output_dir, 234 | anatomy, 235 | ) 236 | 237 | roi = compute_hip_roi(medical_volume, centroid, radius_sagittal, radius_axial) 238 | 239 | # selem = ndi.generate_binary_structure(3, 1) 240 | selem = ball(3) 241 | femur_mask_eroded = binary_erosion(femur_mask, selem) 242 | roi = roi * femur_mask_eroded 243 | roi_eroded = roi.astype(np.uint8) 244 | 245 | hu = get_mean_roi_hu(medical_volume, roi_eroded) 246 | 247 | return (roi_eroded, centroid, hu) 248 | 249 | 250 | def get_femural_neck_roi( 251 | femur_mask, 252 | medical_volume, 253 | intertrochanter_roi, 254 | intertrochanter_centroid, 255 | head_roi, 256 | head_centroid, 257 | output_dir, 258 | ): 259 | zooms = medical_volume.header.get_zooms() 260 | 261 | direction_vector = np.array(head_centroid) - np.array(intertrochanter_centroid) 262 | unit_direction_vector = direction_vector / np.linalg.norm(direction_vector) 263 | 264 | z, y, x = np.where(intertrochanter_roi) 265 | intertrochanter_points = np.column_stack((z, y, x)) 266 | t_start = np.dot( 267 | intertrochanter_points - intertrochanter_centroid, unit_direction_vector 268 | ).max() 269 | 270 | z, y, x = np.where(head_roi) 271 | head_points = np.column_stack((z, y, x)) 272 | t_end = ( 273 | np.linalg.norm(direction_vector) 274 | + np.dot(head_points - head_centroid, unit_direction_vector).min() 275 | ) 276 | 277 | z, y, x = np.indices(femur_mask.shape) 278 | coordinates = np.stack((z, y, x), axis=-1) 279 | 280 | distance_to_line_origin = np.dot( 281 | coordinates - intertrochanter_centroid, unit_direction_vector 282 | ) 283 | 284 | coordinates_zoomed = coordinates * zooms 285 | intertrochanter_centroid_zoomed = np.array(intertrochanter_centroid) * zooms 286 | unit_direction_vector_zoomed = unit_direction_vector * zooms 287 | 288 | distance_to_line = np.linalg.norm( 289 | np.cross( 290 | coordinates_zoomed - intertrochanter_centroid_zoomed, 291 | coordinates_zoomed 292 | - (intertrochanter_centroid_zoomed + unit_direction_vector_zoomed), 293 | ), 294 | axis=-1, 295 | ) / np.linalg.norm(unit_direction_vector_zoomed) 296 | 297 | cylinder_radius = 10 298 | 299 | cylinder_mask = ( 300 | (distance_to_line <= cylinder_radius) 301 | & (distance_to_line_origin >= t_start) 302 | & (distance_to_line_origin <= t_end) 303 | ) 304 | 305 | # selem = ndi.generate_binary_structure(3, 1) 306 | selem = ball(3) 307 | femur_mask_eroded = binary_erosion(femur_mask, selem) 308 | roi = cylinder_mask * femur_mask_eroded 309 | neck_roi = roi.astype(np.uint8) 310 | 311 | hu = get_mean_roi_hu(medical_volume, neck_roi) 312 | 313 | centroid = list( 314 | intertrochanter_centroid + unit_direction_vector * (t_start + t_end) / 2 315 | ) 316 | centroid = [round(x) for x in centroid] 317 | 318 | return neck_roi, centroid, hu 319 | 320 | 321 | def compute_hip_roi(img, centroid, radius_sagittal, radius_axial): 322 | pixel_spacing = img.header.get_zooms() 323 | length_i = radius_axial * 0.75 / pixel_spacing[0] 324 | length_j = radius_axial * 0.75 / pixel_spacing[1] 325 | length_k = radius_sagittal * 0.75 / pixel_spacing[2] 326 | 327 | roi = np.zeros(img.get_fdata().shape, dtype=np.uint8) 328 | i_lower = math.floor(centroid[0] - length_i) 329 | j_lower = math.floor(centroid[1] - length_j) 330 | k_lower = math.floor(centroid[2] - length_k) 331 | for i in range(i_lower, i_lower + 2 * math.ceil(length_i) + 1): 332 | for j in range(j_lower, j_lower + 2 * math.ceil(length_j) + 1): 333 | for k in range(k_lower, k_lower + 2 * math.ceil(length_k) + 1): 334 | if (i - centroid[0]) ** 2 / length_i**2 + ( 335 | j - centroid[1] 336 | ) ** 2 / length_j**2 + (k - centroid[2]) ** 2 / length_k**2 <= 1: 337 | roi[i, j, k] = 1 338 | return roi 339 | 340 | 341 | def inscribe_axial(axial_mask): 342 | dist_map = cv2.distanceTransform(axial_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) 343 | _, radius_axial, _, center_axial = cv2.minMaxLoc(dist_map) 344 | center_axial = list(center_axial) 345 | left_right_center = round(center_axial[1]) 346 | posterior_anterior_center = round(center_axial[0]) 347 | return left_right_center, posterior_anterior_center, radius_axial 348 | 349 | 350 | def inscribe_sagittal(sagittal_mask, zoom_factor): 351 | dist_map = cv2.distanceTransform(sagittal_mask, cv2.DIST_L2, cv2.DIST_MASK_PRECISE) 352 | _, radius_sagittal, _, center_sagittal = cv2.minMaxLoc(dist_map) 353 | center_sagittal = list(center_sagittal) 354 | posterior_anterior_center = round(center_sagittal[1]) 355 | inferior_superior_center = round(center_sagittal[0]) 356 | inferior_superior_center = round(inferior_superior_center / zoom_factor) 357 | return posterior_anterior_center, inferior_superior_center, radius_sagittal 358 | 359 | 360 | def get_mean_roi_hu(medical_volume, roi): 361 | masked_medical_volume = medical_volume.get_fdata() * roi 362 | return np.mean(masked_medical_volume[masked_medical_volume != 0]) 363 | -------------------------------------------------------------------------------- /comp2comp/hip/hip_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | import os 6 | 7 | import numpy as np 8 | from scipy.ndimage import zoom 9 | 10 | from comp2comp.visualization.detectron_visualizer import Visualizer 11 | from comp2comp.visualization.linear_planar_reformation import ( 12 | linear_planar_reformation, 13 | ) 14 | 15 | 16 | def method_visualizer( 17 | sagittal_image, 18 | axial_image, 19 | axial_slice, 20 | sagittal_slice, 21 | center_sagittal, 22 | radius_sagittal, 23 | center_axial, 24 | radius_axial, 25 | output_dir, 26 | anatomy, 27 | ): 28 | if not os.path.exists(output_dir): 29 | os.makedirs(output_dir) 30 | 31 | axial_image = np.clip(axial_image, -300, 1800) 32 | axial_image = normalize_img(axial_image) * 255.0 33 | 34 | sagittal_image = np.clip(sagittal_image, -300, 1800) 35 | sagittal_image = normalize_img(sagittal_image) * 255.0 36 | 37 | sagittal_image = sagittal_image.reshape( 38 | (sagittal_image.shape[0], sagittal_image.shape[1], 1) 39 | ) 40 | img_rgb = np.tile(sagittal_image, (1, 1, 3)) 41 | vis = Visualizer(img_rgb) 42 | vis.draw_circle( 43 | circle_coord=center_sagittal, color=[0, 1, 0], radius=radius_sagittal 44 | ) 45 | vis.draw_binary_mask(sagittal_slice) 46 | 47 | vis_obj = vis.get_output() 48 | vis_obj.save(os.path.join(output_dir, f"{anatomy}_sagittal_method.png")) 49 | 50 | axial_image = axial_image.reshape((axial_image.shape[0], axial_image.shape[1], 1)) 51 | img_rgb = np.tile(axial_image, (1, 1, 3)) 52 | vis = Visualizer(img_rgb) 53 | vis.draw_circle(circle_coord=center_axial, color=[0, 1, 0], radius=radius_axial) 54 | vis.draw_binary_mask(axial_slice) 55 | 56 | vis_obj = vis.get_output() 57 | vis_obj.save(os.path.join(output_dir, f"{anatomy}_axial_method.png")) 58 | 59 | 60 | def hip_roi_visualizer( 61 | medical_volume, 62 | roi, 63 | centroid, 64 | hu, 65 | output_dir, 66 | anatomy, 67 | ): 68 | zooms = medical_volume.header.get_zooms() 69 | zoom_factor = zooms[2] / zooms[1] 70 | 71 | sagittal_image = medical_volume.get_fdata()[centroid[0], :, :] 72 | sagittal_roi = roi[centroid[0], :, :] 73 | 74 | sagittal_image = zoom(sagittal_image, (1, zoom_factor), order=1).round() 75 | sagittal_roi = zoom(sagittal_roi, (1, zoom_factor), order=3).round() 76 | sagittal_image = np.flip(sagittal_image.T) 77 | sagittal_roi = np.flip(sagittal_roi.T) 78 | 79 | axial_image = medical_volume.get_fdata()[:, :, round(centroid[2])] 80 | axial_roi = roi[:, :, round(centroid[2])] 81 | 82 | axial_image = np.flip(axial_image.T) 83 | axial_roi = np.flip(axial_roi.T) 84 | 85 | _ROI_COLOR = np.array([1.000, 0.340, 0.200]) 86 | 87 | sagittal_image = np.clip(sagittal_image, -300, 1800) 88 | sagittal_image = normalize_img(sagittal_image) * 255.0 89 | sagittal_image = sagittal_image.reshape( 90 | (sagittal_image.shape[0], sagittal_image.shape[1], 1) 91 | ) 92 | img_rgb = np.tile(sagittal_image, (1, 1, 3)) 93 | vis = Visualizer(img_rgb) 94 | vis.draw_binary_mask( 95 | sagittal_roi, 96 | color=_ROI_COLOR, 97 | edge_color=_ROI_COLOR, 98 | alpha=0.0, 99 | area_threshold=0, 100 | ) 101 | vis.draw_text( 102 | text=f"Mean HU: {round(hu)}", 103 | position=(412, 10), 104 | color=_ROI_COLOR, 105 | font_size=9, 106 | horizontal_alignment="left", 107 | ) 108 | vis_obj = vis.get_output() 109 | vis_obj.save(os.path.join(output_dir, f"{anatomy}_hip_roi_sagittal.png")) 110 | 111 | """ 112 | axial_image = np.clip(axial_image, -300, 1800) 113 | axial_image = normalize_img(axial_image) * 255.0 114 | axial_image = axial_image.reshape((axial_image.shape[0], axial_image.shape[1], 1)) 115 | img_rgb = np.tile(axial_image, (1, 1, 3)) 116 | vis = Visualizer(img_rgb) 117 | vis.draw_binary_mask( 118 | axial_roi, color=_ROI_COLOR, edge_color=_ROI_COLOR, alpha=0.0, area_threshold=0 119 | ) 120 | vis.draw_text( 121 | text=f"Mean HU: {round(hu)}", 122 | position=(412, 10), 123 | color=_ROI_COLOR, 124 | font_size=9, 125 | horizontal_alignment="left", 126 | ) 127 | vis_obj = vis.get_output() 128 | vis_obj.save(os.path.join(output_dir, f"{anatomy}_hip_roi_axial.png")) 129 | """ 130 | 131 | 132 | def hip_report_visualizer(medical_volume, roi, centroids, output_dir, anatomy, labels): 133 | _ROI_COLOR = np.array([1.000, 0.340, 0.200]) 134 | image, mask = linear_planar_reformation( 135 | medical_volume, roi, centroids, dimension="axial" 136 | ) 137 | # add 3rd dim to image 138 | image = np.flip(image.T) 139 | mask = np.flip(mask.T) 140 | mask[mask > 1] = 1 141 | # mask = np.expand_dims(mask, axis=2) 142 | image = np.expand_dims(image, axis=2) 143 | image = np.clip(image, -300, 1800) 144 | image = normalize_img(image) * 255.0 145 | img_rgb = np.tile(image, (1, 1, 3)) 146 | vis = Visualizer(img_rgb) 147 | vis.draw_binary_mask( 148 | mask, color=_ROI_COLOR, edge_color=_ROI_COLOR, alpha=0.0, area_threshold=0 149 | ) 150 | pos_idx = 0 151 | for key, value in labels.items(): 152 | vis.draw_text( 153 | text=f"{key}: {value}", 154 | position=(310, 10 + pos_idx * 17), 155 | color=_ROI_COLOR, 156 | font_size=9, 157 | horizontal_alignment="left", 158 | ) 159 | pos_idx += 1 160 | vis_obj = vis.get_output() 161 | vis_obj.save(os.path.join(output_dir, f"{anatomy}_report_axial.png")) 162 | 163 | 164 | def normalize_img(img: np.ndarray) -> np.ndarray: 165 | """Normalize the image. 166 | Args: 167 | img (np.ndarray): Input image. 168 | Returns: 169 | np.ndarray: Normalized image. 170 | """ 171 | return (img - img.min()) / (img.max() - img.min()) 172 | -------------------------------------------------------------------------------- /comp2comp/hip/tunnelvision.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import voxel as vx\n", 10 | "import tunnelvision as tv\n", 11 | "import numpy as np\n", 12 | "\n", 13 | "mv = vx.load(\"./segmentations/converted_dcm.nii.gz\")\n", 14 | "mv = mv.reformat((\"LR\", \"PA\", \"IS\"))\n", 15 | "np_mv = mv.A\n", 16 | "np_mv = np_mv.astype(np.int32)\n", 17 | "np_mv = np.expand_dims(np_mv, axis=0)\n", 18 | "np_mv = np.expand_dims(np_mv, axis=4)\n", 19 | "\n", 20 | "seg = vx.load(\"./rois/roi.nii.gz\")\n", 21 | "np_seg = seg.A\n", 22 | "np_seg_dim = seg.A\n", 23 | "np_seg = np_seg.astype(np.int32)\n", 24 | "np_seg = np.expand_dims(np_seg, axis=0)\n", 25 | "np_seg = np.expand_dims(np_seg, axis=4)\n", 26 | "\n", 27 | "hip_seg = vx.load(\"./segmentations/hip.nii.gz\")\n", 28 | "hip_seg = hip_seg.reformat((\"LR\", \"PA\", \"IS\"))\n", 29 | "np_hip_seg = hip_seg.A.astype(int)\n", 30 | "# set values not equal to 88 or 89 to 0\n", 31 | "np_hip_seg[(np_hip_seg != 88) & (np_hip_seg != 89)] = 0\n", 32 | "np_hip_seg[np_hip_seg != 0] = np_hip_seg[np_hip_seg != 0] + 4\n", 33 | "np_hip_seg[np_seg_dim != 0] = 0\n", 34 | "np_hip_seg = np_hip_seg.astype(np.int32)\n", 35 | "np_hip_seg = np.expand_dims(np_hip_seg, axis=0)\n", 36 | "np_hip_seg = np.expand_dims(np_hip_seg, axis=4)\n", 37 | "\n", 38 | "ax = tv.Axes(figsize=(512, 512))\n", 39 | "ax.imshow(np_mv)\n", 40 | "ax.imshow(np_seg, cmap=\"seg\")\n", 41 | "ax.imshow(np_hip_seg, cmap=\"seg\")\n", 42 | "ax.show()" 43 | ] 44 | } 45 | ], 46 | "metadata": { 47 | "kernelspec": { 48 | "display_name": "Python 3.8.16 ('c2c_env')", 49 | "language": "python", 50 | "name": "python3" 51 | }, 52 | "language_info": { 53 | "codemirror_mode": { 54 | "name": "ipython", 55 | "version": 3 56 | }, 57 | "file_extension": ".py", 58 | "mimetype": "text/x-python", 59 | "name": "python", 60 | "nbconvert_exporter": "python", 61 | "pygments_lexer": "ipython3", 62 | "version": "3.8.16" 63 | }, 64 | "orig_nbformat": 4, 65 | "vscode": { 66 | "interpreter": { 67 | "hash": "62fd47c2f495fb43260e4f88a1d5487d18d4c091bac4d4df4eca96cade9f1e23" 68 | } 69 | } 70 | }, 71 | "nbformat": 4, 72 | "nbformat_minor": 2 73 | } 74 | -------------------------------------------------------------------------------- /comp2comp/inference_class_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | from typing import Dict 6 | 7 | 8 | class InferenceClass: 9 | """Base class for inference classes.""" 10 | 11 | def __init__(self): 12 | pass 13 | 14 | def __call__(self) -> Dict: 15 | raise NotImplementedError 16 | 17 | def __repr__(self): 18 | return self.__class__.__name__ 19 | -------------------------------------------------------------------------------- /comp2comp/inference_pipeline.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | import inspect 6 | import os 7 | from typing import Dict, List 8 | 9 | import nibabel as nib 10 | 11 | from comp2comp.inference_class_base import InferenceClass 12 | from comp2comp.io.io import DicomLoader, NiftiSaver 13 | 14 | 15 | class InferencePipeline(InferenceClass): 16 | """Inference pipeline.""" 17 | 18 | def __init__( 19 | self, inference_classes: List = None, config: Dict = None, args: Dict = None 20 | ): 21 | self.config = config 22 | self.args = args 23 | # assign values from config to attributes 24 | if self.config is not None: 25 | for key, value in self.config.items(): 26 | setattr(self, key, value) 27 | 28 | self.inference_classes = inference_classes 29 | 30 | def __call__(self, inference_pipeline=None, **kwargs): 31 | # print out the class names for each inference class 32 | print("") 33 | print("Inference pipeline:") 34 | for i, inference_class in enumerate(self.inference_classes): 35 | print(f"({i + 1}) {inference_class.__repr__()}") 36 | print("") 37 | 38 | print("Starting inference pipeline for:\n") 39 | 40 | if inference_pipeline: 41 | for key, value in kwargs.items(): 42 | setattr(inference_pipeline, key, value) 43 | else: 44 | for key, value in kwargs.items(): 45 | setattr(self, key, value) 46 | 47 | output = {} 48 | for inference_class in self.inference_classes: 49 | function_keys = set(inspect.signature(inference_class).parameters.keys()) 50 | function_keys.remove("inference_pipeline") 51 | 52 | if "kwargs" in function_keys: 53 | function_keys.remove("kwargs") 54 | 55 | assert function_keys == set( 56 | output.keys() 57 | ), "Input to inference class, {}, does not have the correct parameters".format( 58 | inference_class.__repr__() 59 | ) 60 | 61 | print( 62 | "Running {} with input keys {}".format( 63 | inference_class.__repr__(), 64 | inspect.signature(inference_class).parameters.keys(), 65 | ) 66 | ) 67 | 68 | if inference_pipeline: 69 | output = inference_class( 70 | inference_pipeline=inference_pipeline, **output 71 | ) 72 | else: 73 | output = inference_class(inference_pipeline=self, **output) 74 | 75 | # if not the last inference class, check that the output keys are correct 76 | if inference_class != self.inference_classes[-1]: 77 | print( 78 | "Finished {} with output keys {}\n".format( 79 | inference_class.__repr__(), output.keys() 80 | ) 81 | ) 82 | 83 | print("Inference pipeline finished.\n") 84 | 85 | return output 86 | 87 | def saveArrToNifti(self, arr, path): 88 | """ 89 | Saves an array to nifti using the CT as reference 90 | 91 | Args: 92 | arr (ndarray): input array. 93 | path (str, Path): full save path. 94 | 95 | Returns: 96 | None. 97 | 98 | """ 99 | img = nib.Nifti1Image( 100 | arr, self.medical_volume.affine, self.medical_volume.header 101 | ) 102 | nib.save(img, path) 103 | 104 | 105 | if __name__ == "__main__": 106 | """Example usage of InferencePipeline.""" 107 | import argparse 108 | 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument("--dicom_dir", type=str, required=True) 111 | args = parser.parse_args() 112 | 113 | output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../outputs") 114 | if not os.path.exists(output_dir): 115 | os.mkdir(output_dir) 116 | output_file_path = os.path.join(output_dir, "test.nii.gz") 117 | 118 | pipeline = InferencePipeline( 119 | [DicomLoader(args.dicom_dir), NiftiSaver()], 120 | config={"output_dir": output_file_path}, 121 | ) 122 | pipeline() 123 | 124 | print("Done.") 125 | -------------------------------------------------------------------------------- /comp2comp/io/io.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | import os 6 | import shutil 7 | from pathlib import Path 8 | from typing import Dict, Union 9 | 10 | # import dicom2nifti 11 | import dosma as dm 12 | import nibabel as nib 13 | import pydicom 14 | import SimpleITK as sitk 15 | 16 | from comp2comp.inference_class_base import InferenceClass 17 | 18 | 19 | class DicomLoader(InferenceClass): 20 | """Load a single dicom series.""" 21 | 22 | def __init__(self, input_path: Union[str, Path]): 23 | super().__init__() 24 | self.dicom_dir = Path(input_path) 25 | self.dr = dm.DicomReader() 26 | 27 | def __call__(self, inference_pipeline) -> Dict: 28 | medical_volume = self.dr.load( 29 | self.dicom_dir, group_by=None, sort_by="InstanceNumber" 30 | )[0] 31 | return {"medical_volume": medical_volume} 32 | 33 | 34 | class NiftiSaver(InferenceClass): 35 | """Save dosma medical volume object to NIfTI file.""" 36 | 37 | def __init__(self): 38 | super().__init__() 39 | # self.output_dir = Path(output_path) 40 | self.nw = dm.NiftiWriter() 41 | 42 | def __call__( 43 | self, inference_pipeline, medical_volume: dm.MedicalVolume 44 | ) -> Dict[str, Path]: 45 | nifti_file = inference_pipeline.output_dir 46 | self.nw.write(medical_volume, nifti_file) 47 | return {"nifti_file": nifti_file} 48 | 49 | 50 | class DicomFinder(InferenceClass): 51 | """Find dicom files in a directory.""" 52 | 53 | def __init__(self, input_path: Union[str, Path]) -> Dict[str, Path]: 54 | super().__init__() 55 | self.input_path = Path(input_path) 56 | 57 | def __call__(self, inference_pipeline) -> Dict[str, Path]: 58 | """Find dicom files in a directory. 59 | 60 | Args: 61 | inference_pipeline (InferencePipeline): Inference pipeline. 62 | 63 | Returns: 64 | Dict[str, Path]: Dictionary containing dicom files. 65 | """ 66 | dicom_files = [] 67 | for file in self.input_path.glob("**/*.dcm"): 68 | dicom_files.append(file) 69 | inference_pipeline.dicom_file_paths = dicom_files 70 | return {} 71 | 72 | 73 | class DicomToNifti(InferenceClass): 74 | """Convert dicom files to NIfTI files.""" 75 | 76 | def __init__(self, input_path: Union[str, Path], pipeline_name=None, save=True): 77 | super().__init__() 78 | self.input_path = Path(input_path) 79 | self.save = save 80 | self.pipeline_name = pipeline_name 81 | 82 | def __call__(self, inference_pipeline): 83 | if os.path.exists( 84 | os.path.join( 85 | inference_pipeline.output_dir, "segmentations", "converted_dcm.nii.gz" 86 | ) 87 | ): 88 | return {} 89 | if hasattr(inference_pipeline, "medical_volume"): 90 | return {} 91 | output_dir = inference_pipeline.output_dir 92 | segmentations_output_dir = os.path.join(output_dir, "segmentations") 93 | os.makedirs(segmentations_output_dir, exist_ok=True) 94 | 95 | # if self.input_path is a folder 96 | if self.input_path.is_dir(): 97 | # store a dcm object for retrieving dicom tags 98 | dcm_files = [d for d in os.listdir(self.input_path) if d.endswith('.dcm')] 99 | inference_pipeline.dcm = pydicom.read_file(os.path.join(self.input_path, dcm_files[0])) 100 | 101 | ds = dicom_series_to_nifti( 102 | self.input_path, 103 | output_file=os.path.join( 104 | segmentations_output_dir, "converted_dcm.nii.gz" 105 | ), 106 | reorient_nifti=False, 107 | pipeline_name=self.pipeline_name, 108 | ) 109 | inference_pipeline.dicom_series_path = str(self.input_path) 110 | inference_pipeline.dicom_ds = ds 111 | elif str(self.input_path).endswith(".nii"): 112 | shutil.copy( 113 | self.input_path, 114 | os.path.join(segmentations_output_dir, "converted_dcm.nii"), 115 | ) 116 | elif str(self.input_path).endswith(".nii.gz"): 117 | shutil.copy( 118 | self.input_path, 119 | os.path.join(segmentations_output_dir, "converted_dcm.nii.gz"), 120 | ) 121 | 122 | inference_pipeline.medical_volume = nib.load( 123 | os.path.join(segmentations_output_dir, "converted_dcm.nii.gz") 124 | ) 125 | 126 | return {} 127 | 128 | 129 | def series_selector(dicom_path, pipeline_name=None): 130 | ds = pydicom.filereader.dcmread(dicom_path) 131 | image_type_list = list(ds.ImageType) 132 | if pipeline_name != "aaa": 133 | if not any("primary" in s.lower() for s in image_type_list): 134 | raise ValueError("Not primary image type") 135 | if not any("original" in s.lower() for s in image_type_list): 136 | raise ValueError("Not original image type") 137 | if ds.ImageOrientationPatient != [1, 0, 0, 0, 1, 0]: 138 | raise ValueError("Image orientation is not axial") 139 | else: 140 | print( 141 | f"Skipping primary, original, and orientation image type check for the {pipeline_name} pipeline." 142 | ) 143 | # if any("gsi" in s.lower() for s in image_type_list): 144 | # raise ValueError("GSI image type") 145 | return ds 146 | 147 | 148 | def dicom_series_to_nifti(input_path, output_file, reorient_nifti, pipeline_name=None): 149 | reader = sitk.ImageSeriesReader() 150 | dicom_names = reader.GetGDCMSeriesFileNames(str(input_path)) 151 | ds = series_selector(dicom_names[0], pipeline_name=pipeline_name) 152 | reader.SetFileNames(dicom_names) 153 | image = reader.Execute() 154 | sitk.WriteImage(image, output_file) 155 | return ds 156 | -------------------------------------------------------------------------------- /comp2comp/io/io_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | import csv 6 | import os 7 | 8 | import nibabel as nib 9 | import pydicom 10 | 11 | 12 | def find_dicom_files(input_path): 13 | dicom_series = [] 14 | if not os.path.isdir(input_path): 15 | dicom_series = [str(os.path.abspath(input_path))] 16 | else: 17 | for root, _, files in os.walk(input_path): 18 | for file in files: 19 | if file.endswith(".dcm") or file.endswith(".dicom"): 20 | dicom_series.append(os.path.join(root, file)) 21 | return dicom_series 22 | 23 | 24 | def get_dicom_paths_and_num(path): 25 | """ 26 | Get all paths under a path that contain only dicom files. 27 | Args: 28 | path (str): Path to search. 29 | Returns: 30 | list: List of paths. 31 | """ 32 | dicom_paths = [] 33 | for root, _, files in os.walk(path): 34 | if len(files) > 0: 35 | if all(file.endswith(".dcm") or file.endswith(".dicom") for file in files): 36 | dicom_paths.append((root, len(files))) 37 | 38 | if len(dicom_paths) == 0: 39 | raise ValueError("No scans were found in:\n" + path) 40 | 41 | return dicom_paths 42 | 43 | 44 | def get_dicom_or_nifti_paths_and_num(path): 45 | """Get all paths under a path that contain only dicom files or a nifti file. 46 | Args: 47 | path (str): Path to search. 48 | 49 | Returns: 50 | list: List of paths. 51 | """ 52 | dicom_nifti_paths = [] 53 | 54 | if path.endswith(".nii") or path.endswith(".nii.gz"): 55 | dicom_nifti_paths.append((path, getNumSlicesNifti(path))) 56 | elif path.endswith(".txt"): 57 | dicom_nifti_paths = [] 58 | with open(path, "r") as f: 59 | for dicom_folder_path in f: 60 | dicom_folder_path = dicom_folder_path.strip() 61 | if dicom_folder_path.endswith(".nii") or dicom_folder_path.endswith(".nii.gz"): 62 | dicom_nifti_paths.append( (dicom_folder_path, getNumSlicesNifti(dicom_folder_path))) 63 | else: 64 | dicom_nifti_paths.append( (dicom_folder_path, len(os.listdir(dicom_folder_path)))) 65 | else: 66 | for root, dirs, files in os.walk(path): 67 | if len(files) > 0: 68 | # if all(file.endswith(".dcm") or file.endswith(".dicom") for file in files): 69 | dicom_nifti_paths.append((root, len(files))) 70 | 71 | return dicom_nifti_paths 72 | 73 | 74 | def write_dicom_metadata_to_csv(ds, csv_filename): 75 | with open(csv_filename, "w", newline="") as csvfile: 76 | csvwriter = csv.writer(csvfile) 77 | csvwriter.writerow(["Tag", "Keyword", "Value"]) 78 | 79 | for element in ds: 80 | tag = element.tag 81 | keyword = pydicom.datadict.keyword_for_tag(tag) 82 | if keyword == "PixelData": 83 | continue 84 | value = str(element.value) 85 | csvwriter.writerow([tag, keyword, value]) 86 | 87 | 88 | def getNumSlicesNifti(path): 89 | img = nib.load(path) 90 | img = nib.as_closest_canonical(img) 91 | return img.shape[2] 92 | -------------------------------------------------------------------------------- /comp2comp/liver_spleen_pancreas/liver_spleen_pancreas.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from time import time 4 | from typing import Union 5 | 6 | from totalsegmentatorv2.python_api import totalsegmentator 7 | 8 | from comp2comp.inference_class_base import InferenceClass 9 | 10 | 11 | class LiverSpleenPancreasSegmentation(InferenceClass): 12 | """Organ segmentation.""" 13 | 14 | def __init__(self): 15 | super().__init__() 16 | # self.input_path = input_path 17 | 18 | def __call__(self, inference_pipeline): 19 | # inference_pipeline.dicom_series_path = self.input_path 20 | self.output_dir = inference_pipeline.output_dir 21 | self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") 22 | if not os.path.exists(self.output_dir_segmentations): 23 | os.makedirs(self.output_dir_segmentations) 24 | 25 | self.model_dir = inference_pipeline.model_dir 26 | 27 | seg = self.organ_seg( 28 | os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), 29 | self.output_dir_segmentations + "organs.nii.gz", 30 | inference_pipeline.model_dir, 31 | ) 32 | 33 | inference_pipeline.segmentation = seg 34 | 35 | return {} 36 | 37 | def organ_seg( 38 | self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir 39 | ): 40 | """Run organ segmentation. 41 | 42 | Args: 43 | input_path (Union[str, Path]): Input path. 44 | output_path (Union[str, Path]): Output path. 45 | """ 46 | 47 | print("Segmenting organs...") 48 | st = time() 49 | os.environ["SCRATCH"] = self.model_dir 50 | 51 | seg = totalsegmentator( 52 | input=input_path, 53 | output=output_path, 54 | # input = os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), 55 | # output = os.path.join(self.output_dir_segmentations, "segmentation.nii"), 56 | task_ids=[291], 57 | ml=True, 58 | nr_thr_resamp=1, 59 | nr_thr_saving=6, 60 | fast=False, 61 | nora_tag="None", 62 | preview=False, 63 | task="total", 64 | # roi_subset = [ 65 | # "vertebrae_T12", 66 | # "vertebrae_L1", 67 | # "vertebrae_L2", 68 | # "vertebrae_L3", 69 | # "vertebrae_L4", 70 | # "vertebrae_L5", 71 | # ], 72 | roi_subset=None, 73 | statistics=False, 74 | radiomics=False, 75 | crop_path=None, 76 | body_seg=False, 77 | force_split=False, 78 | output_type="nifti", 79 | quiet=False, 80 | verbose=False, 81 | test=0, 82 | skip_saving=True, 83 | device="gpu", 84 | license_number=None, 85 | statistics_exclude_masks_at_border=True, 86 | no_derived_masks=False, 87 | v1_order=False, 88 | ) 89 | 90 | # Setup nnunet 91 | # model = "3d_fullres" 92 | # folds = [0] 93 | # trainer = "nnUNetTrainerV2_ep4000_nomirror" 94 | # crop_path = None 95 | # task_id = [251] 96 | 97 | # setup_nnunet() 98 | # download_pretrained_weights(task_id[0]) 99 | 100 | # from totalsegmentator.nnunet import nnUNet_predict_image 101 | 102 | # with nostdout(): 103 | # seg, mvs = nnUNet_predict_image( 104 | # input_path, 105 | # output_path, 106 | # task_id, 107 | # model=model, 108 | # folds=folds, 109 | # trainer=trainer, 110 | # tta=False, 111 | # multilabel_image=True, 112 | # resample=1.5, 113 | # crop=None, 114 | # crop_path=crop_path, 115 | # task_name="total", 116 | # nora_tag="None", 117 | # preview=False, 118 | # nr_threads_resampling=1, 119 | # nr_threads_saving=6, 120 | # quiet=False, 121 | # verbose=True, 122 | # test=0, 123 | # ) 124 | end = time() 125 | 126 | # Log total time for spine segmentation 127 | print(f"Total time for organ segmentation: {end-st:.2f}s.") 128 | 129 | return seg 130 | -------------------------------------------------------------------------------- /comp2comp/liver_spleen_pancreas/liver_spleen_pancreas_visualization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | 6 | import nibabel as nib 7 | import numpy as np 8 | 9 | from comp2comp.inference_class_base import InferenceClass 10 | from comp2comp.liver_spleen_pancreas.visualization_utils import ( 11 | generate_liver_spleen_pancreas_report, 12 | generate_slice_images, 13 | ) 14 | 15 | 16 | class LiverSpleenPancreasVisualizer(InferenceClass): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | self.unit_dict = { 21 | "Volume": r"$\mathregular{cm^3}$", 22 | "Mean": "HU", 23 | "Median": "HU", 24 | } 25 | 26 | # self.class_nums = [1, 5, 10] 27 | self.class_nums = [5, 1, 7] 28 | self.organ_names = ["liver", "spleen", "pancreas"] 29 | 30 | def __call__(self, inference_pipeline): 31 | self.output_dir = inference_pipeline.output_dir 32 | self.output_dir_images_organs = os.path.join(self.output_dir, "images/") 33 | inference_pipeline.output_dir_images_organs = self.output_dir_images_organs 34 | 35 | if not os.path.exists(self.output_dir_images_organs): 36 | os.makedirs(self.output_dir_images_organs) 37 | 38 | # make folder for volumes 39 | self.output_dir_volumes = os.path.join(self.output_dir, "volumes/") 40 | if not os.path.exists(self.output_dir_volumes): 41 | os.makedirs(self.output_dir_volumes) 42 | 43 | # save the volume to disk in nifti format 44 | nib.save( 45 | inference_pipeline.medical_volume, 46 | os.path.join(self.output_dir_volumes, "ct.nii.gz"), 47 | ) 48 | 49 | segmentation_subset = np.zeros( 50 | inference_pipeline.medical_volume.shape, dtype=np.int8 51 | ) 52 | tmp_seg = inference_pipeline.segmentation.get_fdata().astype(np.int8) 53 | 54 | for i, c in enumerate(self.class_nums, start=1): 55 | segmentation_subset[tmp_seg == c] = i 56 | 57 | inference_pipeline.saveArrToNifti( 58 | segmentation_subset, 59 | os.path.join(self.output_dir_volumes, "liver_spleen_pancreas_mask.nii.gz"), 60 | ) 61 | 62 | inference_pipeline.medical_volume_arr = np.flip( 63 | inference_pipeline.medical_volume.get_fdata(), axis=1 64 | ) 65 | inference_pipeline.segmentation_arr = np.flip( 66 | inference_pipeline.segmentation.get_fdata(), axis=1 67 | ) 68 | 69 | inference_pipeline.pix_dims = inference_pipeline.medical_volume.header[ 70 | "pixdim" 71 | ][1:4] 72 | inference_pipeline.vol_per_pixel = np.prod( 73 | inference_pipeline.pix_dims / 10 74 | ) # mm to cm for having ml/pixel. 75 | 76 | self.organ_metrics = generate_slice_images( 77 | inference_pipeline.medical_volume_arr, 78 | inference_pipeline.segmentation_arr, 79 | self.class_nums, 80 | self.unit_dict, 81 | inference_pipeline.vol_per_pixel, 82 | inference_pipeline.pix_dims, 83 | self.output_dir_images_organs, 84 | fontsize=24, 85 | ) 86 | 87 | inference_pipeline.organ_metrics = self.organ_metrics 88 | 89 | generate_liver_spleen_pancreas_report( 90 | self.output_dir_images_organs, self.organ_names 91 | ) 92 | 93 | return {} 94 | 95 | 96 | class LiverSpleenPancreasMetricsPrinter(InferenceClass): 97 | def __init__(self): 98 | super().__init__() 99 | 100 | def __call__(self, inference_pipeline): 101 | results = inference_pipeline.organ_metrics 102 | organs = list(results.keys()) 103 | 104 | name_dist = max([len(o) for o in organs]) 105 | metrics = [] 106 | for k in results[list(results.keys())[0]].keys(): 107 | if k != "Organ": 108 | metrics.append(k) 109 | 110 | units = ["cm^3", "HU", "HU"] 111 | 112 | header = ( 113 | "{:<" + str(name_dist + 4) + "}" + ("{:<" + str(15) + "}") * len(metrics) 114 | ) 115 | header = header.format( 116 | "Organ", *[m + "(" + u + ")" for m, u in zip(metrics, units)] 117 | ) 118 | 119 | base_print = ( 120 | "{:<" + str(name_dist + 4) + "}" + ("{:<" + str(15) + ".0f}") * len(metrics) 121 | ) 122 | 123 | print("\n") 124 | print(header) 125 | 126 | for organ in results.values(): 127 | line = base_print.format(*organ.values()) 128 | print(line) 129 | 130 | print("\n") 131 | 132 | output_dir = inference_pipeline.output_dir 133 | self.output_dir_metrics_organs = os.path.join(output_dir, "metrics/") 134 | 135 | if not os.path.exists(self.output_dir_metrics_organs): 136 | os.makedirs(self.output_dir_metrics_organs) 137 | 138 | header = ( 139 | ",".join(["Organ"] + [m + "(" + u + ")" for m, u in zip(metrics, units)]) 140 | + "\n" 141 | ) 142 | with open( 143 | os.path.join( 144 | self.output_dir_metrics_organs, "liver_spleen_pancreas_metrics.csv" 145 | ), 146 | "w", 147 | ) as f: 148 | f.write(header) 149 | 150 | for organ in results.values(): 151 | line = ",".join([str(v) for v in organ.values()]) + "\n" 152 | f.write(line) 153 | 154 | return {} 155 | -------------------------------------------------------------------------------- /comp2comp/liver_spleen_pancreas/visualization_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import scipy 9 | from matplotlib.colors import ListedColormap 10 | from PIL import Image 11 | 12 | 13 | def extract_axial_mid_slice(ct, mask, crop=True): 14 | # find the slice with max surface area of the organ 15 | axial_extent = np.where(mask.sum(axis=(0, 1)))[0] 16 | 17 | max_extent = 0 18 | max_extent_idx = 0 19 | 20 | for idx in axial_extent: 21 | label, num_features = scipy.ndimage.label(mask[:, :, idx]) 22 | 23 | if num_features > 1: 24 | continue 25 | else: 26 | extent = label.sum() 27 | if extent > max_extent: 28 | max_extent = extent 29 | max_extent_idx = idx 30 | 31 | ct_slice_z = np.transpose(ct[:, :, max_extent_idx], axes=(1, 0)) 32 | mask_slice_z = np.transpose(mask[:, :, max_extent_idx], axes=(1, 0)) 33 | 34 | ct_slice_z = np.flip(ct_slice_z, axis=(0, 1)) 35 | mask_slice_z = np.flip(mask_slice_z, axis=(0, 1)) 36 | 37 | return ct_slice_z, mask_slice_z, max_extent_idx 38 | 39 | 40 | # def extract_axial_mid_slice(ct, mask, crop=True): 41 | # # find the slice with max surface area of the organ 42 | # slice_idx = np.argmax(mask.sum(axis=(0, 1))) 43 | 44 | # ct_slice_z = np.transpose(ct[:, :, slice_idx], axes=(1, 0)) 45 | # mask_slice_z = np.transpose(mask[:, :, slice_idx], axes=(1, 0)) 46 | 47 | # ct_slice_z = np.flip(ct_slice_z, axis=(0, 1)) 48 | # mask_slice_z = np.flip(mask_slice_z, axis=(0, 1)) 49 | 50 | # if crop: 51 | # ct_range_x = np.where(ct_slice_z.max(axis=0) > -200)[0][[0, -1]] 52 | 53 | # ct_slice_z = ct_slice_z[ 54 | # ct_range_x[0] : ct_range_x[1], ct_range_x[0] : ct_range_x[1] 55 | # ] 56 | # mask_slice_z = mask_slice_z[ 57 | # ct_range_x[0] : ct_range_x[1], ct_range_x[0] : ct_range_x[1] 58 | # ] 59 | 60 | # return ct_slice_z, mask_slice_z 61 | 62 | 63 | def extract_coronal_mid_slice(ct, mask, crop=True): 64 | # find the slice with max coherent extent of the organ 65 | coronary_extent = np.where(mask.sum(axis=(0, 2)))[0] 66 | 67 | max_extent = 0 68 | max_extent_idx = 0 69 | 70 | for idx in coronary_extent: 71 | label, num_features = scipy.ndimage.label(mask[:, idx, :]) 72 | 73 | if num_features > 1: 74 | continue 75 | else: 76 | extent = len(np.where(label.sum(axis=1))[0]) 77 | if extent > max_extent: 78 | max_extent = extent 79 | max_extent_idx = idx 80 | 81 | ct_slice_y = np.transpose(ct[:, max_extent_idx, :], axes=(1, 0)) 82 | mask_slice_y = np.transpose(mask[:, max_extent_idx, :], axes=(1, 0)) 83 | 84 | ct_slice_y = np.flip(ct_slice_y, axis=1) 85 | mask_slice_y = np.flip(mask_slice_y, axis=1) 86 | 87 | return ct_slice_y, mask_slice_y 88 | 89 | 90 | def save_slice( 91 | ct_slice, 92 | mask_slice, 93 | path, 94 | figsize=(12, 12), 95 | corner_text=None, 96 | unit_dict=None, 97 | aspect=1, 98 | show=False, 99 | xy_placement=None, 100 | class_color=1, 101 | fontsize=14, 102 | ): 103 | # colormap for shown segmentations 104 | color_array = plt.get_cmap("tab10")(range(10)) 105 | color_array = np.concatenate((np.array([[0, 0, 0, 0]]), color_array[:, :]), axis=0) 106 | map_object_seg = ListedColormap(name="segmentation_cmap", colors=color_array) 107 | 108 | fig, axx = plt.subplots(1, figsize=figsize, frameon=False) 109 | axx.imshow( 110 | ct_slice, 111 | cmap="gray", 112 | vmin=-400, 113 | vmax=400, 114 | interpolation="spline36", 115 | aspect=aspect, 116 | origin="lower", 117 | ) 118 | axx.imshow( 119 | mask_slice * class_color, 120 | cmap=map_object_seg, 121 | vmin=0, 122 | vmax=9, 123 | alpha=0.2, 124 | interpolation="nearest", 125 | aspect=aspect, 126 | origin="lower", 127 | ) 128 | 129 | plt.axis("off") 130 | axx.axes.get_xaxis().set_visible(False) 131 | axx.axes.get_yaxis().set_visible(False) 132 | 133 | y_size, x_size = ct_slice.shape 134 | 135 | if corner_text is not None: 136 | bbox_props = dict(boxstyle="round", facecolor="gray", alpha=0.5) 137 | 138 | texts = [] 139 | for k, v in corner_text.items(): 140 | if isinstance(v, str): 141 | texts.append("{:<9}{}".format(k + ":", v)) 142 | else: 143 | unit = unit_dict[k] if k in unit_dict else "" 144 | texts.append("{:<9}{:.0f} {}".format(k + ":", v, unit)) 145 | 146 | if xy_placement is None: 147 | # get the extent of textbox, remove, and the plot again with correct position 148 | t = axx.text( 149 | 0.5, 150 | 0.5, 151 | "\n".join(texts), 152 | color="white", 153 | transform=axx.transAxes, 154 | fontsize=fontsize, 155 | family="monospace", 156 | bbox=bbox_props, 157 | va="top", 158 | ha="left", 159 | ) 160 | xmin, xmax = t.get_window_extent().xmin, t.get_window_extent().xmax 161 | xmin, xmax = axx.transAxes.inverted().transform((xmin, xmax)) 162 | 163 | xy_placement = [1 - (xmax - xmin) - (xmax - xmin) * 0.09, 0.975] 164 | t.remove() 165 | 166 | axx.text( 167 | xy_placement[0], 168 | xy_placement[1], 169 | "\n".join(texts), 170 | color="white", 171 | transform=axx.transAxes, 172 | fontsize=fontsize, 173 | family="monospace", 174 | bbox=bbox_props, 175 | va="top", 176 | ha="left", 177 | ) 178 | 179 | if show: 180 | plt.show() 181 | else: 182 | fig.savefig(path, bbox_inches="tight", pad_inches=0) 183 | plt.close(fig) 184 | 185 | 186 | def slicedDilationOrErosion(input_mask, num_iteration, operation): 187 | """ 188 | Perform the dilation on the smallest slice that will fit the 189 | segmentation 190 | """ 191 | 192 | # if empty, don't do dilation 193 | if input_mask.sum() == 0: 194 | return input_mask 195 | 196 | margin = 2 if num_iteration is None else num_iteration + 1 197 | 198 | x_size, y_size, z_size = input_mask.shape 199 | 200 | # find the minimum volume enclosing the organ 201 | x_idx = np.where(input_mask.sum(axis=(1, 2)))[0] 202 | x_start, x_end = max(0, x_idx[0] - margin), min(x_idx[-1] + margin, x_size) 203 | 204 | y_idx = np.where(input_mask.sum(axis=(0, 2)))[0] 205 | y_start, y_end = max(0, y_idx[0] - margin), min(y_idx[-1] + margin, y_size) 206 | 207 | z_idx = np.where(input_mask.sum(axis=(0, 1)))[0] 208 | z_start, z_end = max(0, z_idx[0] - margin), min(z_idx[-1] + margin, z_size) 209 | 210 | struct = scipy.ndimage.generate_binary_structure(3, 1) 211 | struct = scipy.ndimage.iterate_structure(struct, num_iteration) 212 | 213 | if operation == "dilate": 214 | mask_slice = scipy.ndimage.binary_dilation( 215 | input_mask[x_start:x_end, y_start:y_end, z_start:z_end], structure=struct 216 | ).astype(np.int8) 217 | elif operation == "erode": 218 | mask_slice = scipy.ndimage.binary_erosion( 219 | input_mask[x_start:x_end, y_start:y_end, z_start:z_end], structure=struct 220 | ).astype(np.int8) 221 | 222 | output_mask = input_mask.copy() 223 | 224 | output_mask[x_start:x_end, y_start:y_end, z_start:z_end] = mask_slice 225 | 226 | return output_mask 227 | 228 | 229 | def extract_organ_metrics( 230 | ct, all_masks, class_num=None, vol_per_pixel=None, erode_mask=True 231 | ): 232 | if erode_mask: 233 | eroded_mask = slicedDilationOrErosion( 234 | input_mask=(all_masks == class_num), num_iteration=3, operation="erode" 235 | ) 236 | ct_organ_vals = ct[eroded_mask == 1] 237 | else: 238 | ct_organ_vals = ct[all_masks == class_num] 239 | 240 | results = {} 241 | 242 | # in ml 243 | organ_vol = (all_masks == class_num).sum() * vol_per_pixel 244 | organ_mean = ct_organ_vals.mean() 245 | organ_median = np.median(ct_organ_vals) 246 | 247 | results = { 248 | "Organ": class_map_part_organs[class_num], 249 | "Volume": organ_vol, 250 | "Mean": organ_mean, 251 | "Median": organ_median, 252 | } 253 | 254 | return results 255 | 256 | 257 | def generate_slice_images( 258 | ct, 259 | all_masks, 260 | class_nums, 261 | unit_dict, 262 | vol_per_pixel, 263 | pix_dims, 264 | root, 265 | fontsize=20, 266 | show=False, 267 | ): 268 | all_results = {} 269 | 270 | colors = [1, 3, 4] 271 | 272 | # create the txt files for the slices idx 273 | with open(os.path.join(root, "slice_idx.csv"), "w") as f: 274 | f.write("organ,mean_HU,axial_idx\n") 275 | 276 | for i, c_num in enumerate(class_nums): 277 | organ_name = class_map_part_organs[c_num] 278 | 279 | axial_path = os.path.join(root, organ_name.lower() + "_axial.png") 280 | coronal_path = os.path.join(root, organ_name.lower() + "_coronal.png") 281 | 282 | ct_slice_z, liver_slice_z, slice_idx_z = extract_axial_mid_slice( 283 | ct, all_masks == c_num 284 | ) 285 | with open(os.path.join(root, "slice_idx.csv"), "a") as f: 286 | mean_hu = ct_slice_z[liver_slice_z == 1].mean() 287 | f.write( 288 | organ_name.lower() 289 | + "," 290 | + "{:.1f}".format(mean_hu) 291 | + "," 292 | + str(slice_idx_z) 293 | + "\n" 294 | ) 295 | 296 | results = extract_organ_metrics( 297 | ct, all_masks, class_num=c_num, vol_per_pixel=vol_per_pixel 298 | ) 299 | 300 | save_slice( 301 | ct_slice_z, 302 | liver_slice_z, 303 | axial_path, 304 | figsize=(12, 12), 305 | corner_text=results, 306 | unit_dict=unit_dict, 307 | class_color=colors[i], 308 | fontsize=fontsize, 309 | show=show, 310 | ) 311 | 312 | ct_slice_y, liver_slice_y = extract_coronal_mid_slice(ct, all_masks == c_num) 313 | 314 | save_slice( 315 | ct_slice_y, 316 | liver_slice_y, 317 | coronal_path, 318 | figsize=(12, 12), 319 | aspect=pix_dims[2] / pix_dims[1], 320 | show=show, 321 | class_color=colors[i], 322 | ) 323 | 324 | all_results[results["Organ"]] = results 325 | 326 | if show: 327 | return 328 | 329 | return all_results 330 | 331 | 332 | def generate_liver_spleen_pancreas_report(root, organ_names): 333 | 334 | axial_imgs = [ 335 | Image.open(os.path.join(root, organ + "_axial.png")) for organ in organ_names 336 | ] 337 | coronal_imgs = [ 338 | Image.open(os.path.join(root, organ + "_coronal.png")) for organ in organ_names 339 | ] 340 | 341 | result_width = max( 342 | sum([img.size[0] for img in axial_imgs]), 343 | sum([img.size[0] for img in coronal_imgs]), 344 | ) 345 | result_height = max( 346 | [a.size[1] + c.size[1] for a, c in zip(axial_imgs, coronal_imgs)] 347 | ) 348 | 349 | result = Image.new("RGB", (result_width, result_height)) 350 | 351 | total_width = 0 352 | 353 | for a_img, c_img in zip(axial_imgs, coronal_imgs): 354 | a_width, a_height = a_img.size 355 | c_width, c_height = c_img.size 356 | 357 | translate = (a_width - c_width) // 2 if a_width > c_width else 0 358 | 359 | result.paste(im=a_img, box=(total_width, 0)) 360 | result.paste(im=c_img, box=(translate + total_width, a_height)) 361 | 362 | total_width += a_width 363 | 364 | result.save(os.path.join(root, "liver_spleen_pancreas_report.png")) 365 | 366 | 367 | # from https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/map_to_binary.py 368 | 369 | class_map_part_organs = { 370 | 1: "spleen", 371 | 2: "kidney_right", 372 | 3: "kidney_left", 373 | 4: "gallbladder", 374 | 5: "liver", 375 | 6: "stomach", 376 | 7: "pancreas", 377 | 8: "adrenal_gland_right", 378 | 9: "adrenal_gland_left", 379 | 10: "lung_upper_lobe_left", 380 | 11: "lung_lower_lobe_left", 381 | 12: "lung_upper_lobe_right", 382 | 13: "lung_middle_lobe_right", 383 | 14: "lung_lower_lobe_right", 384 | 15: "esophagus", 385 | 16: "trachea", 386 | 17: "thyroid_gland", 387 | 18: "small_bowel", 388 | 19: "duodenum", 389 | 20: "colon", 390 | 21: "urinary_bladder", 391 | 22: "prostate", 392 | 23: "kidney_cyst_left", 393 | 24: "kidney_cyst_right", 394 | 25: "sacrum", 395 | 26: "vertebrae_S1", 396 | 27: "vertebrae_L5", 397 | 28: "vertebrae_L4", 398 | 29: "vertebrae_L3", 399 | 30: "vertebrae_L2", 400 | 31: "vertebrae_L1", 401 | 32: "vertebrae_T12", 402 | 33: "vertebrae_T11", 403 | 34: "vertebrae_T10", 404 | 35: "vertebrae_T9", 405 | 36: "vertebrae_T8", 406 | 37: "vertebrae_T7", 407 | 38: "vertebrae_T6", 408 | 39: "vertebrae_T5", 409 | 40: "vertebrae_T4", 410 | 41: "vertebrae_T3", 411 | 42: "vertebrae_T2", 412 | 43: "vertebrae_T1", 413 | 44: "vertebrae_C7", 414 | 45: "vertebrae_C6", 415 | 46: "vertebrae_C5", 416 | 47: "vertebrae_C4", 417 | 48: "vertebrae_C3", 418 | 49: "vertebrae_C2", 419 | 50: "vertebrae_C1", 420 | 51: "heart", 421 | 52: "aorta", 422 | 53: "pulmonary_vein", 423 | 54: "brachiocephalic_trunk", 424 | 55: "subclavian_artery_right", 425 | 56: "subclavian_artery_left", 426 | 57: "common_carotid_artery_right", 427 | 58: "common_carotid_artery_left", 428 | 59: "brachiocephalic_vein_left", 429 | 60: "brachiocephalic_vein_right", 430 | 61: "atrial_appendage_left", 431 | 62: "superior_vena_cava", 432 | 63: "inferior_vena_cava", 433 | 64: "portal_vein_and_splenic_vein", 434 | 65: "iliac_artery_left", 435 | 66: "iliac_artery_right", 436 | 67: "iliac_vena_left", 437 | 68: "iliac_vena_right", 438 | 69: "humerus_left", 439 | 70: "humerus_right", 440 | 71: "scapula_left", 441 | 72: "scapula_right", 442 | 73: "clavicula_left", 443 | 74: "clavicula_right", 444 | 75: "femur_left", 445 | 76: "femur_right", 446 | 77: "hip_left", 447 | 78: "hip_right", 448 | 79: "spinal_cord", 449 | 80: "gluteus_maximus_left", 450 | 81: "gluteus_maximus_right", 451 | 82: "gluteus_medius_left", 452 | 83: "gluteus_medius_right", 453 | 84: "gluteus_minimus_left", 454 | 85: "gluteus_minimus_right", 455 | 86: "autochthon_left", 456 | 87: "autochthon_right", 457 | 88: "iliopsoas_left", 458 | 89: "iliopsoas_right", 459 | 90: "brain", 460 | 91: "skull", 461 | 92: "rib_left_1", 462 | 93: "rib_left_2", 463 | 94: "rib_left_3", 464 | 95: "rib_left_4", 465 | 96: "rib_left_5", 466 | 97: "rib_left_6", 467 | 98: "rib_left_7", 468 | 99: "rib_left_8", 469 | 100: "rib_left_9", 470 | 101: "rib_left_10", 471 | 102: "rib_left_11", 472 | 103: "rib_left_12", 473 | 104: "rib_right_1", 474 | 105: "rib_right_2", 475 | 106: "rib_right_3", 476 | 107: "rib_right_4", 477 | 108: "rib_right_5", 478 | 109: "rib_right_6", 479 | 110: "rib_right_7", 480 | 111: "rib_right_8", 481 | 112: "rib_right_9", 482 | 113: "rib_right_10", 483 | 114: "rib_right_11", 484 | 115: "rib_right_12", 485 | 116: "sternum", 486 | 117: "costal_cartilages", 487 | } 488 | -------------------------------------------------------------------------------- /comp2comp/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Callable, Sequence, Union 3 | 4 | import numpy as np 5 | 6 | 7 | def flatten_non_category_dims( 8 | xs: Union[np.ndarray, Sequence[np.ndarray]], category_dim: int = None 9 | ): 10 | """Flattens all non-category dimensions into a single dimension. 11 | 12 | Args: 13 | xs (ndarrays): Sequence of ndarrays with the same category dimension. 14 | category_dim: The dimension/axis corresponding to different categories. 15 | i.e. `C`. If `None`, behaves like `np.flatten(x)`. 16 | 17 | Returns: 18 | ndarray: Shape (C, -1) if `category_dim` specified else shape (-1,) 19 | """ 20 | single_item = isinstance(xs, np.ndarray) 21 | if single_item: 22 | xs = [xs] 23 | 24 | if category_dim is not None: 25 | dims = (xs[0].shape[category_dim], -1) 26 | xs = (np.moveaxis(x, category_dim, 0).reshape(dims) for x in xs) 27 | else: 28 | xs = (x.flatten() for x in xs) 29 | 30 | if single_item: 31 | return list(xs)[0] 32 | else: 33 | return xs 34 | 35 | 36 | class Metric(Callable, ABC): 37 | """Interface for new metrics. 38 | 39 | A metric should be implemented as a callable with explicitly defined 40 | arguments. In other words, metrics should not have `**kwargs` or `**args` 41 | options in the `__call__` method. 42 | 43 | While not explicitly constrained to the return type, metrics typically 44 | return float value(s). The number of values returned corresponds to the 45 | number of categories. 46 | 47 | * metrics should have different name() for different functionality. 48 | * `category_dim` duck type if metric can process multiple categories at 49 | once. 50 | 51 | To compute metrics: 52 | 53 | .. code-block:: python 54 | 55 | metric = Metric() 56 | results = metric(...) 57 | """ 58 | 59 | def __init__(self, units: str = ""): 60 | self.units = units 61 | 62 | def name(self): 63 | return type(self).__name__ 64 | 65 | def display_name(self): 66 | """Name to use for pretty printing and display purposes.""" 67 | name = self.name() 68 | return "{} {}".format(name, self.units) if self.units else name 69 | 70 | @abstractmethod 71 | def __call__(self, *args, **kwargs): 72 | pass 73 | 74 | 75 | class HounsfieldUnits(Metric): 76 | FULL_NAME = "Hounsfield Unit" 77 | 78 | def __init__(self, units="hu"): 79 | super().__init__(units) 80 | 81 | def __call__(self, mask, x, category_dim: int = None): 82 | mask = mask.astype(np.bool) 83 | if category_dim is None: 84 | return np.mean(x[mask]) 85 | 86 | assert category_dim == -1 87 | num_classes = mask.shape[-1] 88 | 89 | return np.array([np.mean(x[mask[..., c]]) for c in range(num_classes)]) 90 | 91 | def name(self): 92 | return self.FULL_NAME 93 | 94 | 95 | class CrossSectionalArea(Metric): 96 | def __call__(self, mask, spacing=None, category_dim: int = None): 97 | pixel_area = np.prod(spacing) if spacing else 1 98 | mask = mask.astype(np.bool) 99 | mask = flatten_non_category_dims(mask, category_dim) 100 | 101 | return pixel_area * np.count_nonzero(mask, -1) / 100.0 102 | 103 | def name(self): 104 | if self.units: 105 | return "Cross-sectional Area ({})".format(self.units) 106 | else: 107 | return "Cross-sectional Area" 108 | 109 | 110 | def manifest_to_map(manifest, model_type): 111 | """Converts a manifest to a map of metric name to metric instance. 112 | 113 | Args: 114 | manifest (dict): A dictionary of metric name to metric instance. 115 | 116 | Returns: 117 | dict: A dictionary of metric name to metric instance. 118 | """ 119 | # TODO: hacky. Update this 120 | figure_text_key = {} 121 | for manifest_dict in manifest: 122 | try: 123 | key = manifest_dict["Level"] 124 | except BaseException: 125 | key = ".".join((manifest_dict["File"].split("/")[-1]).split(".")[:-1]) 126 | muscle_hu = f"{manifest_dict['Hounsfield Unit (muscle)']:.2f}" 127 | muscle_area = f"{manifest_dict['Cross-sectional Area (cm^2) (muscle)']:.2f}" 128 | vat_hu = f"{manifest_dict['Hounsfield Unit (vat)']:.2f}" 129 | vat_area = f"{manifest_dict['Cross-sectional Area (cm^2) (vat)']:.2f}" 130 | sat_hu = f"{manifest_dict['Hounsfield Unit (sat)']:.2f}" 131 | sat_area = f"{manifest_dict['Cross-sectional Area (cm^2) (sat)']:.2f}" 132 | imat_hu = f"{manifest_dict['Hounsfield Unit (imat)']:.2f}" 133 | imat_area = f"{manifest_dict['Cross-sectional Area (cm^2) (imat)']:.2f}" 134 | if model_type.model_name == "abCT_v0.0.1": 135 | figure_text_key[key] = [ 136 | muscle_hu, 137 | muscle_area, 138 | imat_hu, 139 | imat_area, 140 | vat_hu, 141 | vat_area, 142 | sat_hu, 143 | sat_area, 144 | ] 145 | else: 146 | figure_text_key[key] = [ 147 | muscle_hu, 148 | muscle_area, 149 | vat_hu, 150 | vat_area, 151 | sat_hu, 152 | sat_area, 153 | imat_hu, 154 | imat_area, 155 | ] 156 | return figure_text_key 157 | -------------------------------------------------------------------------------- /comp2comp/models/models.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import os 3 | from pathlib import Path 4 | from typing import Dict, Sequence 5 | 6 | import wget 7 | from keras.models import load_model 8 | 9 | 10 | class Models(enum.Enum): 11 | ABCT_V_0_0_1 = ( 12 | 1, 13 | "abCT_v0.0.1", 14 | {"muscle": 0, "imat": 1, "vat": 2, "sat": 3}, 15 | False, 16 | ("soft", "bone", "custom"), 17 | ) 18 | 19 | STANFORD_V_0_0_1 = ( 20 | 2, 21 | "stanford_v0.0.1", 22 | # ("background", "muscle", "bone", "vat", "sat", "imat"), 23 | # Category name mapped to channel index 24 | {"muscle": 1, "vat": 3, "sat": 4, "imat": 5}, 25 | True, 26 | ("soft", "bone", "custom"), 27 | ) 28 | 29 | STANFORD_V_0_0_2 = ( 30 | 3, 31 | "stanford_v0.0.2", 32 | {"muscle": 4, "sat": 1, "vat": 2, "imat": 3}, 33 | True, 34 | ("soft", "bone", "custom"), 35 | ) 36 | TS_SPINE_FULL = ( 37 | 4, 38 | "ts_spine_full", 39 | # Category name mapped to channel index 40 | { 41 | "L5": 18, 42 | "L4": 19, 43 | "L3": 20, 44 | "L2": 21, 45 | "L1": 22, 46 | "T12": 23, 47 | "T11": 24, 48 | "T10": 25, 49 | "T9": 26, 50 | "T8": 27, 51 | "T7": 28, 52 | "T6": 29, 53 | "T5": 30, 54 | "T4": 31, 55 | "T3": 32, 56 | "T2": 33, 57 | "T1": 34, 58 | "C7": 35, 59 | "C6": 36, 60 | "C5": 37, 61 | "C4": 38, 62 | "C3": 39, 63 | "C2": 40, 64 | "C1": 41, 65 | }, 66 | False, 67 | (), 68 | ) 69 | TS_SPINE = ( 70 | 5, 71 | "ts_spine", 72 | # Category name mapped to channel index 73 | # {"L5": 18, "L4": 19, "L3": 20, "L2": 21, "L1": 22, "T12": 23}, 74 | {"L5": 27, "L4": 28, "L3": 29, "L2": 30, "L1": 31, "T12": 32}, 75 | False, 76 | (), 77 | ) 78 | STANFORD_SPINE_V_0_0_1 = ( 79 | 6, 80 | "stanford_spine_v0.0.1", 81 | # Category name mapped to channel index 82 | {"L5": 24, "L4": 23, "L3": 22, "L2": 21, "L1": 20, "T12": 19}, 83 | False, 84 | (), 85 | ) 86 | TS_HIP = ( 87 | 7, 88 | "ts_hip", 89 | # Category name mapped to channel index 90 | {"femur_left": 88, "femur_right": 89}, 91 | False, 92 | (), 93 | ) 94 | 95 | def __new__( 96 | cls, 97 | value: int, 98 | model_name: str, 99 | categories: Dict[str, int], 100 | use_softmax: bool, 101 | windows: Sequence[str], 102 | ): 103 | obj = object.__new__(cls) 104 | obj._value_ = value 105 | 106 | obj.model_name = model_name 107 | obj.categories = categories 108 | obj.use_softmax = use_softmax 109 | obj.windows = windows 110 | return obj 111 | 112 | def load_model(self, model_dir): 113 | """Load the model from the models directory. 114 | 115 | Args: 116 | logger (logging.Logger): Logger. 117 | 118 | Returns: 119 | keras.models.Model: Model. 120 | """ 121 | try: 122 | filename = Models.find_model_weights(self.model_name, model_dir) 123 | except Exception: 124 | print("Downloading muscle/fat model from hugging face") 125 | Path(model_dir).mkdir(parents=True, exist_ok=True) 126 | wget.download( 127 | f"https://huggingface.co/stanfordmimi/stanford_abct_v0.0.1/resolve/main/{self.model_name}.h5", 128 | out=os.path.join(model_dir, f"{self.model_name}.h5"), 129 | ) 130 | filename = Models.find_model_weights(self.model_name, model_dir) 131 | print("") 132 | 133 | print("Loading muscle/fat model from {}".format(filename)) 134 | return load_model(filename) 135 | 136 | @staticmethod 137 | def model_from_name(model_name): 138 | """Get the model enum from the model name. 139 | 140 | Args: 141 | model_name (str): Model name. 142 | 143 | Returns: 144 | Models: Model enum. 145 | """ 146 | for model in Models: 147 | if model.model_name == model_name: 148 | return model 149 | return None 150 | 151 | @staticmethod 152 | def find_model_weights(file_name, model_dir): 153 | for root, _, files in os.walk(model_dir): 154 | for file in files: 155 | if file.startswith(file_name): 156 | filename = os.path.join(root, file) 157 | return filename 158 | -------------------------------------------------------------------------------- /comp2comp/muscle_adipose_tissue/data.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Sequence 3 | 4 | import keras.utils as k_utils 5 | import numpy as np 6 | import pydicom 7 | from keras.utils.data_utils import OrderedEnqueuer 8 | from tqdm import tqdm 9 | 10 | 11 | def parse_windows(windows): 12 | """Parse windows provided by the user. 13 | 14 | These windows can either be strings corresponding to popular windowing 15 | thresholds for CT or tuples of (upper, lower) bounds. 16 | 17 | Args: 18 | windows (list): List of strings or tuples. 19 | 20 | Returns: 21 | list: List of tuples of (upper, lower) bounds. 22 | """ 23 | windowing = { 24 | "soft": (400, 50), 25 | "bone": (1800, 400), 26 | "liver": (150, 30), 27 | "spine": (250, 50), 28 | "custom": (500, 50), 29 | } 30 | vals = [] 31 | for w in windows: 32 | if isinstance(w, Sequence) and len(w) == 2: 33 | assert_msg = "Expected tuple of (lower, upper) bound" 34 | assert len(w) == 2, assert_msg 35 | assert isinstance(w[0], (float, int)), assert_msg 36 | assert isinstance(w[1], (float, int)), assert_msg 37 | assert w[0] < w[1], assert_msg 38 | vals.append(w) 39 | continue 40 | 41 | if w not in windowing: 42 | raise KeyError("Window {} not found".format(w)) 43 | window_width = windowing[w][0] 44 | window_level = windowing[w][1] 45 | upper = window_level + window_width / 2 46 | lower = window_level - window_width / 2 47 | 48 | vals.append((lower, upper)) 49 | 50 | return tuple(vals) 51 | 52 | 53 | def _window(xs, bounds): 54 | """Apply windowing to an array of CT images. 55 | 56 | Args: 57 | xs (ndarray): NxHxW 58 | bounds (tuple): (lower, upper) bounds 59 | 60 | Returns: 61 | ndarray: Windowed images. 62 | """ 63 | 64 | imgs = [] 65 | for lb, ub in bounds: 66 | imgs.append(np.clip(xs, a_min=lb, a_max=ub)) 67 | 68 | if len(imgs) == 1: 69 | return imgs[0] 70 | elif xs.shape[-1] == 1: 71 | return np.concatenate(imgs, axis=-1) 72 | else: 73 | return np.stack(imgs, axis=-1) 74 | 75 | 76 | class Dataset(k_utils.Sequence): 77 | def __init__(self, files: List[str], batch_size: int = 16, windows=None): 78 | self._files = files 79 | self._batch_size = batch_size 80 | self.windows = windows 81 | 82 | def __len__(self): 83 | return math.ceil(len(self._files) / self._batch_size) 84 | 85 | def __getitem__(self, idx): 86 | files = self._files[idx * self._batch_size : (idx + 1) * self._batch_size] 87 | dcms = [pydicom.read_file(f, force=True) for f in files] 88 | 89 | xs = [(x.pixel_array + int(x.RescaleIntercept)).astype("float32") for x in dcms] 90 | 91 | params = [ 92 | {"spacing": header.PixelSpacing, "image": x} for header, x in zip(dcms, xs) 93 | ] 94 | 95 | # Preprocess xs via windowing. 96 | xs = np.stack(xs, axis=0) 97 | if self.windows: 98 | xs = _window(xs, parse_windows(self.windows)) 99 | else: 100 | xs = xs[..., np.newaxis] 101 | 102 | return xs, params 103 | 104 | 105 | def _swap_muscle_imap(xs, ys, muscle_idx: int, imat_idx: int, threshold=-30.0): 106 | """ 107 | If pixel labeled as muscle but has HU < threshold, change label to imat. 108 | 109 | Args: 110 | xs (ndarray): NxHxWxC 111 | ys (ndarray): NxHxWxC 112 | muscle_idx (int): Index of the muscle label. 113 | imat_idx (int): Index of the imat label. 114 | threshold (float): Threshold for HU value. 115 | 116 | Returns: 117 | ndarray: Segmentation mask with swapped labels. 118 | """ 119 | labels = ys.copy() 120 | 121 | muscle_mask = (labels[..., muscle_idx] > 0.5).astype(int) 122 | imat_mask = labels[..., imat_idx] 123 | 124 | imat_mask[muscle_mask.astype(np.bool) & (xs < threshold)] = 1 125 | muscle_mask[xs < threshold] = 0 126 | 127 | labels[..., muscle_idx] = muscle_mask 128 | labels[..., imat_idx] = imat_mask 129 | 130 | return labels 131 | 132 | 133 | def postprocess(xs: np.ndarray, ys: np.ndarray): 134 | """Built-in post-processing. 135 | 136 | TODO: Make this configurable. 137 | 138 | Args: 139 | xs (ndarray): NxHxW 140 | ys (ndarray): NxHxWxC 141 | params (dictionary): Post-processing parameters. Must contain 142 | "categories". 143 | 144 | Returns: 145 | ndarray: Post-processed labels. 146 | """ 147 | 148 | # Add another channel full of zeros to ys 149 | ys = np.concatenate([ys, np.zeros_like(ys[..., :1])], axis=-1) 150 | 151 | # If muscle hu is < -30, assume it is imat. 152 | 153 | """ 154 | if "muscle" in categories and "imat" in categories: 155 | ys = _swap_muscle_imap( 156 | xs, 157 | ys, 158 | muscle_idx=categories["muscle"], 159 | imat_idx=categories["imat"], 160 | ) 161 | """ 162 | 163 | return ys 164 | 165 | 166 | def predict( 167 | model, 168 | dataset: Dataset, 169 | batch_size: int = 16, 170 | num_workers: int = 1, 171 | max_queue_size: int = 10, 172 | use_multiprocessing: bool = False, 173 | ): 174 | """Predict segmentation masks for a dataset. 175 | 176 | Args: 177 | model (keras.Model): Model to use for prediction. 178 | dataset (Dataset): Dataset to predict on. 179 | batch_size (int): Batch size. 180 | num_workers (int): Number of workers. 181 | max_queue_size (int): Maximum queue size. 182 | use_multiprocessing (bool): Use multiprocessing. 183 | use_postprocessing (bool): Use built-in post-processing. 184 | postprocessing_params (dict): Post-processing parameters. 185 | 186 | Returns: 187 | List: List of segmentation masks. 188 | """ 189 | 190 | if num_workers > 0: 191 | enqueuer = OrderedEnqueuer( 192 | dataset, use_multiprocessing=use_multiprocessing, shuffle=False 193 | ) 194 | enqueuer.start(workers=num_workers, max_queue_size=max_queue_size) 195 | output_generator = enqueuer.get() 196 | else: 197 | output_generator = iter(dataset) 198 | 199 | num_scans = len(dataset) 200 | xs = [] 201 | ys = [] 202 | params = [] 203 | for _ in tqdm(range(num_scans)): 204 | x, p_dicts = next(output_generator) 205 | y = model.predict(x, batch_size=batch_size) 206 | 207 | image = np.stack([out["image"] for out in p_dicts], axis=0) 208 | y = postprocess(image, y) 209 | 210 | params.extend(p_dicts) 211 | xs.extend([x[i, ...] for i in range(len(x))]) 212 | ys.extend([y[i, ...] for i in range(len(y))]) 213 | 214 | return xs, ys, params 215 | -------------------------------------------------------------------------------- /comp2comp/muscle_adipose_tissue/muscle_adipose_tissue_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | import os 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | 10 | from comp2comp.inference_class_base import InferenceClass 11 | from comp2comp.visualization.detectron_visualizer import Visualizer 12 | 13 | 14 | class MuscleAdiposeTissueVisualizer(InferenceClass): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | self._spine_colors = { 19 | "L5": [255, 0, 0], 20 | "L4": [0, 255, 0], 21 | "L3": [255, 255, 0], 22 | "L2": [255, 128, 0], 23 | "L1": [0, 255, 255], 24 | "T12": [255, 0, 255], 25 | } 26 | 27 | self._muscle_fat_colors = { 28 | "muscle": [255, 136, 133], 29 | "imat": [154, 135, 224], 30 | "vat": [140, 197, 135], 31 | "sat": [246, 190, 129], 32 | } 33 | 34 | self._SPINE_TEXT_OFFSET_FROM_TOP = 10.0 35 | self._SPINE_TEXT_OFFSET_FROM_RIGHT = 63.0 36 | self._SPINE_TEXT_VERTICAL_SPACING = 14.0 37 | 38 | self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING = 40.0 39 | self._MUSCLE_FAT_TEXT_VERTICAL_SPACING = 14.0 40 | self._MUSCLE_FAT_TEXT_OFFSET_FROM_TOP = 22.0 41 | self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT = 181.0 42 | 43 | def __call__(self, inference_pipeline, images, results): 44 | self.output_dir = inference_pipeline.output_dir 45 | self.dicom_file_names = inference_pipeline.dicom_file_names 46 | # if spine is an attribute of the inference pipeline, use it 47 | if not hasattr(inference_pipeline, "spine"): 48 | spine = False 49 | else: 50 | spine = True 51 | self.spine_masks = inference_pipeline.spine_masks 52 | 53 | for i, (image, result) in enumerate(zip(images, results)): 54 | # now, result is a dict with keys for each tissue 55 | dicom_file_name = self.dicom_file_names[i] 56 | self.save_binary_segmentation_overlay(image, result, dicom_file_name, spine) 57 | # pass along for next class in pipeline 58 | return {"results": results} 59 | 60 | def save_binary_segmentation_overlay(self, image, result, dicom_file_name, spine): 61 | file_name = dicom_file_name + ".png" 62 | img_in = image 63 | assert img_in.shape == (512, 512), "Image shape is not 512 x 512" 64 | 65 | img_in = np.clip(img_in, -300, 1800) 66 | img_in = self.normalize_img(img_in) * 255.0 67 | 68 | # Create the folder to save the images 69 | images_base_path = Path(self.output_dir) / "images" 70 | images_base_path.mkdir(exist_ok=True) 71 | 72 | text_start_vertical_offset = self._MUSCLE_FAT_TEXT_OFFSET_FROM_TOP 73 | 74 | img_in = img_in.reshape((img_in.shape[0], img_in.shape[1], 1)) 75 | img_rgb = np.tile(img_in, (1, 1, 3)) 76 | 77 | vis = Visualizer(img_rgb) 78 | vis.draw_text( 79 | text="Density (HU)", 80 | position=( 81 | img_in.shape[1] - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT - 63, 82 | text_start_vertical_offset, 83 | ), 84 | color=[1, 1, 1], 85 | font_size=9, 86 | horizontal_alignment="left", 87 | ) 88 | vis.draw_text( 89 | text="Area (CM²)", 90 | position=( 91 | img_in.shape[1] - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT - 63, 92 | text_start_vertical_offset + self._MUSCLE_FAT_TEXT_VERTICAL_SPACING, 93 | ), 94 | color=[1, 1, 1], 95 | font_size=9, 96 | horizontal_alignment="left", 97 | ) 98 | 99 | if spine: 100 | spine_color = np.array(self._spine_colors[dicom_file_name]) / 255.0 101 | vis.draw_box( 102 | box_coord=(1, 1, img_in.shape[0] - 1, img_in.shape[1] - 1), 103 | alpha=1, 104 | edge_color=spine_color, 105 | ) 106 | # draw the level T12 - L5 in the upper left corner 107 | if dicom_file_name == "T12": 108 | position = (40, 15) 109 | else: 110 | position = (30, 15) 111 | vis.draw_text( 112 | text=dicom_file_name, position=position, color=spine_color, font_size=24 113 | ) 114 | vis.draw_binary_mask( 115 | self.spine_masks[dicom_file_name], 116 | color=spine_color, 117 | alpha=0.9, 118 | area_threshold=0, 119 | ) 120 | 121 | for idx, tissue in enumerate(result.keys()): 122 | alpha_val = 0.9 123 | color = np.array(self._muscle_fat_colors[tissue]) / 255.0 124 | edge_color = color 125 | mask = result[tissue]["mask"] 126 | 127 | vis.draw_binary_mask( 128 | mask, 129 | color=color, 130 | edge_color=edge_color, 131 | alpha=alpha_val, 132 | area_threshold=0, 133 | ) 134 | 135 | hu_val = round(result[tissue]["Hounsfield Unit"]) 136 | area_val = round(result[tissue]["Cross-sectional Area (cm^2)"]) 137 | 138 | vis.draw_text( 139 | text=tissue, 140 | position=( 141 | mask.shape[1] 142 | - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT 143 | + self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING * (idx + 1), 144 | text_start_vertical_offset - self._MUSCLE_FAT_TEXT_VERTICAL_SPACING, 145 | ), 146 | color=color, 147 | font_size=9, 148 | horizontal_alignment="center", 149 | ) 150 | 151 | vis.draw_text( 152 | text=hu_val, 153 | position=( 154 | mask.shape[1] 155 | - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT 156 | + self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING * (idx + 1), 157 | text_start_vertical_offset, 158 | ), 159 | color=color, 160 | font_size=9, 161 | horizontal_alignment="center", 162 | ) 163 | vis.draw_text( 164 | text=area_val, 165 | position=( 166 | mask.shape[1] 167 | - self._MUSCLE_FAT_TEXT_OFFSET_FROM_RIGHT 168 | + self._MUSCLE_FAT_TEXT_HORIZONTAL_SPACING * (idx + 1), 169 | text_start_vertical_offset + self._MUSCLE_FAT_TEXT_VERTICAL_SPACING, 170 | ), 171 | color=color, 172 | font_size=9, 173 | horizontal_alignment="center", 174 | ) 175 | 176 | vis_obj = vis.get_output() 177 | vis_obj.save(os.path.join(images_base_path, file_name)) 178 | 179 | def normalize_img(self, img: np.ndarray) -> np.ndarray: 180 | """Normalize the image. 181 | 182 | Args: 183 | img (np.ndarray): Input image. 184 | 185 | Returns: 186 | np.ndarray: Normalized image. 187 | """ 188 | return (img - img.min()) / (img.max() - img.min()) 189 | -------------------------------------------------------------------------------- /comp2comp/spine/spine_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | import os 6 | from pathlib import Path 7 | from typing import Union 8 | 9 | import numpy as np 10 | 11 | from comp2comp.visualization.detectron_visualizer import Visualizer 12 | 13 | 14 | def spine_binary_segmentation_overlay( 15 | img_in: Union[str, Path], 16 | mask: Union[str, Path], 17 | base_path: Union[str, Path], 18 | file_name: str, 19 | figure_text_key=None, 20 | spine_hus=None, 21 | seg_hus=None, 22 | spine=True, 23 | model_type=None, 24 | pixel_spacing=None, 25 | ): 26 | """Save binary segmentation overlay. 27 | Args: 28 | img_in (Union[str, Path]): Path to the input image. 29 | mask (Union[str, Path]): Path to the mask. 30 | base_path (Union[str, Path]): Path to the output directory. 31 | file_name (str): Output file name. 32 | centroids (list, optional): List of centroids. Defaults to None. 33 | figure_text_key (dict, optional): Figure text key. Defaults to None. 34 | spine_hus (list, optional): List of HU values. Defaults to None. 35 | spine (bool, optional): Spine flag. Defaults to True. 36 | model_type (Models): Model type. Defaults to None. 37 | """ 38 | _COLORS = ( 39 | np.array( 40 | [ 41 | 1.000, 42 | 0.000, 43 | 0.000, 44 | 0.000, 45 | 1.000, 46 | 0.000, 47 | 1.000, 48 | 1.000, 49 | 0.000, 50 | 1.000, 51 | 0.500, 52 | 0.000, 53 | 0.000, 54 | 1.000, 55 | 1.000, 56 | 1.000, 57 | 0.000, 58 | 1.000, 59 | ] 60 | ) 61 | .astype(np.float32) 62 | .reshape(-1, 3) 63 | ) 64 | 65 | label_map = {"L5": 0, "L4": 1, "L3": 2, "L2": 3, "L1": 4, "T12": 5} 66 | 67 | _ROI_COLOR = np.array([1.000, 0.340, 0.200]) 68 | 69 | _SPINE_TEXT_OFFSET_FROM_TOP = 10.0 70 | _SPINE_TEXT_OFFSET_FROM_RIGHT = 40.0 71 | _SPINE_TEXT_VERTICAL_SPACING = 14.0 72 | 73 | img_in = np.clip(img_in, -300, 1800) 74 | img_in = normalize_img(img_in) * 255.0 75 | images_base_path = Path(base_path) / "images" 76 | images_base_path.mkdir(exist_ok=True) 77 | 78 | img_in = img_in.reshape((img_in.shape[0], img_in.shape[1], 1)) 79 | img_rgb = np.tile(img_in, (1, 1, 3)) 80 | 81 | vis = Visualizer(img_rgb) 82 | 83 | levels = list(spine_hus.keys()) 84 | levels.reverse() 85 | num_levels = len(levels) 86 | 87 | # draw seg masks 88 | for i, level in enumerate(levels): 89 | color = _COLORS[label_map[level]] 90 | edge_color = None 91 | alpha_val = 0.2 92 | vis.draw_binary_mask( 93 | mask[:, :, i].astype(int), 94 | color=color, 95 | edge_color=edge_color, 96 | alpha=alpha_val, 97 | area_threshold=0, 98 | ) 99 | 100 | # draw rois 101 | for i, _ in enumerate(levels): 102 | color = _ROI_COLOR 103 | edge_color = color 104 | vis.draw_binary_mask( 105 | mask[:, :, num_levels + i].astype(int), 106 | color=color, 107 | edge_color=edge_color, 108 | alpha=alpha_val, 109 | area_threshold=0, 110 | ) 111 | 112 | vis.draw_text( 113 | text="ROI", 114 | position=( 115 | mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 35, 116 | _SPINE_TEXT_OFFSET_FROM_TOP, 117 | ), 118 | color=[1, 1, 1], 119 | font_size=9, 120 | horizontal_alignment="center", 121 | ) 122 | 123 | vis.draw_text( 124 | text="Seg", 125 | position=( 126 | mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT, 127 | _SPINE_TEXT_OFFSET_FROM_TOP, 128 | ), 129 | color=[1, 1, 1], 130 | font_size=9, 131 | horizontal_alignment="center", 132 | ) 133 | 134 | # draw text and lines 135 | for i, level in enumerate(levels): 136 | vis.draw_text( 137 | text=f"{level}:", 138 | position=( 139 | mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 80, 140 | _SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP, 141 | ), 142 | color=_COLORS[label_map[level]], 143 | font_size=9, 144 | horizontal_alignment="left", 145 | ) 146 | vis.draw_text( 147 | text=f"{round(float(spine_hus[level]))}", 148 | position=( 149 | mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT - 35, 150 | _SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP, 151 | ), 152 | color=_COLORS[label_map[level]], 153 | font_size=9, 154 | horizontal_alignment="center", 155 | ) 156 | vis.draw_text( 157 | text=f"{round(float(seg_hus[level]))}", 158 | position=( 159 | mask.shape[1] - _SPINE_TEXT_OFFSET_FROM_RIGHT, 160 | _SPINE_TEXT_VERTICAL_SPACING * (i + 1) + _SPINE_TEXT_OFFSET_FROM_TOP, 161 | ), 162 | color=_COLORS[label_map[level]], 163 | font_size=9, 164 | horizontal_alignment="center", 165 | ) 166 | 167 | """ 168 | vis.draw_line( 169 | x_data=(0, mask.shape[1] - 1), 170 | y_data=( 171 | int( 172 | inferior_superior_centers[num_levels - i - 1] 173 | * (pixel_spacing[2] / pixel_spacing[1]) 174 | ), 175 | int( 176 | inferior_superior_centers[num_levels - i - 1] 177 | * (pixel_spacing[2] / pixel_spacing[1]) 178 | ), 179 | ), 180 | color=_COLORS[label_map[level]], 181 | linestyle="dashed", 182 | linewidth=0.25, 183 | ) 184 | """ 185 | 186 | vis_obj = vis.get_output() 187 | img = vis_obj.save(os.path.join(images_base_path, file_name)) 188 | return img 189 | 190 | 191 | def normalize_img(img: np.ndarray) -> np.ndarray: 192 | """Normalize the image. 193 | Args: 194 | img (np.ndarray): Input image. 195 | Returns: 196 | np.ndarray: Normalized image. 197 | """ 198 | return (img - img.min()) / (img.max() - img.min()) 199 | -------------------------------------------------------------------------------- /comp2comp/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/comp2comp/utils/__init__.py -------------------------------------------------------------------------------- /comp2comp/utils/colormap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | """ 4 | An awesome colormap for really neat visualizations. 5 | Copied from Detectron, and removed gray colors. 6 | """ 7 | 8 | import random 9 | 10 | import numpy as np 11 | 12 | __all__ = ["colormap", "random_color", "random_colors"] 13 | 14 | # fmt: off 15 | # RGB: 16 | _COLORS = np.array( 17 | [ 18 | 0.000, 0.447, 0.741, 19 | 0.850, 0.325, 0.098, 20 | 0.929, 0.694, 0.125, 21 | 0.494, 0.184, 0.556, 22 | 0.466, 0.674, 0.188, 23 | 0.301, 0.745, 0.933, 24 | 0.635, 0.078, 0.184, 25 | 0.300, 0.300, 0.300, 26 | 0.600, 0.600, 0.600, 27 | 1.000, 0.000, 0.000, 28 | 1.000, 0.500, 0.000, 29 | 0.749, 0.749, 0.000, 30 | 0.000, 1.000, 0.000, 31 | 0.000, 0.000, 1.000, 32 | 0.667, 0.000, 1.000, 33 | 0.333, 0.333, 0.000, 34 | 0.333, 0.667, 0.000, 35 | 0.333, 1.000, 0.000, 36 | 0.667, 0.333, 0.000, 37 | 0.667, 0.667, 0.000, 38 | 0.667, 1.000, 0.000, 39 | 1.000, 0.333, 0.000, 40 | 1.000, 0.667, 0.000, 41 | 1.000, 1.000, 0.000, 42 | 0.000, 0.333, 0.500, 43 | 0.000, 0.667, 0.500, 44 | 0.000, 1.000, 0.500, 45 | 0.333, 0.000, 0.500, 46 | 0.333, 0.333, 0.500, 47 | 0.333, 0.667, 0.500, 48 | 0.333, 1.000, 0.500, 49 | 0.667, 0.000, 0.500, 50 | 0.667, 0.333, 0.500, 51 | 0.667, 0.667, 0.500, 52 | 0.667, 1.000, 0.500, 53 | 1.000, 0.000, 0.500, 54 | 1.000, 0.333, 0.500, 55 | 1.000, 0.667, 0.500, 56 | 1.000, 1.000, 0.500, 57 | 0.000, 0.333, 1.000, 58 | 0.000, 0.667, 1.000, 59 | 0.000, 1.000, 1.000, 60 | 0.333, 0.000, 1.000, 61 | 0.333, 0.333, 1.000, 62 | 0.333, 0.667, 1.000, 63 | 0.333, 1.000, 1.000, 64 | 0.667, 0.000, 1.000, 65 | 0.667, 0.333, 1.000, 66 | 0.667, 0.667, 1.000, 67 | 0.667, 1.000, 1.000, 68 | 1.000, 0.000, 1.000, 69 | 1.000, 0.333, 1.000, 70 | 1.000, 0.667, 1.000, 71 | 0.333, 0.000, 0.000, 72 | 0.500, 0.000, 0.000, 73 | 0.667, 0.000, 0.000, 74 | 0.833, 0.000, 0.000, 75 | 1.000, 0.000, 0.000, 76 | 0.000, 0.167, 0.000, 77 | 0.000, 0.333, 0.000, 78 | 0.000, 0.500, 0.000, 79 | 0.000, 0.667, 0.000, 80 | 0.000, 0.833, 0.000, 81 | 0.000, 1.000, 0.000, 82 | 0.000, 0.000, 0.167, 83 | 0.000, 0.000, 0.333, 84 | 0.000, 0.000, 0.500, 85 | 0.000, 0.000, 0.667, 86 | 0.000, 0.000, 0.833, 87 | 0.000, 0.000, 1.000, 88 | 0.000, 0.000, 0.000, 89 | 0.143, 0.143, 0.143, 90 | 0.857, 0.857, 0.857, 91 | 1.000, 1.000, 1.000 92 | ] 93 | ).astype(np.float32).reshape(-1, 3) 94 | # fmt: on 95 | 96 | 97 | def colormap(rgb=False, maximum=255): 98 | """ 99 | Args: 100 | rgb (bool): whether to return RGB colors or BGR colors. 101 | maximum (int): either 255 or 1 102 | Returns: 103 | ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1] 104 | """ 105 | assert maximum in [255, 1], maximum 106 | c = _COLORS * maximum 107 | if not rgb: 108 | c = c[:, ::-1] 109 | return c 110 | 111 | 112 | def random_color(rgb=False, maximum=255): 113 | """ 114 | Args: 115 | rgb (bool): whether to return RGB colors or BGR colors. 116 | maximum (int): either 255 or 1 117 | Returns: 118 | ndarray: a vector of 3 numbers 119 | """ 120 | idx = np.random.randint(0, len(_COLORS)) 121 | ret = _COLORS[idx] * maximum 122 | if not rgb: 123 | ret = ret[::-1] 124 | return ret 125 | 126 | 127 | def random_colors(N, rgb=False, maximum=255): 128 | """ 129 | Args: 130 | N (int): number of unique colors needed 131 | rgb (bool): whether to return RGB colors or BGR colors. 132 | maximum (int): either 255 or 1 133 | Returns: 134 | ndarray: a list of random_color 135 | """ 136 | indices = random.sample(range(len(_COLORS)), N) 137 | ret = [_COLORS[i] * maximum for i in indices] 138 | if not rgb: 139 | ret = [x[::-1] for x in ret] 140 | return ret 141 | 142 | 143 | if __name__ == "__main__": 144 | import cv2 145 | 146 | size = 100 147 | H, W = 10, 10 148 | canvas = np.random.rand(H * size, W * size, 3).astype("float32") 149 | for h in range(H): 150 | for w in range(W): 151 | idx = h * W + w 152 | if idx >= len(_COLORS): 153 | break 154 | canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx] 155 | cv2.imshow("a", canvas) 156 | cv2.waitKey(0) 157 | -------------------------------------------------------------------------------- /comp2comp/utils/dl_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | from keras import Model 4 | 5 | # from keras.utils import multi_gpu_model 6 | # from tensorflow.python.keras.utils.multi_gpu_utils import multi_gpu_model 7 | 8 | 9 | def get_available_gpus(num_gpus: int = None): 10 | """Get gpu ids for gpus that are >95% free. 11 | 12 | Tensorflow does not support checking free memory on gpus. 13 | This is a crude method that relies on `nvidia-smi` to 14 | determine which gpus are occupied and which are free. 15 | 16 | Args: 17 | num_gpus: Number of requested gpus. If not specified, 18 | ids of all available gpu(s) are returned. 19 | 20 | Returns: 21 | List[int]: List of gpu ids that are free. Length 22 | will equal `num_gpus`, if specified. 23 | """ 24 | # Built-in tensorflow gpu id. 25 | assert isinstance(num_gpus, (type(None), int)) 26 | if num_gpus == 0: 27 | return [-1] 28 | 29 | num_requested_gpus = num_gpus 30 | try: 31 | num_gpus = ( 32 | len( 33 | subprocess.check_output("nvidia-smi --list-gpus", shell=True) 34 | .decode() 35 | .split("\n") 36 | ) 37 | - 1 38 | ) 39 | 40 | out_str = subprocess.check_output("nvidia-smi | grep MiB", shell=True).decode() 41 | except subprocess.CalledProcessError: 42 | return None 43 | mem_str = [x for x in out_str.split() if "MiB" in x] 44 | # First 2 * num_gpu elements correspond to memory for gpus 45 | # Order: (occupied-0, total-0, occupied-1, total-1, ...) 46 | mems = [float(x[:-3]) for x in mem_str] 47 | gpu_percent_occupied_mem = [ 48 | mems[2 * gpu_id] / mems[2 * gpu_id + 1] for gpu_id in range(num_gpus) 49 | ] 50 | 51 | available_gpus = [ 52 | gpu_id for gpu_id, mem in enumerate(gpu_percent_occupied_mem) if mem < 0.05 53 | ] 54 | if num_requested_gpus and num_requested_gpus > len(available_gpus): 55 | raise ValueError( 56 | "Requested {} gpus, only {} are free".format( 57 | num_requested_gpus, len(available_gpus) 58 | ) 59 | ) 60 | 61 | return available_gpus[:num_requested_gpus] if num_requested_gpus else available_gpus 62 | 63 | 64 | class ModelMGPU(Model): 65 | """Wrapper for distributing model across multiple gpus""" 66 | 67 | def __init__(self, ser_model, gpus): 68 | pmodel = multi_gpu_model(ser_model, gpus) # noqa: F821 69 | self.__dict__.update(pmodel.__dict__) 70 | self._smodel = ser_model 71 | 72 | def __getattribute__(self, attrname): 73 | """Override load and save methods to be used from the serial-model. The 74 | serial-model holds references to the weights in the multi-gpu model. 75 | """ 76 | # return Model.__getattribute__(self, attrname) 77 | if "load" in attrname or "save" in attrname: 78 | return getattr(self._smodel, attrname) 79 | 80 | return super(ModelMGPU, self).__getattribute__(attrname) 81 | -------------------------------------------------------------------------------- /comp2comp/utils/env.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import importlib.util 3 | import os 4 | import sys 5 | 6 | __all__ = [] 7 | 8 | 9 | # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path # noqa 10 | def _import_file(module_name, file_path, make_importable=False): 11 | spec = importlib.util.spec_from_file_location(module_name, file_path) 12 | module = importlib.util.module_from_spec(spec) 13 | spec.loader.exec_module(module) 14 | if make_importable: 15 | sys.modules[module_name] = module 16 | return module 17 | 18 | 19 | def _configure_libraries(): 20 | """ 21 | Configurations for some libraries. 22 | """ 23 | # An environment option to disable `import cv2` globally, 24 | # in case it leads to negative performance impact 25 | disable_cv2 = int(os.environ.get("MEDSEGPY_DISABLE_CV2", False)) 26 | if disable_cv2: 27 | sys.modules["cv2"] = None 28 | else: 29 | # Disable opencl in opencv since its interaction with cuda often 30 | # has negative effects 31 | # This envvar is supported after OpenCV 3.4.0 32 | os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled" 33 | try: 34 | import cv2 35 | 36 | if int(cv2.__version__.split(".")[0]) >= 3: 37 | cv2.ocl.setUseOpenCL(False) 38 | except ImportError: 39 | pass 40 | 41 | 42 | _ENV_SETUP_DONE = False 43 | 44 | 45 | def setup_environment(): 46 | """Perform environment setup work. The default setup is a no-op, but this 47 | function allows the user to specify a Python source file or a module in 48 | the $MEDSEGPY_ENV_MODULE environment variable, that performs 49 | custom setup work that may be necessary to their computing environment. 50 | """ 51 | global _ENV_SETUP_DONE 52 | if _ENV_SETUP_DONE: 53 | return 54 | _ENV_SETUP_DONE = True 55 | 56 | _configure_libraries() 57 | 58 | custom_module_path = os.environ.get("MEDSEGPY_ENV_MODULE") 59 | 60 | if custom_module_path: 61 | setup_custom_environment(custom_module_path) 62 | else: 63 | # The default setup is a no-op 64 | pass 65 | 66 | 67 | def setup_custom_environment(custom_module): 68 | """ 69 | Load custom environment setup by importing a Python source file or a 70 | module, and run the setup function. 71 | """ 72 | if custom_module.endswith(".py"): 73 | module = _import_file("medsegpy.utils.env.custom_module", custom_module) 74 | else: 75 | module = importlib.import_module(custom_module) 76 | assert hasattr(module, "setup_environment") and callable( 77 | module.setup_environment 78 | ), ( 79 | "Custom environment module defined in {} does not have the " 80 | "required callable attribute 'setup_environment'." 81 | ).format( 82 | custom_module 83 | ) 84 | module.setup_environment() 85 | -------------------------------------------------------------------------------- /comp2comp/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import functools 3 | import logging 4 | import os 5 | import sys 6 | import time 7 | from collections import Counter 8 | 9 | from termcolor import colored 10 | 11 | logging.captureWarnings(True) 12 | 13 | 14 | class _ColorfulFormatter(logging.Formatter): 15 | def __init__(self, *args, **kwargs): 16 | self._root_name = kwargs.pop("root_name") + "." 17 | self._abbrev_name = kwargs.pop("abbrev_name", "") 18 | if len(self._abbrev_name): 19 | self._abbrev_name = self._abbrev_name + "." 20 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 21 | 22 | def formatMessage(self, record): 23 | record.name = record.name.replace(self._root_name, self._abbrev_name) 24 | log = super(_ColorfulFormatter, self).formatMessage(record) 25 | if record.levelno == logging.WARNING: 26 | prefix = colored("WARNING", "red", attrs=["blink"]) 27 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 28 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 29 | else: 30 | return log 31 | return prefix + " " + log 32 | 33 | 34 | @functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers # noqa 35 | def setup_logger( 36 | output=None, 37 | distributed_rank=0, 38 | *, 39 | color=True, 40 | name="Comp2Comp", 41 | abbrev_name=None, 42 | ): 43 | """ 44 | Initialize the detectron2 logger and set its verbosity level to "INFO". 45 | 46 | Args: 47 | output (str): a file name or a directory to save log. If None, will not 48 | save log file. If ends with ".txt" or ".log", assumed to be a file 49 | name. Otherwise, logs will be saved to `output/log.txt`. 50 | name (str): the root module name of this logger 51 | abbrev_name (str): an abbreviation of the module, to avoid long names in 52 | logs. Set to "" to not log the root module in logs. 53 | By default, will abbreviate "detectron2" to "d2" and leave other 54 | modules unchanged. 55 | 56 | Returns: 57 | logging.Logger: a logger 58 | """ 59 | logger = logging.getLogger(name) 60 | logger.setLevel(logging.DEBUG) 61 | logger.propagate = False 62 | if abbrev_name is None: 63 | abbrev_name = name 64 | 65 | plain_formatter = logging.Formatter( 66 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", 67 | datefmt="%m/%d %H:%M:%S", 68 | ) 69 | # stdout logging: master only 70 | if distributed_rank == 0: 71 | ch = logging.StreamHandler(stream=sys.stdout) 72 | ch.setLevel(logging.DEBUG) 73 | if color: 74 | formatter = _ColorfulFormatter( 75 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", 76 | datefmt="%m/%d %H:%M:%S", 77 | root_name=name, 78 | abbrev_name=str(abbrev_name), 79 | ) 80 | else: 81 | formatter = plain_formatter 82 | ch.setFormatter(formatter) 83 | logger.addHandler(ch) 84 | 85 | # file logging: all workers 86 | if output is not None: 87 | if output.endswith(".txt") or output.endswith(".log"): 88 | filename = output 89 | else: 90 | filename = os.path.join(output, "log.txt") 91 | if distributed_rank > 0: 92 | filename = filename + ".rank{}".format(distributed_rank) 93 | os.makedirs(os.path.dirname(filename), exist_ok=True) 94 | 95 | fh = logging.StreamHandler(_cached_log_stream(filename)) 96 | fh.setLevel(logging.DEBUG) 97 | fh.setFormatter(plain_formatter) 98 | logger.addHandler(fh) 99 | 100 | return logger 101 | 102 | 103 | # cache the opened file object, so that different calls to `setup_logger` 104 | # with the same file name can safely write to the same file. 105 | @functools.lru_cache(maxsize=None) 106 | def _cached_log_stream(filename): 107 | return open(filename, "a") 108 | 109 | 110 | """ 111 | Below are some other convenient logging methods. 112 | They are mainly adopted from 113 | https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py 114 | """ 115 | 116 | 117 | def _find_caller(): 118 | """ 119 | Returns: 120 | str: module name of the caller 121 | tuple: a hashable key to be used to identify different callers 122 | """ 123 | frame = sys._getframe(2) 124 | while frame: 125 | code = frame.f_code 126 | if os.path.join("utils", "logger.") not in code.co_filename: 127 | mod_name = frame.f_globals["__name__"] 128 | if mod_name == "__main__": 129 | mod_name = "detectron2" 130 | return mod_name, (code.co_filename, frame.f_lineno, code.co_name) 131 | frame = frame.f_back 132 | 133 | 134 | _LOG_COUNTER = Counter() 135 | _LOG_TIMER = {} 136 | 137 | 138 | def log_first_n(lvl, msg, n=1, *, name=None, key="caller"): 139 | """ 140 | Log only for the first n times. 141 | 142 | Args: 143 | lvl (int): the logging level 144 | msg (str): 145 | n (int): 146 | name (str): name of the logger to use. Will use the caller's module by 147 | default. 148 | key (str or tuple[str]): the string(s) can be one of "caller" or 149 | "message", which defines how to identify duplicated logs. 150 | For example, if called with `n=1, key="caller"`, this function 151 | will only log the first call from the same caller, regardless of 152 | the message content. 153 | If called with `n=1, key="message"`, this function will log the 154 | same content only once, even if they are called from different 155 | places. 156 | If called with `n=1, key=("caller", "message")`, this function 157 | will not log only if the same caller has logged the same message 158 | before. 159 | """ 160 | if isinstance(key, str): 161 | key = (key,) 162 | assert len(key) > 0 163 | 164 | caller_module, caller_key = _find_caller() 165 | hash_key = () 166 | if "caller" in key: 167 | hash_key = hash_key + caller_key 168 | if "message" in key: 169 | hash_key = hash_key + (msg,) 170 | 171 | _LOG_COUNTER[hash_key] += 1 172 | if _LOG_COUNTER[hash_key] <= n: 173 | logging.getLogger(name or caller_module).log(lvl, msg) 174 | 175 | 176 | def log_every_n(lvl, msg, n=1, *, name=None): 177 | """ 178 | Log once per n times. 179 | 180 | Args: 181 | lvl (int): the logging level 182 | msg (str): 183 | n (int): 184 | name (str): name of the logger to use. Will use the caller's module by 185 | default. 186 | """ 187 | caller_module, key = _find_caller() 188 | _LOG_COUNTER[key] += 1 189 | if n == 1 or _LOG_COUNTER[key] % n == 1: 190 | logging.getLogger(name or caller_module).log(lvl, msg) 191 | 192 | 193 | def log_every_n_seconds(lvl, msg, n=1, *, name=None): 194 | """ 195 | Log no more than once per n seconds. 196 | 197 | Args: 198 | lvl (int): the logging level 199 | msg (str): 200 | n (int): 201 | name (str): name of the logger to use. Will use the caller's module by 202 | default. 203 | """ 204 | caller_module, key = _find_caller() 205 | last_logged = _LOG_TIMER.get(key, None) 206 | current_time = time.time() 207 | if last_logged is None or current_time - last_logged >= n: 208 | logging.getLogger(name or caller_module).log(lvl, msg) 209 | _LOG_TIMER[key] = current_time 210 | -------------------------------------------------------------------------------- /comp2comp/utils/orientation.py: -------------------------------------------------------------------------------- 1 | import nibabel as nib 2 | 3 | from comp2comp.inference_class_base import InferenceClass 4 | 5 | 6 | class ToCanonical(InferenceClass): 7 | """Convert spine segmentation to canonical orientation.""" 8 | 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def __call__(self, inference_pipeline): 13 | """ 14 | First dim goes from L to R. 15 | Second dim goes from P to A. 16 | Third dim goes from I to S. 17 | """ 18 | canonical_segmentation = nib.as_closest_canonical( 19 | inference_pipeline.segmentation 20 | ) 21 | canonical_medical_volume = nib.as_closest_canonical( 22 | inference_pipeline.medical_volume 23 | ) 24 | 25 | inference_pipeline.segmentation = canonical_segmentation 26 | inference_pipeline.medical_volume = canonical_medical_volume 27 | inference_pipeline.pixel_spacing_list = ( 28 | canonical_medical_volume.header.get_zooms() 29 | ) 30 | return {} 31 | -------------------------------------------------------------------------------- /comp2comp/utils/process.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | import os 6 | import shutil 7 | import sys 8 | import time 9 | import traceback 10 | from datetime import datetime 11 | from pathlib import Path 12 | 13 | from comp2comp.io import io_utils 14 | 15 | 16 | def find_common_root(paths): 17 | paths_with_sep = [path if path.endswith("/") else path + "/" for path in paths] 18 | 19 | # Find common prefix, ensuring it ends with a directory separator 20 | common_root = os.path.commonprefix(paths_with_sep) 21 | common_root 22 | if not common_root.endswith("/"): 23 | # Find the last separator to correctly identify the common root directory 24 | common_root = common_root[: common_root.rfind("/") + 1] 25 | 26 | return common_root 27 | 28 | 29 | def process_2d(args, pipeline_builder): 30 | output_dir = Path( 31 | os.path.join( 32 | os.path.dirname(os.path.abspath(__file__)), 33 | "../../outputs", 34 | datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), 35 | ) 36 | ) 37 | if not os.path.exists(output_dir): 38 | output_dir.mkdir(parents=True) 39 | 40 | model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../models") 41 | if not os.path.exists(model_dir): 42 | os.mkdir(model_dir) 43 | 44 | pipeline = pipeline_builder(args) 45 | 46 | pipeline(output_dir=output_dir, model_dir=model_dir) 47 | 48 | 49 | def process_3d(args, pipeline_builder): 50 | model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../models") 51 | if not os.path.exists(model_dir): 52 | os.mkdir(model_dir) 53 | 54 | if args.output_path is not None: 55 | output_path = Path(args.output_path) 56 | else: 57 | output_path = os.path.join( 58 | os.path.dirname(os.path.abspath(__file__)), "../../outputs" 59 | ) 60 | 61 | if not args.overwrite_outputs: 62 | date_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 63 | output_path = os.path.join(output_path, date_time) 64 | 65 | path_and_num = io_utils.get_dicom_or_nifti_paths_and_num(args.input_path) 66 | 67 | # in case input is a .txt file we need to find the common root of the files 68 | if args.input_path.endswith(".txt"): 69 | all_paths = [p[0] for p in path_and_num] 70 | common_root = find_common_root(all_paths) 71 | 72 | for path, num in path_and_num: 73 | 74 | try: 75 | st = time.time() 76 | 77 | if path.endswith(".nii") or path.endswith(".nii.gz"): 78 | print("Processing: ", path) 79 | else: 80 | print("Processing: ", path, " with ", num, " slices") 81 | min_slices = 30 82 | if num < min_slices: 83 | print(f"Number of slices is less than {min_slices}, skipping\n") 84 | continue 85 | 86 | print("") 87 | 88 | try: 89 | sys.stdout.flush() 90 | except Exception: 91 | pass 92 | 93 | if path.endswith(".nii") or path.endswith(".nii.gz"): 94 | folder_name = Path(os.path.basename(os.path.normpath(path))) 95 | # remove .nii or .nii.gz 96 | folder_name = os.path.normpath( 97 | Path(str(folder_name).replace(".gz", "").replace(".nii", "")) 98 | ) 99 | output_dir = Path( 100 | os.path.join( 101 | output_path, 102 | folder_name, 103 | ) 104 | ) 105 | 106 | else: 107 | if args.input_path.endswith(".txt"): 108 | output_dir = Path( 109 | os.path.join( 110 | output_path, 111 | os.path.relpath(os.path.normpath(path), common_root), 112 | ) 113 | ) 114 | else: 115 | output_dir = Path( 116 | os.path.join( 117 | output_path, 118 | Path(os.path.basename(os.path.normpath(args.input_path))), 119 | os.path.relpath( 120 | os.path.normpath(path), 121 | os.path.normpath(args.input_path), 122 | ), 123 | ) 124 | ) 125 | 126 | if not os.path.exists(output_dir): 127 | output_dir.mkdir(parents=True) 128 | 129 | pipeline = pipeline_builder(path, args) 130 | 131 | pipeline(output_dir=output_dir, model_dir=model_dir) 132 | 133 | if not args.save_segmentations: 134 | # remove the segmentations folder 135 | segmentations_dir = os.path.join(output_dir, "segmentations") 136 | if os.path.exists(segmentations_dir): 137 | shutil.rmtree(segmentations_dir) 138 | 139 | print(f"Finished processing {path} in {time.time() - st:.1f} seconds\n") 140 | print("Output was saved to:") 141 | print(output_dir) 142 | 143 | except Exception: 144 | print(f"ERROR PROCESSING {path}\n") 145 | traceback.print_exc() 146 | if os.path.exists(output_dir): 147 | shutil.rmtree(output_dir) 148 | # remove parent folder if empty 149 | if len(os.listdir(os.path.dirname(output_dir))) == 0: 150 | shutil.rmtree(os.path.dirname(output_dir)) 151 | continue 152 | -------------------------------------------------------------------------------- /comp2comp/utils/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | from typing import Sequence, Union 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def format_output_path( 10 | file_path, 11 | save_dir: str = None, 12 | base_dirs: Sequence[str] = None, 13 | file_name: Sequence[str] = None, 14 | ): 15 | """Format output path for a given file. 16 | 17 | Args: 18 | file_path (str): File path. 19 | save_dir (str, optional): Save directory. Defaults to None. 20 | base_dirs (Sequence[str], optional): Base directories. Defaults to None. 21 | file_name (Sequence[str], optional): File name. Defaults to None. 22 | 23 | Returns: 24 | str: Output path. 25 | """ 26 | 27 | dirname = os.path.dirname(file_path) if not save_dir else save_dir 28 | 29 | if save_dir and base_dirs: 30 | dirname: str = os.path.dirname(file_path) 31 | relative_dir = [ 32 | dirname.split(bdir, 1)[1] for bdir in base_dirs if dirname.startswith(bdir) 33 | ][0] 34 | # Trim path separator from the path 35 | relative_dir = relative_dir.lstrip(os.path.sep) 36 | dirname = os.path.join(save_dir, relative_dir) 37 | 38 | if file_name is not None: 39 | return os.path.join( 40 | dirname, 41 | "{}.h5".format(file_name), 42 | ) 43 | 44 | return os.path.join( 45 | dirname, 46 | "{}.h5".format(os.path.splitext(os.path.basename(file_path))[0]), 47 | ) 48 | 49 | 50 | # Function the returns a list of file names exluding 51 | # the extention from the list of file paths 52 | def get_file_names(files): 53 | """Get file names from a list of file paths. 54 | 55 | Args: 56 | files (list): List of file paths. 57 | 58 | Returns: 59 | list: List of file names. 60 | """ 61 | file_names = [] 62 | for file in files: 63 | file_name = os.path.splitext(os.path.basename(file))[0] 64 | file_names.append(file_name) 65 | return file_names 66 | 67 | 68 | def find_files( 69 | root_dirs: Union[str, Sequence[str]], 70 | max_depth: int = None, 71 | exist_ok: bool = False, 72 | pattern: str = None, 73 | ): 74 | """Recursively search for files. 75 | 76 | To avoid recomputing experiments with results, set `exist_ok=False`. 77 | Results will be searched for in `PREFERENCES.OUTPUT_DIR` (if non-empty). 78 | 79 | Args: 80 | root_dirs (`str(s)`): Root folder(s) to search. 81 | max_depth (int, optional): Maximum depth to search. 82 | exist_ok (bool, optional): If `True`, recompute results for 83 | scans. 84 | pattern (str, optional): If specified, looks for files with names 85 | matching the pattern. 86 | 87 | Return: 88 | List[str]: Experiment directories to test. 89 | """ 90 | 91 | def _get_files(depth: int, dir_name: str): 92 | if dir_name is None or not os.path.isdir(dir_name): 93 | return [] 94 | 95 | if max_depth is not None and depth > max_depth: 96 | return [] 97 | 98 | files = os.listdir(dir_name) 99 | ret_files = [] 100 | for file in files: 101 | possible_dir = os.path.join(dir_name, file) 102 | if os.path.isdir(possible_dir): 103 | subfiles = _get_files(depth + 1, possible_dir) 104 | ret_files.extend(subfiles) 105 | elif os.path.isfile(possible_dir): 106 | if pattern and not re.match(pattern, possible_dir): 107 | continue 108 | output_path = format_output_path(possible_dir) 109 | if not exist_ok and os.path.isfile(output_path): 110 | logger.info( 111 | "Skipping {} - results exist at {}".format( 112 | possible_dir, output_path 113 | ) 114 | ) 115 | continue 116 | ret_files.append(possible_dir) 117 | 118 | return ret_files 119 | 120 | out_files = [] 121 | if isinstance(root_dirs, str): 122 | root_dirs = [root_dirs] 123 | for d in root_dirs: 124 | out_files.extend(_get_files(0, d)) 125 | 126 | return sorted(set(out_files)) 127 | -------------------------------------------------------------------------------- /comp2comp/visualization/dicom.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import pydicom 6 | from PIL import Image 7 | from pydicom.dataset import Dataset, FileMetaDataset 8 | from pydicom.uid import ExplicitVRLittleEndian 9 | 10 | 11 | def to_dicom(input, output_path, plane="axial"): 12 | """Converts a png image to a dicom image. Written with assistance from ChatGPT.""" 13 | if isinstance(input, str) or isinstance(input, Path): 14 | png_path = input 15 | dicom_path = os.path.join( 16 | output_path, os.path.basename(png_path).replace(".png", ".dcm") 17 | ) 18 | image = Image.open(png_path) 19 | image_array = np.array(image) 20 | image_array = image_array[:, :, :3] 21 | else: 22 | image_array = input 23 | dicom_path = output_path 24 | 25 | meta = FileMetaDataset() 26 | meta.MediaStorageSOPClassUID = "1.2.840.10008.5.1.4.1.1.7" 27 | meta.MediaStorageSOPInstanceUID = pydicom.uid.generate_uid() 28 | meta.TransferSyntaxUID = ExplicitVRLittleEndian 29 | meta.ImplementationClassUID = pydicom.uid.PYDICOM_IMPLEMENTATION_UID 30 | 31 | ds = Dataset() 32 | ds.file_meta = meta 33 | ds.is_little_endian = True 34 | ds.is_implicit_VR = False 35 | ds.SOPClassUID = "1.2.840.10008.5.1.4.1.1.7" 36 | ds.SOPInstanceUID = pydicom.uid.generate_uid() 37 | ds.PatientName = "John Doe" 38 | ds.PatientID = "123456" 39 | ds.Modality = "OT" 40 | ds.SeriesInstanceUID = pydicom.uid.generate_uid() 41 | ds.StudyInstanceUID = pydicom.uid.generate_uid() 42 | ds.FrameOfReferenceUID = pydicom.uid.generate_uid() 43 | ds.BitsAllocated = 8 44 | ds.BitsStored = 8 45 | ds.HighBit = 7 46 | ds.PhotometricInterpretation = "RGB" 47 | ds.PixelRepresentation = 0 48 | ds.Rows = image_array.shape[0] 49 | ds.Columns = image_array.shape[1] 50 | ds.SamplesPerPixel = 3 51 | ds.PlanarConfiguration = 0 52 | 53 | if plane.lower() == "axial": 54 | ds.ImageOrientationPatient = [1, 0, 0, 0, 1, 0] 55 | elif plane.lower() == "sagittal": 56 | ds.ImageOrientationPatient = [0, 1, 0, 0, 0, -1] 57 | elif plane.lower() == "coronal": 58 | ds.ImageOrientationPatient = [1, 0, 0, 0, 0, -1] 59 | else: 60 | raise ValueError( 61 | "Invalid plane value. Must be 'axial', 'sagittal', or 'coronal'." 62 | ) 63 | 64 | ds.PixelData = image_array.tobytes() 65 | pydicom.filewriter.write_file(dicom_path, ds, write_like_original=False) 66 | 67 | 68 | # Example usage 69 | if __name__ == "__main__": 70 | png_path = "../../figures/spine_example.png" 71 | output_path = "./" 72 | plane = "sagittal" 73 | to_dicom(png_path, output_path, plane) 74 | -------------------------------------------------------------------------------- /comp2comp/visualization/linear_planar_reformation.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: louisblankemeier 3 | """ 4 | 5 | import numpy as np 6 | 7 | 8 | def linear_planar_reformation( 9 | medical_volume: np.ndarray, segmentation: np.ndarray, centroids, dimension="axial" 10 | ): 11 | if dimension == "sagittal" or dimension == "coronal": 12 | centroids = sorted(centroids, key=lambda x: x[2]) 13 | elif dimension == "axial": 14 | centroids = sorted(centroids, key=lambda x: x[0]) 15 | 16 | centroids = [(int(x[0]), int(x[1]), int(x[2])) for x in centroids] 17 | sagittal_centroids = [centroids[i][0] for i in range(0, len(centroids))] 18 | coronal_centroids = [centroids[i][1] for i in range(0, len(centroids))] 19 | axial_centroids = [centroids[i][2] for i in range(0, len(centroids))] 20 | 21 | sagittal_vals, coronal_vals, axial_vals = [], [], [] 22 | 23 | if dimension == "sagittal": 24 | sagittal_vals = [sagittal_centroids[0]] * axial_centroids[0] 25 | 26 | if dimension == "coronal": 27 | coronal_vals = [coronal_centroids[0]] * axial_centroids[0] 28 | 29 | if dimension == "axial": 30 | axial_vals = [axial_centroids[0]] * sagittal_centroids[0] 31 | 32 | for i in range(1, len(axial_centroids)): 33 | if dimension == "sagittal" or dimension == "coronal": 34 | num = axial_centroids[i] - axial_centroids[i - 1] 35 | elif dimension == "axial": 36 | num = sagittal_centroids[i] - sagittal_centroids[i - 1] 37 | 38 | if dimension == "sagittal": 39 | interp = list( 40 | np.linspace(sagittal_centroids[i - 1], sagittal_centroids[i], num=num) 41 | ) 42 | sagittal_vals.extend(interp) 43 | 44 | if dimension == "coronal": 45 | interp = list( 46 | np.linspace(coronal_centroids[i - 1], coronal_centroids[i], num=num) 47 | ) 48 | coronal_vals.extend(interp) 49 | 50 | if dimension == "axial": 51 | interp = list( 52 | np.linspace(axial_centroids[i - 1], axial_centroids[i], num=num) 53 | ) 54 | axial_vals.extend(interp) 55 | 56 | if dimension == "sagittal": 57 | sagittal_vals.extend( 58 | [sagittal_centroids[-1]] * (medical_volume.shape[2] - len(sagittal_vals)) 59 | ) 60 | sagittal_vals = np.array(sagittal_vals) 61 | sagittal_vals = sagittal_vals.astype(int) 62 | 63 | if dimension == "coronal": 64 | coronal_vals.extend( 65 | [coronal_centroids[-1]] * (medical_volume.shape[2] - len(coronal_vals)) 66 | ) 67 | coronal_vals = np.array(coronal_vals) 68 | coronal_vals = coronal_vals.astype(int) 69 | 70 | if dimension == "axial": 71 | axial_vals.extend( 72 | [axial_centroids[-1]] * (medical_volume.shape[0] - len(axial_vals)) 73 | ) 74 | axial_vals = np.array(axial_vals) 75 | axial_vals = axial_vals.astype(int) 76 | 77 | if dimension == "sagittal": 78 | sagittal_image = medical_volume[sagittal_vals, :, range(len(sagittal_vals))] 79 | sagittal_label = segmentation[sagittal_vals, :, range(len(sagittal_vals))] 80 | 81 | if dimension == "coronal": 82 | coronal_image = medical_volume[:, coronal_vals, range(len(coronal_vals))] 83 | coronal_label = segmentation[:, coronal_vals, range(len(coronal_vals))] 84 | 85 | if dimension == "axial": 86 | axial_image = medical_volume[range(len(axial_vals)), :, axial_vals] 87 | axial_label = segmentation[range(len(axial_vals)), :, axial_vals] 88 | 89 | if dimension == "sagittal": 90 | return sagittal_image, sagittal_label 91 | 92 | if dimension == "coronal": 93 | return coronal_image, coronal_label 94 | 95 | if dimension == "axial": 96 | return axial_image, axial_label 97 | -------------------------------------------------------------------------------- /docs/Local Implementation @ M1 arm64 Silicon.md: -------------------------------------------------------------------------------- 1 | # Local Implementation @ M1/arm64/AppleSilicon 2 | 3 | Due to dependencies and differences in architecture, the direct installation of *Comp2Comp* using install.sh or setup.py did not work on an local machine with arm64 / apple silicon running MacOS. This guide is mainly based on [issue #30](https://github.com/StanfordMIMI/Comp2Comp/issues/30). Most of the problems I encountered are caused by requiring TensorFlow and PyTorch in the same environment, which (especially for TensorFlow) is tricky at some times. Thus, this guide focuses more on the setup of the environment @arm64 / AppleSilicon, than *Comp2Comp* or *TotalSegmentator* itself. 4 | 5 | ## Installation 6 | Comp2Comp requires TensorFlow and TotalSegmentator requires PyTorch. Although (at the moment) neither *Comp2Comp* nor *TotalSegmentator* can make use of the M1 GPU. Thus, using the arm64-specific versions is necessary. 7 | 8 | ### TensorFlow 9 | For reference: 10 | - https://developer.apple.com/metal/tensorflow-plugin/ 11 | - https://developer.apple.com/forums/thread/683757 12 | - https://developer.apple.com/forums/thread/686926?page=2 13 | 14 | 1. Create an environment (python 3.8 or 3.9) using miniforge: https://github.com/conda-forge/miniforge. (TensorFlow did not work for others using anaconda; maybe you can get it running using -c apple and -c conda-forge for the further steps. However, I am not sure whether just the channel (and the retrieved packages) or anaconda's python itself is the problem.) 15 | 16 | 2. Install TensorFlow and tensorflow-metal in these versions: 17 | ``` 18 | conda install -c apple tensorflow-deps=2.9.0 -y 19 | python -m pip install tensorflow-macos==2.9 20 | python -m pip install tensorflow-metal==0.5.0 21 | ``` 22 | If you use other methods to install tensorflow, version 2.11.0 might be the best option. Tensorflow version 2.12.0 has caused some problems. 23 | 24 | ### PyTorch 25 | For reference https://pytorch.org. The nightly build is (at least for -c conda-forge or -c pytorch) not needed, and the default already supports GPU acceleration on arm64. 26 | 27 | 3. Install Pytorch 28 | ``` 29 | conda install pytorch torchvision torchaudio -c pytorch 30 | ``` 31 | 32 | ### Other Dependencies (Numpy and scikit-learn) 33 | 4. Install other packages 34 | ``` 35 | conda install -c conda-forge numpy scikit-learn -y 36 | ``` 37 | 38 | ### TotalSegmentator 39 | Louis et al. modified the original *TotalSegmentator* (https://github.com/wasserth/TotalSegmentator) for the use with *Comp2Comp*. *Comp2Comp* does not work with the original version. With the current version of the modified *TotalSegmentator* (https://github.com/StanfordMIMI/TotalSegmentator), no adaptions are necessary. 40 | 41 | ### Comp2Comp 42 | For *Comp2Comp* on M1 however, it is important **not** to use bin/install.sh, as some of the predefined requirements won't work. Thus: 43 | 44 | 5. Clone *Comp2Comp* 45 | ``` 46 | git clone https://github.com/StanfordMIMI/Comp2Comp.git 47 | ``` 48 | 49 | 6. Modify setup.py by 50 | - remove `"numpy==1.23.5"` 51 | - remove `"tensorflow>=2.0.0"` 52 | 53 | (You have installed these manually before.) 54 | 55 | 7. Install *Comp2Comp* with 56 | ``` 57 | python -m pip install -e . 58 | ``` 59 | 60 | ## Performance 61 | Using M1Max w/ 64GB RAM 62 | - `process 2d` (Comp2Comp in predefined slices): 250 slices in 14.2sec / 361 slices in 17.9sec 63 | - `process 3d` (segmentation of spine and identification of slices using TotalSegmentator, Comp2Comp in identified slices): high res, full body scan, 1367sec 64 | 65 | ## ToDos / Nice2Have / Future 66 | - Integration and use `--fast` and `--body_seg` for TotalSegmentator might be preferable 67 | - TotalSegmentator works only with CUDA compatible GPUs (!="mps"). I am not sure, about `torch.device("mps")` in the future, see also https://github.com/wasserth/TotalSegmentator/issues/39. Currently, only the CPU is used. 68 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-rtd-theme 3 | recommonmark 4 | sphinx_bootstrap_theme 5 | sphinxcontrib-bibtex>=2.0.0 6 | m2r2 -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | project = "comp2comp" 10 | copyright = "2023, StanfordMIMI" 11 | author = "StanfordMIMI" 12 | 13 | # -- General configuration --------------------------------------------------- 14 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 15 | 16 | # Adapted from https://github.com/pyvoxel/pyvoxel 17 | 18 | extensions = [ 19 | "sphinx.ext.autodoc", 20 | "sphinx.ext.autosummary", 21 | "sphinx.ext.intersphinx", 22 | "sphinx.ext.todo", 23 | "sphinx.ext.coverage", 24 | "sphinx.ext.mathjax", 25 | "sphinx.ext.ifconfig", 26 | "sphinx.ext.viewcode", 27 | "sphinx.ext.githubpages", 28 | "sphinx.ext.napoleon", 29 | "sphinxcontrib.bibtex", 30 | "sphinx_rtd_theme", 31 | "sphinx.ext.githubpages", 32 | "m2r2", 33 | ] 34 | 35 | autosummary_generate = True 36 | autosummary_imported_members = True 37 | 38 | bibtex_bibfiles = ["references.bib"] 39 | 40 | templates_path = ["_templates"] 41 | exclude_patterns = [] 42 | 43 | 44 | pygments_style = "sphinx" 45 | html_theme = "sphinx_rtd_theme" 46 | htmlhelp_basename = "Comp2Compdoc" 47 | html_static_path = ["_static"] 48 | 49 | intersphinx_mapping = {"numpy": ("https://numpy.org/doc/stable/", None)} 50 | html_theme_options = {"navigation_depth": 2} 51 | 52 | source_suffix = [".rst", ".md"] 53 | 54 | todo_include_todos = True 55 | napoleon_use_ivar = True 56 | napoleon_google_docstring = True 57 | html_show_sourcelink = False 58 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. comp2comp documentation master file, created by 2 | sphinx-quickstart on Sun Apr 9 21:28:41 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to comp2comp's documentation! 7 | ===================================== 8 | 9 | .. mdinclude:: ../../README.md 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :hidden: 14 | -------------------------------------------------------------------------------- /figures/aaa_diameter_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/figures/aaa_diameter_graph.png -------------------------------------------------------------------------------- /figures/aaa_segmentation_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/figures/aaa_segmentation_video.gif -------------------------------------------------------------------------------- /figures/aortic_aneurysm_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/figures/aortic_aneurysm_example.png -------------------------------------------------------------------------------- /figures/aortic_calcium_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/figures/aortic_calcium_overview.png -------------------------------------------------------------------------------- /figures/hip_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/figures/hip_example.png -------------------------------------------------------------------------------- /figures/liver_spleen_pancreas_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/figures/liver_spleen_pancreas_example.png -------------------------------------------------------------------------------- /figures/muscle_adipose_tissue_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/figures/muscle_adipose_tissue_example.png -------------------------------------------------------------------------------- /figures/spine_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/figures/spine_example.png -------------------------------------------------------------------------------- /figures/spine_muscle_adipose_tissue_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/figures/spine_muscle_adipose_tissue_example.png -------------------------------------------------------------------------------- /logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StanfordMIMI/Comp2Comp/2b000963e04d6140a40c79d7ecdefbb752a287e2/logo.png -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | multi_line_output=3 3 | include_trailing_comma=True 4 | force_grid_wrap=0 5 | use_parentheses=True 6 | ensure_newline_before_comments=True 7 | line_length=80 8 | 9 | [mypy] 10 | python_version=3.6 11 | ignore_missing_imports = True 12 | warn_unused_configs = True 13 | disallow_untyped_defs = True 14 | check_untyped_defs = True 15 | warn_unused_ignores = True 16 | warn_redundant_casts = True 17 | show_column_numbers = True 18 | follow_imports = silent 19 | allow_redefinition = True 20 | ; Require all functions to be annotated 21 | disallow_incomplete_defs = True -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import os 5 | from os import path 6 | 7 | from setuptools import find_packages, setup 8 | 9 | 10 | def get_version(): 11 | init_py_path = path.join( 12 | path.abspath(path.dirname(__file__)), "comp2comp", "__init__.py" 13 | ) 14 | init_py = open(init_py_path, "r").readlines() 15 | version_line = [line.strip() for line in init_py if line.startswith("__version__")][ 16 | 0 17 | ] 18 | version = version_line.split("=")[-1].strip().strip("'\"") 19 | 20 | # The following is used to build release packages. 21 | # Users should never use it. 22 | suffix = os.getenv("ABCTSEG_VERSION_SUFFIX", "") 23 | version = version + suffix 24 | if os.getenv("BUILD_NIGHTLY", "0") == "1": 25 | from datetime import datetime 26 | 27 | date_str = datetime.today().strftime("%y%m%d") 28 | version = version + ".dev" + date_str 29 | 30 | new_init_py = [line for line in init_py if not line.startswith("__version__")] 31 | new_init_py.append('__version__ = "{}"\n'.format(version)) 32 | with open(init_py_path, "w") as f: 33 | f.write("".join(new_init_py)) 34 | return version 35 | 36 | 37 | setup( 38 | name="comp2comp", 39 | version=get_version(), 40 | author="StanfordMIMI", 41 | url="https://github.com/StanfordMIMI/Comp2Comp", 42 | description="Computed tomography to body composition.", 43 | packages=find_packages(exclude=("configs", "tests")), 44 | python_requires=">=3.9", 45 | install_requires=[ 46 | "pydicom", 47 | "moviepy", 48 | "numpy==1.23.5", 49 | "h5py", 50 | "tabulate", 51 | "tqdm", 52 | "silx", 53 | "yacs", 54 | "pandas", 55 | "dosma", 56 | "opencv-python", 57 | "huggingface_hub", 58 | "pycocotools", 59 | "wget", 60 | "tensorflow==2.12.0", 61 | "totalsegmentator @ git+https://github.com/StanfordMIMI/TotalSegmentator.git", 62 | "totalsegmentatorv2 @ git+https://github.com/StanfordMIMI/TotalSegmentatorV2.git", 63 | ], 64 | extras_require={ 65 | "all": ["shapely", "psutil"], 66 | "dev": [ 67 | # Formatting 68 | "flake8", 69 | "isort", 70 | "black==22.8.0", 71 | "flake8-bugbear", 72 | "flake8-comprehensions", 73 | # Docs 74 | "mock", 75 | "sphinx", 76 | "sphinx-rtd-theme", 77 | "recommonmark", 78 | "myst-parser", 79 | ], 80 | "contrast_phase": ["xgboost"], 81 | }, 82 | ) 83 | --------------------------------------------------------------------------------