├── .circleci └── config.yml ├── .gitattributes ├── .gitignore ├── .vscode ├── extensions.json └── settings.json ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── Makefile ├── conf.py ├── index.rst ├── make.bat ├── mdai.rst └── mdai.utils.rst ├── mdai ├── .DS_Store ├── __init__.py ├── client.py ├── inference.py ├── preprocess.py ├── utils │ ├── __init__.py │ ├── common_utils.py │ ├── dicom_utils.py │ ├── keras_utils.py │ ├── sample_SR.dcm │ ├── sample_dicom.dcm │ ├── tensorflow_utils.py │ └── transforms.py └── visualize.py ├── notebooks ├── MDai_Simple_API.ipynb ├── hello-world-fastai.ipynb ├── hello-world-keras.ipynb ├── hello-world-pill.ipynb └── hello-world-tfrecords-VGG16.ipynb ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── conftest.py ├── test_preprocess.py └── test_visualize.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | slack: circleci/slack@3.4.2 5 | 6 | common_deploy_setup: &common_deploy_setup 7 | working_directory: ~/mdai-client-py 8 | docker: 9 | - image: cimg/python:3.8.17 10 | 11 | step_restore_cache: &step_restore_cache 12 | restore_cache: 13 | keys: 14 | - v1-{{ checksum "poetry.lock" }} 15 | 16 | step_install_dependencies: &step_install_dependencies 17 | run: 18 | name: Install dependencies 19 | command: | 20 | pip install -U poetry 21 | poetry config virtualenvs.create true 22 | poetry config virtualenvs.in-project true 23 | poetry install --no-ansi 24 | 25 | step_save_cache: &step_save_cache 26 | save_cache: 27 | key: v1-{{ checksum "poetry.lock" }} 28 | paths: 29 | - ".venv" 30 | 31 | run_tests: &run_tests 32 | run: 33 | name: Run tests 34 | command: | 35 | poetry run pytest 36 | 37 | jobs: 38 | test: 39 | <<: *common_deploy_setup 40 | resource_class: small 41 | steps: 42 | - checkout 43 | - <<: *step_restore_cache 44 | - <<: *step_install_dependencies 45 | - <<: *step_save_cache 46 | - <<: *run_tests 47 | - slack/status 48 | 49 | test_and_release_to_testpypi: 50 | <<: *common_deploy_setup 51 | resource_class: small 52 | steps: 53 | - checkout 54 | - <<: *step_restore_cache 55 | - <<: *step_install_dependencies 56 | - <<: *step_save_cache 57 | - <<: *run_tests 58 | - run: 59 | name: Release to TestPyPi 60 | command: | 61 | poetry config repositories.testpypi https://test.pypi.org/legacy/ 62 | poetry config http-basic.testpypi $TESTPYPI_USER $TESTPYPI_PASS 63 | poetry publish --build --repository testpypi 64 | - slack/status 65 | 66 | test_and_release_to_pypi: 67 | <<: *common_deploy_setup 68 | resource_class: small 69 | steps: 70 | - checkout 71 | - <<: *step_restore_cache 72 | - <<: *step_install_dependencies 73 | - <<: *step_save_cache 74 | - <<: *run_tests 75 | - run: 76 | name: Release to PyPi 77 | command: | 78 | poetry config http-basic.pypi $PYPI_USER $PYPI_PASS 79 | poetry publish --build 80 | - slack/status 81 | 82 | workflows: 83 | circleci_test: 84 | jobs: 85 | - test: 86 | filters: 87 | branches: 88 | ignore: master 89 | context: 90 | - SLACK 91 | circleci_test_and_release: 92 | jobs: 93 | - test_and_release_to_testpypi: 94 | filters: 95 | branches: 96 | only: master 97 | context: 98 | - SLACK 99 | - PYPI_CREDENTIALS 100 | - test_and_release_to_pypi: 101 | filters: 102 | tags: 103 | only: /v[0-9]+(\.[0-9]+)*/ 104 | branches: 105 | ignore: /.*/ 106 | context: 107 | - SLACK 108 | - PYPI_CREDENTIALS 109 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | notebooks/* linguist-generated=true 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv/ 2 | .cache/ 3 | .pytest_cache/ 4 | __pycache__/ 5 | build/ 6 | dist/ 7 | *.egg-info/ 8 | .ipynb_checkpoints/ 9 | /tests/data/ 10 | /docs/_build/ 11 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": ["charliermarsh.ruff"] 3 | } 4 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | // Show the path in the top window bar. 3 | "window.title": "${rootName}${separator}${activeEditorMedium}", 4 | 5 | // Formatting 6 | "editor.formatOnSave": true, 7 | "files.insertFinalNewline": true, 8 | "files.trimTrailingWhitespace": true, 9 | "python.formatting.provider": "black", 10 | 11 | // Linting 12 | "python.linting.enabled": true, 13 | 14 | // Python-specific settings 15 | "[python]": { 16 | "editor.detectIndentation": false, 17 | "editor.tabSize": 4 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Development 2 | 3 | ## Python version 4 | 5 | [Pyenv](https://github.com/pyenv/pyenv) is recommended for managing python versions. Currently, python 3.8+ is used for development. 6 | 7 | ## Virtualenv 8 | 9 | Uses [Poetry](https://poetry.eustace.io/docs/). For initial setup, run: 10 | 11 | ```sh 12 | # Install poetry (1.0+) 13 | pip install -U poetry 14 | 15 | # Configure poetry to install virtualenv in local directory 16 | poetry config virtualenvs.create true 17 | poetry config virtualenvs.in-project true 18 | 19 | # Install virtualenv in local directory 20 | poetry install 21 | ``` 22 | 23 | VSCode will automatically load the virtualenv. [ruff](https://github.com/charliermarsh/ruff) (linting) and [black](https://github.com/ambv/black) (formatter) are installed as dev dependencies. 24 | 25 | To activate the local virtualenv: 26 | 27 | ```sh 28 | source .venv/bin/activate 29 | # or 30 | poetry shell 31 | ``` 32 | 33 | ## Creating a new release 34 | 35 | Update the library version in the following files with every PR according to [semver](https://semver.org/) guidelines - 36 | 37 | - [pyproject.toml](https://github.com/mdai/mdai-client-py/blob/master/pyproject.toml#L17) 38 | 39 | Add a new tag - 40 | 41 | ```sh 42 | git tag -a -m "tag message" 43 | ``` 44 | 45 | Push the new tag - 46 | 47 | ```sh 48 | git push origin 49 | ``` 50 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2023 MD.ai, Inc. 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | http://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. 191 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MD.ai Python Client Library 2 | 3 | The python client library is designed to work with the datasets and annotations generated by the [MD.ai](https://www.md.ai/) Medical AI platform. 4 | 5 | You can download datasets consisting of images and annotations (i.e., JSON file), create train/validation/test datasets, integrate with various machine learing libraries (e.g., TensorFlow/Keras, Fast.ai) for developing machine learning algorithms. 6 | 7 | To get started, check out the examples in the [notebooks section](notebooks), or our [intro to deep learning for medical imaging lessons](https://github.com/mdai/ml-lessons/). 8 | 9 | ## Installation 10 | 11 | Requires Python 3.6+. Install and update using [pip](https://pip.pypa.io/en/stable/quickstart/): 12 | 13 | ```sh 14 | pip install --upgrade mdai 15 | ``` 16 | 17 | ## Documentation 18 | 19 | Documentation is available at: https://docs.md.ai/annotator/python/installation/ 20 | 21 | ## MD.ai Annotator 22 | 23 | The MD.ai annotator is a powerful web based application, to store and view anonymized medical images (e.g, DICOM) on the cloud, create annotations collaboratively, in real-time, and export annotations, images and labels for training. The MD.ai python client library can be used to download images and annotations, prepare the datasets, and then used to train and evaluate deep learning models. 24 | 25 | - MD.ai Annotator Documentation and Videos: https://docs.md.ai/ 26 | - MD.ai Annotator Example Project: https://public.md.ai/annotator/project/aGq4k6NW/workspace 27 | 28 | ![MD.ai Annotator](https://md.ai/images/product/annotator-feat-dicom.webp) 29 | 30 | ## MD.ai Annotation JSON Format 31 | 32 | More detailed information regarding the annotation JSON export format, see: https://docs.md.ai/annotator/data/json/ 33 | 34 | ## Example Notebooks 35 | 36 | - [HelloWorld Keras Notebook](notebooks/hello-world-keras.ipynb) 37 | - [HelloWorld TFRecords Notebook](notebooks/hello-world-tfrecords-VGG16.ipynb) 38 | - [HelloWorld Fast.ai](notebooks/hello-world-fastai.ipynb) 39 | 40 | ## Introductory lessons to Deep Learning for medical imaging by [MD.ai](https://www.md.ai) 41 | 42 | The following are several Jupyter notebooks covering the basics of downloading and parsing annotation data, and training and evaluating different deep learning models for classification, semantic and instance segmentation and object detection problems in the medical imaging domain. The notebooks can be run on Google's colab with GPU (see instruction below). 43 | 44 | - Lesson 1. Classification of chest vs. adominal X-rays using TensorFlow/Keras [Github](https://github.com/mdai/ml-lessons/blob/master/lesson1-xray-images-classification.ipynb) | [Annotator](https://public.md.ai/annotator/project/PVq9raBJ) 45 | - Lesson 2. Lung X-Rays Semantic Segmentation using UNets. [Github](https://github.com/mdai/ml-lessons/blob/master/lesson2-lung-xrays-segmentation.ipynb) | 46 | [Annotator](https://public.md.ai/annotator/project/aGq4k6NW/workspace) 47 | - Lesson 3. RSNA Pneumonia detection using Kaggle data format [Github](https://github.com/mdai/ml-lessons/blob/master/lesson3-rsna-pneumonia-detection-kaggle.ipynb) | [Annotator](https://public.md.ai/annotator/project/LxR6zdR2/workspace) 48 | - Lesson 3. RSNA Pneumonia detection using MD.ai python client library [Github](https://github.com/mdai/ml-lessons/blob/master/lesson3-rsna-pneumonia-detection-mdai-client-lib.ipynb) | [Annotator](https://public.md.ai/annotator/project/LxR6zdR2/workspace) 49 | 50 | ## Contributing 51 | 52 | See [contributing guidelines](CONTRIBUTING.md) to set up a development environemnt and how to make contributions to mdai. 53 | 54 | ## Running Jupyter notebooks Colab 55 | 56 | It’s easy to run a Jupyter notebook on Google's Colab with free GPU use (time limited). 57 | For example, you can add the Github jupyter notebook path to https://colab.research.google.com/notebook: 58 | Select the "GITHUB" tab, and add the Lesson 1 URL: https://github.com/mdai/ml-lessons/blob/master/lesson1-xray-images-classification.ipynb 59 | 60 | To use the GPU, in the notebook menu, go to Runtime -> Change runtime type -> switch to Python 3, and turn on GPU. See more [colab tips.](https://www.kdnuggets.com/2018/02/essential-google-colaboratory-tips-tricks.html) 61 | 62 | ## Advanced: How to run on Google Cloud Platform with Deep Learning Images 63 | 64 | You can also run the notebook with powerful GPUs on the Google Cloud Platform. In this case, you need to authenticate to the Google Cloug Platform, create a private virtual machine instance running a Google's Deep Learning image, and import the lessons. See instructions below. 65 | 66 | [GCP Deep Learnings Images How To](running_on_gcp.md) 67 | 68 | --- 69 | 70 | © 2023 MD.ai, Inc. 71 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = mdai-client-py 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | 18 | sys.path.insert(0, os.path.abspath("..")) 19 | sys.path.insert(0, os.path.abspath("../mdai")) 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "mdai-client-py" 24 | copyright = "2023 MD.ai" 25 | author = "MD.ai" 26 | 27 | # The short X.Y version 28 | version = "" 29 | # The full version, including alpha/beta/rc tags 30 | release = "" 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # If your documentation needs a minimal Sphinx version, state it here. 36 | # 37 | # needs_sphinx = '1.0' 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 43 | "sphinx.ext.autodoc", 44 | "sphinx.ext.mathjax", 45 | "sphinx.ext.viewcode", 46 | "sphinx.ext.githubpages", 47 | ] 48 | 49 | # Add any paths that contain templates here, relative to this directory. 50 | templates_path = ["_templates"] 51 | 52 | # The suffix(es) of source filenames. 53 | # You can specify multiple suffix as a list of string: 54 | # 55 | # source_suffix = ['.rst', '.md'] 56 | source_suffix = ".rst" 57 | 58 | # The master toctree document. 59 | master_doc = "index" 60 | 61 | # The language for content autogenerated by Sphinx. Refer to documentation 62 | # for a list of supported languages. 63 | # 64 | # This is also used if you do content translation via gettext catalogs. 65 | # Usually you set "language" from the command line for these cases. 66 | language = None 67 | 68 | # List of patterns, relative to source directory, that match files and 69 | # directories to ignore when looking for source files. 70 | # This pattern also affects html_static_path and html_extra_path . 71 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 72 | 73 | # The name of the Pygments (syntax highlighting) style to use. 74 | pygments_style = "sphinx" 75 | 76 | 77 | # -- Options for HTML output ------------------------------------------------- 78 | 79 | # The theme to use for HTML and HTML Help pages. See the documentation for 80 | # a list of builtin themes. 81 | # 82 | html_theme = "alabaster" 83 | 84 | # Theme options are theme-specific and customize the look and feel of a theme 85 | # further. For a list of options available for each theme, see the 86 | # documentation. 87 | # 88 | # html_theme_options = {} 89 | 90 | # Add any paths that contain custom static files (such as style sheets) here, 91 | # relative to this directory. They are copied after the builtin static files, 92 | # so a file named "default.css" will overwrite the builtin "default.css". 93 | html_static_path = ["_static"] 94 | 95 | # Custom sidebar templates, must be a dictionary that maps document names 96 | # to template names. 97 | # 98 | # The default sidebars (for documents that don't match any pattern) are 99 | # defined by theme itself. Builtin themes are using these templates by 100 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 101 | # 'searchbox.html']``. 102 | # 103 | # html_sidebars = {} 104 | 105 | 106 | # -- Options for HTMLHelp output --------------------------------------------- 107 | 108 | # Output file base name for HTML help builder. 109 | htmlhelp_basename = "mdai-client-pydoc" 110 | 111 | 112 | # -- Options for LaTeX output ------------------------------------------------ 113 | 114 | latex_elements = { 115 | # The paper size ('letterpaper' or 'a4paper'). 116 | # 117 | # 'papersize': 'letterpaper', 118 | # The font size ('10pt', '11pt' or '12pt'). 119 | # 120 | # 'pointsize': '10pt', 121 | # Additional stuff for the LaTeX preamble. 122 | # 123 | # 'preamble': '', 124 | # Latex figure (float) alignment 125 | # 126 | # 'figure_align': 'htbp', 127 | } 128 | 129 | # Grouping the document tree into LaTeX files. List of tuples 130 | # (source start file, target name, title, 131 | # author, documentclass [howto, manual, or own class]). 132 | latex_documents = [ 133 | (master_doc, "mdai-client-py.tex", "mdai-client-py Documentation", "MD.ai", "manual") 134 | ] 135 | 136 | 137 | # -- Options for manual page output ------------------------------------------ 138 | 139 | # One entry per manual page. List of tuples 140 | # (source start file, name, description, authors, manual section). 141 | man_pages = [(master_doc, "mdai-client-py", "mdai-client-py Documentation", [author], 1)] 142 | 143 | 144 | # -- Options for Texinfo output ---------------------------------------------- 145 | 146 | # Grouping the document tree into Texinfo files. List of tuples 147 | # (source start file, target name, title, author, 148 | # dir menu entry, description, category) 149 | texinfo_documents = [ 150 | ( 151 | master_doc, 152 | "mdai-client-py", 153 | "mdai-client-py Documentation", 154 | author, 155 | "mdai-client-py", 156 | "One line description of project.", 157 | "Miscellaneous", 158 | ) 159 | ] 160 | 161 | 162 | # -- Extension configuration ------------------------------------------------- 163 | 164 | from recommonmark.parser import CommonMarkParser 165 | 166 | source_parsers = {".md": CommonMarkParser} 167 | 168 | source_suffix = [".rst", ".md"] 169 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. mdai-client-py documentation master file, created by 2 | sphinx-quickstart on Sun Aug 5 14:50:10 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to mdai-client-py's documentation! 7 | ========================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | 14 | 15 | Indices and tables 16 | ================== 17 | 18 | * :ref:`genindex` 19 | * :ref:`modindex` 20 | * :ref:`search` 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=mdai-client-py 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/mdai.rst: -------------------------------------------------------------------------------- 1 | mdai package 2 | ============ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | mdai.utils 10 | 11 | Submodules 12 | ---------- 13 | 14 | mdai.client module 15 | ------------------ 16 | 17 | .. automodule:: mdai.client 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | mdai.preprocess module 23 | ---------------------- 24 | 25 | .. automodule:: mdai.preprocess 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | mdai.visualize module 31 | --------------------- 32 | 33 | .. automodule:: mdai.visualize 34 | :members: 35 | :undoc-members: 36 | :show-inheritance: 37 | 38 | 39 | Module contents 40 | --------------- 41 | 42 | .. automodule:: mdai 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | -------------------------------------------------------------------------------- /docs/mdai.utils.rst: -------------------------------------------------------------------------------- 1 | mdai.utils package 2 | ================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | mdai.utils.common\_utils module 8 | ------------------------------- 9 | 10 | .. automodule:: mdai.utils.common_utils 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | mdai.utils.keras\_utils module 16 | ------------------------------ 17 | 18 | .. automodule:: mdai.utils.keras_utils 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | mdai.utils.tensorflow\_utils module 24 | ----------------------------------- 25 | 26 | .. automodule:: mdai.utils.tensorflow_utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: mdai.utils 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /mdai/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdai/mdai-client-py/28ab35436f3db1fb056b1e8d3bc13d9ae4c5c555/mdai/.DS_Store -------------------------------------------------------------------------------- /mdai/__init__.py: -------------------------------------------------------------------------------- 1 | """MD.ai Python client library.""" 2 | 3 | try: 4 | from importlib import metadata 5 | except ImportError: # for Python<3.8 6 | import importlib_metadata as metadata 7 | __version__ = metadata.version("mdai") 8 | 9 | from . import preprocess 10 | from .client import Client 11 | from .utils import common_utils 12 | from .utils import transforms 13 | from .utils import dicom_utils 14 | from .inference import delete_env, run_inference, infer 15 | 16 | try: 17 | CAN_VISUALIZE = True 18 | from . import visualize 19 | except ImportError: 20 | # matplotlib backend missing or cannot be loaded 21 | CAN_VISUALIZE = False 22 | -------------------------------------------------------------------------------- /mdai/client.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | import re 4 | import math 5 | import json 6 | import uuid 7 | import zipfile 8 | import requests 9 | import urllib3.exceptions 10 | from retrying import retry 11 | from tqdm import tqdm 12 | import arrow 13 | from pydicom.filereader import dcmread 14 | from .utils import dicom_utils 15 | from .preprocess import Project 16 | from . import __version__ 17 | 18 | 19 | def retry_on_http_error(exception): 20 | valid_exceptions = [ 21 | requests.exceptions.HTTPError, 22 | requests.exceptions.ConnectionError, 23 | urllib3.exceptions.HTTPError, 24 | ] 25 | return any([isinstance(exception, e) for e in valid_exceptions]) 26 | 27 | 28 | ANNOTATIONS_IMPORT_DEFAULT_CHUNK_SIZE = 100000 29 | 30 | 31 | class Client: 32 | """Client for communicating with MD.ai backend API. 33 | Communication is via user access tokens (in MD.ai Hub, Settings -> User Access Tokens). 34 | """ 35 | 36 | def __init__(self, domain="public.md.ai", access_token=None): 37 | domain_match = re.match(r"^\w+\.md\.ai$", domain) 38 | dev_domain_match = re.match(r"^\w+\.mdai.dev(:\d+)?$", domain) 39 | if not domain_match and not dev_domain_match: 40 | raise ValueError(f"domain {domain} is invalid: should be format *.md.ai") 41 | 42 | self.domain = domain 43 | self.access_token = access_token 44 | self.session = requests.Session() 45 | self._test_endpoint() 46 | self.chat_completion = ChatCompletion(self.domain, self.session, self._create_headers()) 47 | 48 | def project( 49 | self, 50 | project_id, 51 | dataset_id=None, 52 | label_group_id=None, 53 | path=".", 54 | force_download=False, 55 | annotations_only=False, 56 | extract_images=True, 57 | ): 58 | """Initializes Project class given project id. 59 | 60 | Arguments: 61 | project_id: hash ID of project 62 | dataset_id: hash ID of dataset to scope to (optional - default `None`) 63 | label_group_id: hash ID of the label group to scope to (optional - default `None`) 64 | path: directory used for data (optional - default `"."`) 65 | force_download: if `True`, ignores possible existing data in `path` (optional - default `False`) 66 | annotations_only: if `True`, downloads annotations only (optional - default `False`) 67 | extract_images: if 'True', automatically extracts downloaded zip files for image exports (optional - default 'True') 68 | """ 69 | if path == ".": 70 | print("Using working directory for data.") 71 | else: 72 | os.makedirs(path, exist_ok=True) 73 | print(f"Using path '{path}' for data.") 74 | 75 | data_manager_kwargs = { 76 | "domain": self.domain, 77 | "project_id": project_id, 78 | "dataset_id": dataset_id, 79 | "label_group_id": label_group_id, 80 | "path": path, 81 | "session": self.session, 82 | "headers": self._create_headers(), 83 | "force_download": force_download, 84 | "extract_images": extract_images, 85 | } 86 | 87 | annotations_data_manager = ProjectDataManager("annotations", **data_manager_kwargs) 88 | annotations_data_manager.create_data_export_job() 89 | if not annotations_only: 90 | images_data_manager = ProjectDataManager("images", **data_manager_kwargs) 91 | images_data_manager.create_data_export_job() 92 | 93 | annotations_data_manager.wait_until_ready() 94 | if not annotations_only: 95 | images_data_manager.wait_until_ready() 96 | p = Project( 97 | annotations_fp=annotations_data_manager.data_path, 98 | images_dir=images_data_manager.data_path, 99 | ) 100 | return p 101 | else: 102 | print("No project created. Downloaded annotations only.") 103 | return None 104 | 105 | def download_model_outputs( 106 | self, project_id, dataset_id=None, model_id=None, path=".", force_download=False 107 | ): 108 | """Downloads model outputs given project_id. 109 | 110 | Arguments: 111 | project_id: hash ID of project 112 | dataset_id: hash ID of dataset (optional - default `None`) 113 | model_id: hash ID of the model (optional - default `None`) 114 | path: directory used for data (optional - default `"."`) 115 | force_download: if `True`, ignores possible existing data in `path` (optional - default `False`) 116 | """ 117 | if path == ".": 118 | print("Using working directory for model outputs.") 119 | else: 120 | os.makedirs(path, exist_ok=True) 121 | print(f"Using path '{path}' for data.") 122 | 123 | data_manager_kwargs = { 124 | "domain": self.domain, 125 | "project_id": project_id, 126 | "dataset_id": dataset_id, 127 | "model_id": model_id, 128 | "path": path, 129 | "session": self.session, 130 | "headers": self._create_headers(), 131 | "force_download": force_download, 132 | } 133 | 134 | model_outputs_manager = ProjectDataManager("model-outputs", **data_manager_kwargs) 135 | model_outputs_manager.create_data_export_job() 136 | model_outputs_manager.wait_until_ready() 137 | return None 138 | 139 | def download_dicom_metadata( 140 | self, project_id, dataset_id=None, format="json", path=".", force_download=False 141 | ): 142 | """Downloads dicom metadata given project_id, dataset_id and export format. 143 | 144 | Arguments: 145 | project_id: hash ID of project 146 | dataset_id: hash ID of dataset (optional - default `None`) 147 | format: export format for the metadata file, json or csv (optional - default `json`) 148 | path: directory used for data (optional - default `"."`) 149 | force_download: if `True`, ignores possible existing data in `path` (optional - default `False`) 150 | """ 151 | if path == ".": 152 | print("Using working directory for dicom metadata.") 153 | else: 154 | os.makedirs(path, exist_ok=True) 155 | print(f"Using path '{path}' for data.") 156 | 157 | if format not in ["json", "csv"]: 158 | raise Exception( 159 | "Incorrect export format specified for dicom-metadata. Only json and csv formats are supported." 160 | ) 161 | 162 | data_manager_kwargs = { 163 | "domain": self.domain, 164 | "project_id": project_id, 165 | "dataset_id": dataset_id, 166 | "format": format, 167 | "path": path, 168 | "session": self.session, 169 | "headers": self._create_headers(), 170 | "force_download": force_download, 171 | } 172 | 173 | dicom_metadata_manager = ProjectDataManager("dicom-metadata", **data_manager_kwargs) 174 | dicom_metadata_manager.create_data_export_job() 175 | dicom_metadata_manager.wait_until_ready() 176 | return None 177 | 178 | def load_model_annotations(self): 179 | """Deprecated method: use `import_annotations` instead.""" 180 | print("Deprecated method: use `import_annotations` instead.") 181 | 182 | def import_annotations( 183 | self, 184 | annotations, 185 | project_id, 186 | dataset_id, 187 | chunk_size=ANNOTATIONS_IMPORT_DEFAULT_CHUNK_SIZE, 188 | ): 189 | """Import annotations into project. 190 | For example, this method can be used to load machine learning model results into project as 191 | annotations, or quickly populate metadata labels. 192 | 193 | Arguments: 194 | project_id: hash ID of project. 195 | dataset_id: hash ID of dataset. 196 | annotations: list of annotations to load. 197 | chunk_size: number of annotations to load as a chunk. 198 | """ 199 | if not annotations: 200 | print("No annotations provided.") 201 | if not project_id: 202 | print("project_id is required.") 203 | if not dataset_id: 204 | print("dataset_id is required.") 205 | 206 | num_chunks = math.ceil(len(annotations) / chunk_size) 207 | 208 | if num_chunks > 1: 209 | print(f"Importing {len(annotations)} total annotations in {num_chunks} chunks...") 210 | 211 | failed_annotations = [] 212 | 213 | for i in range(num_chunks): 214 | if num_chunks > 1: 215 | print(f"Chunk {i+1}...") 216 | 217 | start = i * chunk_size 218 | end = (i + 1) * chunk_size 219 | annotations_chunk = annotations[start:end] 220 | 221 | manager = AnnotationsImportManager( 222 | annotations=annotations_chunk, 223 | project_id=project_id, 224 | dataset_id=dataset_id, 225 | session=self.session, 226 | domain=self.domain, 227 | headers=self._create_headers(), 228 | ) 229 | manager.create_job() 230 | manager.wait_until_ready() 231 | 232 | for failed_annotation in manager.failed_annotations: 233 | # add start index since returned index is for chunk 234 | failed_annotation["index"] += start 235 | failed_annotations.append(failed_annotation) 236 | 237 | if num_chunks > 1: 238 | num_failed = len(failed_annotations) 239 | print( 240 | f"Successfully imported {len(annotations) - num_failed} / {len(annotations)}" 241 | + f" total annotations into project {project_id}." 242 | ) 243 | 244 | return failed_annotations 245 | 246 | def import_sr(self, sr_files, project_id="", dataset_id="", label_id=""): 247 | """ 248 | Inputs: 249 | `sr_files` - List of DICOM-SR files to load 250 | `project_id` & `dataset_id` & `label_id` - Project information necessary to output SR to annotation note. All must be present if any are present. (optional) 251 | 252 | Created by Dyllan Hofflich and MD.ai 253 | """ 254 | if not project_id: 255 | raise ValueError('Please add in the "project_id" argument') 256 | if not dataset_id: 257 | raise ValueError('Please add in the "dataset_id" argument') 258 | if not label_id: 259 | raise ValueError('Please add in the "label_id" argument') 260 | annotations = [] 261 | 262 | for file_path in sr_files: 263 | ds = dcmread(file_path) 264 | 265 | # Get the referenced Dicom Files 266 | referenced_dicoms = [] 267 | for study_seq in ds.CurrentRequestedProcedureEvidenceSequence: 268 | referenced_study = {} 269 | study_UID = study_seq.StudyInstanceUID 270 | referenced_study["Study UID"] = study_UID 271 | referenced_dicoms.append(referenced_study) 272 | 273 | content_seq_list = list(ds.ContentSequence) 274 | 275 | content = [] 276 | dicom_utils.iterate_content_seq(content, content_seq_list) 277 | 278 | final_content = [] 279 | for annot in content: 280 | annot = list(filter(None, annot)) 281 | final_content.append(" - ".join(annot)) 282 | for dicom_dict in referenced_dicoms: 283 | study_uid = dicom_dict["Study UID"] 284 | note = "\n".join(final_content) 285 | annot_dict = {"labelId": label_id, "StudyInstanceUID": study_uid, "note": note} 286 | annotations.append(annot_dict) 287 | 288 | self.import_annotations(annotations, project_id, dataset_id) 289 | 290 | def _create_headers(self): 291 | headers = {} 292 | if self.access_token: 293 | headers["x-access-token"] = self.access_token 294 | return headers 295 | 296 | def _test_endpoint(self): 297 | """Checks endpoint for validity and authorization.""" 298 | test_endpoint = f"https://{self.domain}/api/test" 299 | r = self.session.get(test_endpoint, headers=self._create_headers()) 300 | if r.status_code == 200: 301 | print(f"Successfully authenticated to {self.domain}.") 302 | else: 303 | raise Exception("Authorization error. Make sure your access token is valid.") 304 | 305 | @retry( 306 | retry_on_exception=retry_on_http_error, 307 | wait_exponential_multiplier=100, 308 | wait_exponential_max=1000, 309 | stop_max_attempt_number=10, 310 | ) 311 | def _gql(self, query, variables=None): 312 | """Executes GraphQL query.""" 313 | gql_endpoint = f"https://{self.domain}/api/graphql" 314 | headers = self._create_headers() 315 | headers["Accept"] = "application/json" 316 | headers["Content-Type"] = "application/json" 317 | headers["apollographql-client-name"] = ("mdai-client-py",) 318 | headers["apollographql-client-version"] = __version__ 319 | 320 | data = {"query": query, "variables": variables} 321 | r = self.session.post(gql_endpoint, headers=headers, json=data) 322 | if r.status_code != 200: 323 | r.raise_for_status() 324 | 325 | body = r.json() 326 | data = body["data"] if "data" in body else None 327 | errors = body["errors"] if "errors" in body else None 328 | 329 | return data, errors 330 | 331 | 332 | class ProjectDataManager: 333 | """Manager for project data exports and downloads.""" 334 | 335 | def __init__( 336 | self, 337 | data_type, 338 | domain=None, 339 | project_id=None, 340 | dataset_id=None, 341 | label_group_id=None, 342 | model_id=None, 343 | format=None, 344 | path=".", 345 | session=None, 346 | headers=None, 347 | force_download=False, 348 | extract_images=True, 349 | ): 350 | if data_type not in ["images", "annotations", "model-outputs", "dicom-metadata"]: 351 | raise ValueError( 352 | "data_type must be 'images', 'annotations', 'model-outputs' or 'dicom-metadata'." 353 | ) 354 | if not domain: 355 | raise ValueError("domain is not specified.") 356 | if not project_id: 357 | raise ValueError("project_id is not specified.") 358 | if not os.path.exists(path): 359 | raise OSError(f"Path '{path}' does not exist.") 360 | 361 | self.data_type = data_type 362 | self.force_download = force_download 363 | self.extract_images = extract_images 364 | 365 | self.domain = domain 366 | self.project_id = project_id 367 | self.dataset_id = dataset_id 368 | self.label_group_id = label_group_id 369 | self.format = format 370 | self.model_id = model_id 371 | self.path = path 372 | if session and isinstance(session, requests.Session): 373 | self.session = session 374 | else: 375 | self.session = requests.Session() 376 | self.headers = headers 377 | 378 | # path for downloaded data 379 | self.data_path = None 380 | # ready threading event 381 | self._ready = threading.Event() 382 | 383 | def create_data_export_job(self): 384 | """Create data export job through MD.ai API. 385 | This is an async operation. Status code of 202 indicates successful creation of job. 386 | """ 387 | endpoint = f"https://{self.domain}/api/data-export/{self.data_type}" 388 | params = self._get_data_export_params() 389 | r = self.session.post(endpoint, json=params, headers=self.headers) 390 | if r.status_code == 202: 391 | msg = f"Preparing {self.data_type} export for project {self.project_id}..." 392 | print(msg.ljust(100)) 393 | self._check_data_export_job_progress() 394 | else: 395 | if r.status_code == 401: 396 | msg = ( 397 | f"Project {self.project_id} at domain {self.domain}" 398 | + " does not exist or you do not have sufficient permissions for access." 399 | ) 400 | print(msg) 401 | self._on_data_export_job_error() 402 | 403 | def wait_until_ready(self): 404 | self._ready.wait() 405 | 406 | def _get_data_export_params(self): 407 | if self.data_type == "images": 408 | params = { 409 | "projectHashId": self.project_id, 410 | "datasetHashId": self.dataset_id, 411 | "exportFormat": "zip", 412 | } 413 | elif self.data_type == "annotations": 414 | # TODO: restrict to assigned labelgroup 415 | params = { 416 | "projectHashId": self.project_id, 417 | "datasetHashId": self.dataset_id, 418 | "labelGroupHashId": self.label_group_id, 419 | "exportFormat": "json", 420 | } 421 | elif self.data_type == "model-outputs": 422 | params = { 423 | "projectHashId": self.project_id, 424 | "datasetHashId": self.dataset_id, 425 | "modelHashId": self.model_id, 426 | "exportFormat": "json", 427 | } 428 | elif self.data_type == "dicom-metadata": 429 | params = { 430 | "projectHashId": self.project_id, 431 | "datasetHashId": self.dataset_id, 432 | "exportFormat": self.format, 433 | } 434 | return params 435 | 436 | @retry( 437 | retry_on_exception=retry_on_http_error, 438 | wait_exponential_multiplier=100, 439 | wait_exponential_max=1000, 440 | stop_max_attempt_number=10, 441 | ) 442 | def _check_data_export_job_progress(self): 443 | """Poll for data export job progress.""" 444 | endpoint = f"https://{self.domain}/api/data-export/{self.data_type}/progress" 445 | params = self._get_data_export_params() 446 | r = self.session.post(endpoint, json=params, headers=self.headers) 447 | if r.status_code != 200: 448 | r.raise_for_status() 449 | 450 | try: 451 | body = r.json() 452 | status = body["status"] 453 | except (TypeError, KeyError): 454 | self._on_data_export_job_error() 455 | return 456 | 457 | if status == "done": 458 | self._on_data_export_job_done() 459 | 460 | elif status == "error": 461 | self._on_data_export_job_error() 462 | 463 | elif status == "running": 464 | try: 465 | progress = int(body["progress"]) 466 | except (TypeError, ValueError): 467 | progress = 0 468 | try: 469 | time_remaining = int(body["timeRemaining"]) 470 | except (TypeError, ValueError): 471 | time_remaining = 0 472 | 473 | # print formatted progress info 474 | if time_remaining > 45: 475 | time_remaining_fmt = ( 476 | arrow.now().shift(seconds=time_remaining).humanize(only_distance=True) 477 | ) 478 | else: 479 | # arrow humanizes <= 45 to 'in seconds' or 'just now', 480 | # so we will opt to be explicit instead. 481 | time_remaining_fmt = f"{time_remaining} seconds" 482 | end_char = "\r" if progress < 100 else "\n" 483 | msg = ( 484 | f"Exporting {self.data_type} for project {self.project_id}..." 485 | + f"{progress}% (time remaining: {time_remaining_fmt})." 486 | ) 487 | print(msg.ljust(100), end=end_char, flush=True) 488 | 489 | # run progress check at 1s intervals so long as status == 'running' 490 | t = threading.Timer(1.0, self._check_data_export_job_progress) 491 | t.start() 492 | 493 | @retry( 494 | retry_on_exception=retry_on_http_error, 495 | wait_exponential_multiplier=100, 496 | wait_exponential_max=1000, 497 | stop_max_attempt_number=10, 498 | ) 499 | def _on_data_export_job_done(self): 500 | endpoint = f"https://{self.domain}/api/data-export/{self.data_type}/done" 501 | params = self._get_data_export_params() 502 | r = self.session.post(endpoint, json=params, headers=self.headers) 503 | if r.status_code != 200: 504 | r.raise_for_status() 505 | 506 | try: 507 | file_keys = r.json()["fileKeys"] 508 | 509 | if file_keys: 510 | data_path = self._get_data_path(file_keys) 511 | if self.force_download or not os.path.exists(data_path): 512 | # download in separate thread 513 | t = threading.Thread(target=self._download_files, args=(file_keys,)) 514 | t.start() 515 | else: 516 | # use existing data 517 | self.data_path = data_path 518 | print(f"Using cached {self.data_type} data for project {self.project_id}.") 519 | # fire ready threading.Event 520 | self._ready.set() 521 | except (TypeError, KeyError): 522 | self._on_data_export_job_error() 523 | 524 | @retry( 525 | retry_on_exception=retry_on_http_error, 526 | wait_exponential_multiplier=100, 527 | wait_exponential_max=1000, 528 | stop_max_attempt_number=10, 529 | ) 530 | def _on_data_export_job_error(self): 531 | endpoint = f"https://{self.domain}/api/data-export/{self.data_type}/error" 532 | params = self._get_data_export_params() 533 | r = self.session.post(endpoint, json=params, headers=self.headers) 534 | if r.status_code != 200: 535 | r.raise_for_status() 536 | print(f"Error exporting {self.data_type} for project {self.project_id}.") 537 | # fire ready threading.Event 538 | self._ready.set() 539 | 540 | def _get_data_path(self, file_keys): 541 | if self.data_type == "images": 542 | # should be folder for zip file: 543 | # xxxx.zip -> xxxx/ 544 | # xxxx_part1of3.zip -> xxxx/ 545 | images_dir = re.sub(r"(_part\d+of\d+)?\.\S+$", "", file_keys[0]) 546 | return os.path.join(self.path, images_dir) 547 | elif self.data_type == "annotations": 548 | # annotations export will be single file 549 | annotations_fp = file_keys[0] 550 | return os.path.join(self.path, annotations_fp) 551 | elif self.data_type == "model-outputs": 552 | # model outputs export will be single file 553 | model_outputs_fp = file_keys[0] 554 | return os.path.join(self.path, model_outputs_fp) 555 | elif self.data_type == "dicom-metadata": 556 | # dicom metadata export will be single file 557 | dicom_metadata_fp = file_keys[0] 558 | return os.path.join(self.path, dicom_metadata_fp) 559 | 560 | def _download_files(self, file_keys): 561 | """Downloads exported files.""" 562 | try: 563 | for file_key in file_keys: 564 | print(f"Downloading file: {file_key}") 565 | filepath = os.path.join(self.path, file_key) 566 | 567 | key = requests.utils.quote(file_key) 568 | dl_session_id = str(uuid.uuid4()) 569 | 570 | # request download token 571 | url = f"https://{self.domain}/api/data-export/download-request" 572 | data = {"key": key, "sessionId": dl_session_id} 573 | r = requests.post(url, json=data, headers=self.headers) 574 | dl_token = r.json().get("token") 575 | 576 | # download file 577 | # stream response so we can display progress bar 578 | url = f"https://{self.domain}/api/data-export/download/{key}" 579 | data = {"token": dl_token, "sessionId": dl_session_id} 580 | r = requests.post(url, json=data, stream=True) 581 | # fallback to GET if POST not available 582 | if r.status_code == 405: 583 | r = requests.get(url, params=data, stream=True) 584 | 585 | # total size in bytes 586 | total_size = int(r.headers.get("content-length", 0)) 587 | block_size = 32 * 1024 588 | wrote = 0 589 | with open(filepath, "wb") as f: 590 | with tqdm( 591 | total=total_size, unit="B", unit_scale=True, unit_divisor=1024 592 | ) as pbar: 593 | for chunk in r.iter_content(block_size): 594 | f.write(chunk) 595 | wrote = wrote + len(chunk) 596 | pbar.update(block_size) 597 | if total_size != 0 and wrote != total_size: 598 | raise IOError(f"Error downloading file {file_key}.") 599 | 600 | if self.data_type == "images" and self.extract_images: 601 | # unzip archive 602 | print(f"Extracting archive: {file_key}") 603 | with zipfile.ZipFile(filepath, "r") as f: 604 | f.extractall(self.path) 605 | 606 | self.data_path = self._get_data_path(file_keys) 607 | 608 | print(f"Success: {self.data_type} data for project {self.project_id} ready.") 609 | except Exception: 610 | print(f"Error downloading {self.data_type} data for project {self.project_id}.") 611 | 612 | # fire ready threading.Event 613 | self._ready.set() 614 | 615 | 616 | class AnnotationsImportManager: 617 | """Manager for importing annotations.""" 618 | 619 | def __init__( 620 | self, 621 | annotations=None, 622 | project_id=None, 623 | dataset_id=None, 624 | session=None, 625 | domain=None, 626 | headers=None, 627 | ): 628 | if not domain: 629 | raise ValueError("domain is not specified.") 630 | if not project_id: 631 | raise ValueError("project_id is not specified.") 632 | 633 | self.annotations = annotations 634 | self.project_id = project_id 635 | self.dataset_id = dataset_id 636 | if session and isinstance(session, requests.Session): 637 | self.session = session 638 | else: 639 | self.session = requests.Session() 640 | self.domain = domain 641 | self.headers = headers 642 | 643 | self.job_id = None 644 | 645 | # list of failed annotation imports 646 | self.failed_annotations = [] 647 | 648 | # ready threading event 649 | self._ready = threading.Event() 650 | 651 | def create_job(self): 652 | """Create annotations import job through MD.ai API. 653 | This is an async operation. Status code of 202 indicates successful creation of job. 654 | """ 655 | endpoint = f"https://{self.domain}/api/data-import/annotations" 656 | params = { 657 | "projectHashId": self.project_id, 658 | "datasetHashId": self.dataset_id, 659 | "annotations": self.annotations, 660 | } 661 | 662 | # reset list of failed annotation imports 663 | self.failed_annotations = [] 664 | 665 | r = self.session.post(endpoint, json=params, headers=self.headers) 666 | 667 | if r.status_code == 202: 668 | self.job_id = r.json()["jobId"] 669 | msg = f"Importing {len(self.annotations)} annotations into " 670 | msg += f"project {self.project_id}, " 671 | msg += f"dataset {self.dataset_id}..." 672 | print(msg.ljust(100)) 673 | self._check_job_progress() 674 | else: 675 | print(r.status_code) 676 | if r.status_code in (400, 401): 677 | msg = "Provided IDs are invalid, or you do not have sufficient permissions." 678 | print(msg) 679 | 680 | def wait_until_ready(self): 681 | self._ready.wait() 682 | 683 | @retry( 684 | retry_on_exception=retry_on_http_error, 685 | wait_exponential_multiplier=100, 686 | wait_exponential_max=1000, 687 | stop_max_attempt_number=10, 688 | ) 689 | def _check_job_progress(self): 690 | """Poll for annotations import job progress.""" 691 | endpoint = f"https://{self.domain}/api/data-import/annotations/progress" 692 | params = {"projectHashId": self.project_id, "jobId": self.job_id} 693 | r = self.session.post(endpoint, json=params, headers=self.headers) 694 | 695 | try: 696 | body = r.json() 697 | status = body["status"] 698 | except (TypeError, KeyError): 699 | return 700 | 701 | if status == "done": 702 | self._on_job_done() 703 | 704 | elif status == "error": 705 | self._on_job_error() 706 | 707 | elif status == "running": 708 | try: 709 | progress = int(body["progress"]) 710 | except (TypeError, ValueError): 711 | progress = 0 712 | try: 713 | time_remaining = int(body["timeRemaining"]) 714 | except (TypeError, ValueError): 715 | time_remaining = 0 716 | 717 | # print formatted progress info 718 | if time_remaining > 45: 719 | time_remaining_fmt = ( 720 | arrow.now().shift(seconds=time_remaining).humanize(only_distance=True) 721 | ) 722 | else: 723 | # arrow humanizes <= 45 to 'in seconds' or 'just now', 724 | # so we will opt to be explicit instead. 725 | time_remaining_fmt = f"{time_remaining} seconds" 726 | end_char = "\r" if progress < 100 else "\n" 727 | msg = ( 728 | f"Annotations import for project {self.project_id}..." 729 | + f"{progress}% (time remaining: {time_remaining_fmt})." 730 | ) 731 | print(msg.ljust(100), end=end_char, flush=True) 732 | 733 | # run progress check at 1s intervals so long as status == 'running' 734 | t = threading.Timer(1.0, self._check_job_progress) 735 | t.start() 736 | 737 | @retry( 738 | retry_on_exception=retry_on_http_error, 739 | wait_exponential_multiplier=100, 740 | wait_exponential_max=1000, 741 | stop_max_attempt_number=10, 742 | ) 743 | def _on_job_done(self): 744 | endpoint = f"https://{self.domain}/api/data-import/annotations/done" 745 | params = {"projectHashId": self.project_id, "jobId": self.job_id} 746 | r = self.session.post(endpoint, json=params, headers=self.headers) 747 | 748 | try: 749 | body = r.json() 750 | self.failed_annotations = body["failed"] 751 | except (TypeError, KeyError): 752 | return 753 | 754 | num_failed = len(self.failed_annotations) 755 | print( 756 | f"Successfully imported {len(self.annotations) - num_failed} / {len(self.annotations)}" 757 | + f" annotations into project {self.project_id}" 758 | + f", dataset {self.dataset_id}." 759 | ) 760 | # fire ready threading.Event 761 | self._ready.set() 762 | 763 | @retry( 764 | retry_on_exception=retry_on_http_error, 765 | wait_exponential_multiplier=100, 766 | wait_exponential_max=1000, 767 | stop_max_attempt_number=10, 768 | ) 769 | def _on_job_error(self): 770 | endpoint = f"https://{self.domain}/api/data-import/annotations/error" 771 | params = {"projectHashId": self.project_id, "jobId": self.job_id} 772 | r = self.session.post(endpoint, json=params, headers=self.headers) 773 | print( 774 | f"Error importing annotations into project {self.project_id}" 775 | + f", dataset {self.dataset_id}." 776 | ) 777 | # fire ready threading.Event 778 | self._ready.set() 779 | 780 | 781 | class ChatCompletion: 782 | def __init__(self, domain="public.md.ai", session=None, headers=None): 783 | self.domain = domain 784 | if session and isinstance(session, requests.Session): 785 | self.session = session 786 | else: 787 | self.session = requests.Session() 788 | self.headers = headers 789 | self.models = {"gpt-3.5-turbo": ['gpt-3.5-turbo-1106', 'gpt-3.5-turbo-0613', 'gpt-3.5-turbo-16k-0613', 'gpt-3.5-turbo-0301'], 790 | "gpt-4": ['gpt-4-1106-preview', 'gpt-4-vision-preview', 'gpt-4-0613', 'gpt-4-0314'], 791 | "llama2": ['togethercomputer/llama-2-70b-chat'], 792 | "mixtral": ['mistralai/Mixtral-8x7B-Instruct-v0.1']} 793 | 794 | def list_models(self, model_name=None): 795 | if model_name in self.models.keys(): 796 | print(f"{model_name} available model list: {self.models[model_name]}") 797 | else: 798 | for model in self.models.keys(): 799 | print(f"{model} available model list: {self.models[model]}") 800 | 801 | def create( 802 | self, 803 | messages, 804 | model="gpt-3.5-turbo", 805 | functions=None, 806 | function_call=None, 807 | temperature=0, 808 | top_p=None, 809 | n=1, 810 | stop=None, 811 | max_tokens=None, 812 | presence_penalty=None, 813 | frequency_penalty=None, 814 | logit_bias=None, 815 | ): 816 | """Creates a chat completion API call through MD.ai client.""" 817 | headers = self.headers 818 | headers["Content-Type"] = "application/json" 819 | data = { 820 | "model": model, 821 | "messages": messages, 822 | "functions": functions, 823 | "function_call": function_call, 824 | "temperature": temperature, 825 | "top_p": top_p, 826 | "n": n, 827 | "stop": stop, 828 | "max_tokens": max_tokens, 829 | "presence_penalty": presence_penalty, 830 | "frequency_penalty": frequency_penalty, 831 | "logit_bias": logit_bias, 832 | } 833 | try: 834 | response = self.session.post( 835 | f"https://{self.domain}/api/openai/chat/completions", json=data, headers=headers 836 | ) 837 | response_json = json.loads(response.text) 838 | if response_json.get("error"): 839 | raise TypeError(f'{response_json.get("error")}') 840 | response_content = response_json["response"] 841 | return response_content 842 | except Exception as error: 843 | print( 844 | f"{error}. Error Calling Chat Completion API. Please check all the parameters and try again" 845 | ) 846 | -------------------------------------------------------------------------------- /mdai/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import sys 4 | import shlex 5 | import yaml 6 | import zipfile 7 | import pydicom 8 | import pathlib 9 | import subprocess 10 | import pandas as pd 11 | from tqdm import tqdm 12 | 13 | 14 | def extract_zip(version_path, model_version): 15 | """Extracts the exported model zip file""" 16 | if not os.path.isdir(version_path): 17 | zip_file = version_path + ".zip" 18 | if os.path.isfile(zip_file): 19 | print(f"Extracting {model_version}.zip") 20 | with zipfile.ZipFile(zip_file, "r") as f: 21 | f.extractall(version_path) 22 | else: 23 | raise Exception(f"{model_version}.zip does not exist, please try a different version.") 24 | 25 | 26 | def load_json(json_file): 27 | """Loads a json file""" 28 | with open(json_file) as f: 29 | json_file = json.load(f) 30 | return json_file 31 | 32 | 33 | def parse_model_json(model_json): 34 | """Parses the model.json schema file. Returns model scope and output labels""" 35 | model = model_json["model"] 36 | model_scope = model["scope"] 37 | labels = {i["id"]: [i["name"], i["scope"]] for i in model_json["labels"]} 38 | labels = {i["classIndex"]: labels[i["labelId"]] for i in model["labelClasses"]} 39 | return model_scope, labels 40 | 41 | 42 | def get_file_paths(root): 43 | """Yields all file paths recursively from root path, filtering on DICOM extension.""" 44 | if os.path.isfile(root): 45 | yield root 46 | else: 47 | for item in os.scandir(root): 48 | if item.is_file(): 49 | if os.path.splitext(item.path)[1] == ".dcm": 50 | yield item.path 51 | elif item.is_dir(): 52 | yield from get_file_paths(item.path) 53 | 54 | 55 | def process_file(path): 56 | """Returns each instance in raw bytes format""" 57 | instance = {} 58 | with open(path, "rb") as f: 59 | instance["content"] = f.read() 60 | instance["content_type"] = "application/dicom" 61 | return instance 62 | 63 | 64 | def get_scope_files(file_paths, scope): 65 | """Returns aggregated list of files based on input scope of the model""" 66 | scope_map = {"STUDY": "StudyInstanceUID", "SERIES": "SeriesInstanceUID"} 67 | vals = {} 68 | for path in file_paths: 69 | ds = pydicom.dcmread(path, stop_before_pixels=True) 70 | uid = ds.get(scope_map[scope]) 71 | if uid not in vals: 72 | vals[uid] = {"files": [], "annotations": [], "args": {}} 73 | vals[uid]["files"].append(process_file(path)) 74 | del ds 75 | return list(vals.values()) 76 | 77 | 78 | def process_data(path, model_scope): 79 | """Returns processed data in the correct input format for models""" 80 | file_paths = list(get_file_paths(path)) 81 | if model_scope in ("STUDY", "SERIES"): 82 | return get_scope_files(file_paths, model_scope) 83 | else: 84 | data = [] 85 | for path in file_paths: 86 | val = {"files": [], "annotations": [], "args": {}} 87 | val["files"].append(process_file(path)) 88 | data.append(val) 89 | return data 90 | 91 | 92 | def run_model(data_path, model_path, model_scope): 93 | """Prepares inputs and run the MDAI model""" 94 | print("Preparing inputs", flush=True) 95 | input_data = process_data(data_path, model_scope) 96 | 97 | sys.path.insert(0, os.path.join(model_path, ".mdai")) 98 | from mdai_deploy import MDAIModel 99 | 100 | model = MDAIModel() 101 | 102 | outputs = [] 103 | for data in tqdm(input_data, desc="Running inference"): 104 | outputs.append(model.predict(data)) 105 | outputs = [val for output in outputs for val in output] 106 | return outputs 107 | 108 | 109 | def env_exists(env_name): 110 | """Checks if conda env alreay exists to prevent duplicate builds""" 111 | command = shlex.split(f"/bin/bash -c 'conda env list | grep {env_name}'") 112 | try: 113 | subprocess.run(command, capture_output=True, text=True, check=True) 114 | except Exception: 115 | return False 116 | return True 117 | 118 | 119 | def is_py37(version_path): 120 | """Checks if base_image is py37 in config.yaml""" 121 | config_path = os.path.join(version_path, "model", ".mdai", "config.yaml") 122 | with open(config_path, "r") as f: 123 | config_file = yaml.safe_load(f) 124 | 125 | if config_file["base_image"] == "py37": 126 | return True 127 | return False 128 | 129 | 130 | def infer(model_path, data_path, model_version): 131 | """Helper function for processing inputs and running the model""" 132 | file_name = os.path.splitext(data_path)[0].split("/")[-1] 133 | version_path = os.path.join(model_path, "source", model_version) 134 | 135 | model_json = load_json(os.path.join(model_path, "model.json")) 136 | model_scope, labels = parse_model_json(model_json) 137 | model_inference_path = os.path.join(version_path, "model") 138 | 139 | outputs = run_model(data_path, model_inference_path, model_scope) 140 | columns = [ 141 | "StudyInstanceUID", 142 | "SeriesInstanceUID", 143 | "SOPInstanceUID", 144 | "Label", 145 | "Probability", 146 | "Data", 147 | "Scope", 148 | ] 149 | df = pd.DataFrame(columns=columns) 150 | 151 | for output in outputs: 152 | if output.get("type") == "ANNOTATION": 153 | label_details = labels[output.get("class_index")] 154 | row = { 155 | "StudyInstanceUID": output.get("study_uid"), 156 | "SeriesInstanceUID": output.get("series_uid"), 157 | "SOPInstanceUID": output.get("instance_uid"), 158 | "Label": label_details[0], 159 | "Probability": output.get("probability"), 160 | "Scope": label_details[1], 161 | "Data": [output.get("data")], 162 | } 163 | df = pd.concat([df, pd.DataFrame(row, index=[0])], ignore_index=True, axis=0) 164 | df.to_csv(os.path.join(model_path, f"outputs_{file_name}.csv"), index=False) 165 | print("Done!", flush=True) 166 | 167 | 168 | def delete_env(model_path): 169 | """ 170 | Delete the conda env created by previous model runs. 171 | 172 | Args: 173 | model_path: Path to the exported MDAI `model` folder 174 | """ 175 | model_json = load_json(os.path.join(model_path, "model.json")) 176 | model_id = model_json["model"]["id"] 177 | env_name = f"mdai_{model_id}" 178 | 179 | print(f"Deleting conda env {env_name}") 180 | subprocess.run(shlex.split(f'/bin/bash -c "conda env remove -n {env_name}"')) 181 | 182 | 183 | def run_inference(model_path, data_path, model_version="v1"): 184 | """ 185 | Run exported MDAI models locally. Returns a csv of model outputs. 186 | 187 | Args: 188 | model_path: Path to the exported and extracted MDAI `model` folder 189 | data_path: Path to the input DICOM files 190 | model_version: Version of the downloaded model to run. Default 'v1' 191 | 192 | """ 193 | model_path = pathlib.Path(model_path) 194 | data_path = pathlib.Path(data_path) 195 | version_path = os.path.join(model_path, "source", model_version) 196 | 197 | model_json = load_json(os.path.join(model_path, "model.json")) 198 | model_id = model_json["model"]["id"] 199 | env_name = f"mdai_{model_id}" 200 | 201 | if not os.path.exists(model_path): 202 | raise Exception(" Path for extracted model does not exist.") 203 | 204 | if not os.path.exists(data_path): 205 | raise Exception("Path for input data does not exist.") 206 | 207 | extract_zip(version_path, model_version) 208 | 209 | if not is_py37(version_path): 210 | raise Exception( 211 | "Custom Dockerfiles and NVIDIA base images are not currently supported for local inference." 212 | ) 213 | 214 | if env_exists(env_name): 215 | print(f"Loading conda env {env_name}") 216 | command = shlex.split( 217 | r'''/bin/bash -c "source activate {} && \ 218 | python -c 'import mdai; mdai.infer(\"{}\", \"{}\", \"{}\")'"'''.format( 219 | env_name, model_path, data_path, model_version, 220 | ) 221 | ) 222 | else: 223 | command = shlex.split( 224 | r'''/bin/bash -c "conda create -n {} python=3.7 pip -y && \ 225 | source activate {} && \ 226 | pip install numpy tqdm pandas mdai ipykernel pyyaml pydicom==2.1.2 h5py==2.10.0 && \ 227 | pip install -r {} && \ 228 | python -c 'import mdai; mdai.infer(\"{}\", \"{}\", \"{}\")'"'''.format( 229 | env_name, 230 | env_name, 231 | os.path.join(version_path, "model", ".mdai", "requirements.txt"), 232 | model_path, 233 | data_path, 234 | model_version, 235 | ) 236 | ) 237 | subprocess.run(command) 238 | -------------------------------------------------------------------------------- /mdai/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import collections 4 | import glob 5 | 6 | 7 | class Project: 8 | """Project consists of label groups, and datasets. 9 | 10 | Args: 11 | annotations_fp (str): 12 | File path to the exported JSON annotation file. 13 | images_dir (str): 14 | File path to the DICOM images directory. 15 | """ 16 | 17 | def __init__(self, annotations_fp=None, images_dir=None): 18 | """ 19 | 20 | """ 21 | self.annotations_fp = None 22 | self.images_dir = None 23 | self.label_groups = [] 24 | self.datasets = [] 25 | 26 | if annotations_fp is not None and images_dir is not None: 27 | self.annotations_fp = annotations_fp 28 | self.images_dir = images_dir 29 | 30 | with open(self.annotations_fp, "r") as f: 31 | self.data = json.load(f) 32 | 33 | for dataset in self.data["datasets"]: 34 | self.datasets.append(Dataset(dataset, images_dir)) 35 | 36 | for label_group in self.data["labelGroups"]: 37 | self.label_groups.append(LabelGroup(label_group)) 38 | else: 39 | print("Error: Missing data or images file paths!") 40 | 41 | def get_label_groups(self): 42 | return self.label_groups 43 | 44 | def show_label_groups(self): 45 | for label_group in self.label_groups: 46 | print("Label Group, Id: %s, Name: %s" % (label_group.id, label_group.name)) 47 | label_group.show_labels("\t") 48 | 49 | def get_label_group_by_name(self, label_group_name): 50 | for label_group in self.label_groups: 51 | if label_group.name == label_group_name: 52 | return label_group 53 | return None 54 | 55 | def get_label_group_by_id(self, label_group_id): 56 | for label_group in self.label_groups: 57 | if label_group.id == label_group_id: 58 | return label_group 59 | return None 60 | 61 | def get_datasets(self): 62 | """Get JSON representation of datasets""" 63 | return self.datasets 64 | 65 | def show_datasets(self): 66 | print("Datasets:") 67 | for dataset in self.datasets: 68 | print("Id: %s, Name: %s" % (dataset.id, dataset.name)) 69 | print("") 70 | 71 | def get_dataset_by_name(self, dataset_name): 72 | for dataset in self.datasets: 73 | if dataset.name == dataset_name: 74 | return dataset 75 | raise ValueError(f"Dataset name {dataset_name} does not exist.") 76 | 77 | def get_dataset_by_id(self, dataset_id): 78 | for dataset in self.datasets: 79 | if dataset.id == dataset_id: 80 | return dataset 81 | raise ValueError(f"Dataset id {dataset_id} does not exist.") 82 | 83 | def set_labels_dict(self, labels_dict): 84 | 85 | self.classes_dict = self._create_classes_dict(labels_dict) 86 | 87 | for dataset in self.datasets: 88 | dataset.classes_dict = self.classes_dict 89 | 90 | def get_label_id_annotation_mode(self, label_id): 91 | "Return label id's annotation mode." 92 | for label_group in self.label_groups: 93 | labels_data = label_group.get_data()["labels"] 94 | for label in labels_data: 95 | if label["id"] == label_id: 96 | return label["annotationMode"] 97 | raise ValueError(f"Label id {label_id} does not exist.") 98 | 99 | def get_label_id_type(self, label_id): 100 | "Return label id's type." 101 | for label_group in self.label_groups: 102 | labels_data = label_group.get_data()["labels"] 103 | for label in labels_data: 104 | if label["id"] == label_id: 105 | return label["type"] 106 | raise ValueError(f"Label id {label_id} does not exist.") 107 | 108 | def get_label_id_scope(self, label_id): 109 | "Return label id's scope." 110 | for label_group in self.label_groups: 111 | labels_data = label_group.get_data()["labels"] 112 | for label in labels_data: 113 | if label["id"] == label_id: 114 | return label["scope"] 115 | raise ValueError(f"Label id {label_id} does not exist.") 116 | 117 | def _create_classes_dict(self, labels_dict): 118 | """Create a dict with label id as key, and a nested dict of class_id, and class_text as \ 119 | values, e.g., {'L_v8n': {'class_id': 1, 'class_text': 'Lung Opacity'}}, where L_v8n is \ 120 | the label id, with a class_id of 1 and class text of 'Lung Opacity'. 121 | 122 | Args: 123 | labels_dict: 124 | dictionary containing label ids, and (user defined) class ids 125 | 126 | Returns: 127 | classes dict 128 | """ 129 | classes_dict = {} 130 | 131 | for label_id, class_id in labels_dict.items(): 132 | for label_group in self.label_groups: 133 | labels_data = label_group.get_data()["labels"] 134 | for label in labels_data: 135 | if label["id"] == label_id: 136 | classes_dict[label_id] = { 137 | "class_id": class_id, 138 | "class_text": label["name"], 139 | "class_annotation_mode": label["annotationMode"], 140 | "scope": label["scope"], 141 | "type": label["type"], 142 | } 143 | 144 | if classes_dict.keys() != labels_dict.keys(): 145 | in_labels = labels_dict.keys() 146 | out_labels = classes_dict.keys() 147 | diff = set(in_labels).symmetric_difference(out_labels) 148 | raise ValueError(f"Labels {diff} are not valid for this dataset.") 149 | 150 | return classes_dict 151 | 152 | 153 | class LabelGroup: 154 | """A label group contains multiple labels. 155 | Each label has properties such id, name, color, type, scope, annotation mode, rad lex tag ids. 156 | 157 | Label type: 158 | Global typed annotations apply to the whole instance (e.g., a CT image), while 159 | local typed annotations apply to a part of the image (e.g., ROI bounding box). 160 | Label scope: 161 | Scope can be of study, series, or instance. 162 | Label annotation mode: 163 | Annotation mode can be of bounding boxes, free form, polygon, etc. 164 | """ 165 | 166 | def __init__(self, label_group_data): 167 | """ 168 | Args: 169 | label_group (object: json) JSON data for label group 170 | """ 171 | self.label_group_data = label_group_data 172 | self.name = self.label_group_data["name"] 173 | self.id = self.label_group_data["id"] 174 | 175 | def get_data(self): 176 | return self.label_group_data 177 | 178 | def get_labels(self): 179 | """Get label ids and names """ 180 | return [(label["id"], label["name"]) for label in self.label_group_data["labels"]] 181 | 182 | def show_labels(self, print_offset=""): 183 | """Show labels info""" 184 | print(f"{print_offset}Labels:") 185 | for label in self.label_group_data["labels"]: 186 | print(f"{print_offset}Id: {label['id']}, Name: {label['name']}") 187 | print("") 188 | 189 | 190 | class Dataset: 191 | """A dataset consists of DICOM images and annotations. 192 | 193 | Args: 194 | dataset_data: 195 | Dataset json data. 196 | images_dir: 197 | DICOM images directory. 198 | """ 199 | 200 | def __init__(self, dataset_data, images_dir): 201 | 202 | self.dataset_data = dataset_data 203 | self.images_dir = images_dir 204 | 205 | self.id = dataset_data["id"] 206 | self.name = dataset_data["name"] 207 | self.all_annotations = dataset_data["annotations"] 208 | 209 | self.image_ids = None 210 | self.classes_dict = None 211 | self.imgs_anns_dict = None 212 | 213 | # all image ids 214 | self.all_image_ids = glob.glob(os.path.join(self.images_dir, "**/*.dcm"), recursive=True) 215 | 216 | def prepare(self): 217 | if self.classes_dict is None: 218 | raise Exception("Use `Project.set_labels_dict()` to set labels.") 219 | 220 | label_ids = self.classes_dict.keys() 221 | 222 | # filter annotations by label ids 223 | ann_filtered = self.get_annotations(label_ids) 224 | if not ann_filtered: 225 | raise Exception(f"No annotations exist for dataset '{self.name}'.") 226 | 227 | self.imgs_anns_dict = self._associate_images_and_annotations(ann_filtered) 228 | 229 | def get_annotations(self, label_ids=None, verbose=False): 230 | """Returns annotations, filtered by label ids. 231 | 232 | Args: 233 | label_ids (optional): 234 | Filter returned annotations by matching label ids. 235 | 236 | verbose (optional: 237 | Print debug messages. 238 | """ 239 | if label_ids is None: 240 | if verbose: 241 | print("Dataset contains %d annotations." % len(self.all_annotations)) 242 | return self.all_annotations 243 | 244 | ann_filtered = [a for a in self.all_annotations if a["labelId"] in label_ids] 245 | 246 | if verbose: 247 | print( 248 | f"Dataset contains {len(ann_filtered)} annotations" 249 | + f", filtered by label ids {label_ids}." 250 | ) 251 | return ann_filtered 252 | 253 | def _generate_uid(self, ann): 254 | """Generate an unique image identifier based on the DICOM file structure. 255 | 256 | Args: 257 | ann (list): 258 | List of annotations. 259 | 260 | Returns: 261 | A unique image identifier based on the DICOM file structure. 262 | """ 263 | 264 | uid = None 265 | 266 | if "StudyInstanceUID" and "SeriesInstanceUID" and "SOPInstanceUID" in ann: 267 | # SOPInstanceUID aka image level 268 | uid = os.path.join( 269 | self.images_dir, 270 | ann["StudyInstanceUID"], 271 | ann["SeriesInstanceUID"], 272 | ann["SOPInstanceUID"] + ".dcm", 273 | ) 274 | return uid 275 | elif "StudyInstanceUID" and "SeriesInstanceUID" in ann: 276 | prefix = os.path.join( 277 | self.images_dir, ann["StudyInstanceUID"], ann["SeriesInstanceUID"] 278 | ) 279 | uid = [image_id for image_id in self.all_image_ids if image_id.startswith(prefix)] 280 | return uid 281 | elif "StudyInstanceUID" in ann: 282 | prefix = os.path.join(self.images_dir, ann["StudyInstanceUID"]) 283 | uid = [image_id for image_id in self.all_image_ids if image_id.startswith(prefix)] 284 | return uid 285 | else: 286 | raise ValueError(f"Unable to create UID from {ann}") 287 | 288 | def get_image_ids(self, verbose=False): 289 | """Returns image ids. Must call prepare() method first in order to generate image ids. 290 | 291 | Args: 292 | verbose (Optional): 293 | Print debug message. 294 | """ 295 | if not self.image_ids: 296 | raise Exception("Call project.prepare() first.") 297 | 298 | if verbose: 299 | print( 300 | f"Dataset contains {len(self.image_ids)} images" 301 | + f", filtered by label ids {self.classes_dict.keys()}." 302 | ) 303 | return self.image_ids 304 | 305 | def _generate_image_ids(self, anns): 306 | """Get images ids for annotations. 307 | 308 | Args: 309 | ann (list): 310 | List of image ids. 311 | 312 | Returns: 313 | A list of image ids. 314 | """ 315 | image_ids = set() 316 | for ann in anns: 317 | uid = self._generate_uid(ann) 318 | 319 | if uid: 320 | if isinstance(uid, list): 321 | for one_uid in uid: 322 | image_ids.add(one_uid) 323 | else: 324 | image_ids.add(uid) 325 | 326 | # image_ids = glob.glob(os.path.join(self.images_dir, "**/*.dcm"), recursive=True) 327 | return sorted(list(image_ids)) 328 | 329 | def get_annotations_by_image_id(self, image_id): 330 | if image_id not in self.image_ids: 331 | raise ValueError(f"Image id {image_id} is not found in dataset {self.name}.") 332 | 333 | return self.imgs_anns_dict[image_id] 334 | 335 | def _associate_images_and_annotations(self, anns): 336 | """Generate image ids to annotations mapping. 337 | Each image can have zero or more annotations. 338 | 339 | Args: 340 | anns (list): 341 | List of annotations. 342 | 343 | Returns: 344 | A dictionary with image ids as keys and annotations as values. 345 | """ 346 | self.image_ids = self._generate_image_ids(anns) 347 | 348 | # empty dictionary with image ids as keys 349 | imgs_anns_dict = collections.OrderedDict() 350 | imgs_anns_dict = {fp: [] for fp in self.image_ids} 351 | 352 | for ann in anns: 353 | uid = self._generate_uid(ann) 354 | if uid: 355 | if isinstance(uid, list): 356 | for one_uid in uid: 357 | imgs_anns_dict[one_uid].append(ann) 358 | else: 359 | imgs_anns_dict[uid].append(ann) 360 | 361 | return imgs_anns_dict 362 | 363 | def class_id_to_class_text(self, class_id): 364 | for k, v in self.classes_dict.items(): 365 | if v["class_id"] == class_id: 366 | return v["class_text"] 367 | 368 | raise Exception(f"class_id {class_id} is invalid.") 369 | 370 | def class_text_to_class_id(self, class_text): 371 | for k, v in self.classes_dict.items(): 372 | if v["class_text"] == class_text: 373 | return v["class_id"] 374 | raise Exception(f"class_text {class_text} is invalid.") 375 | 376 | def label_id_to_class_id(self, label_id): 377 | for k, v in self.classes_dict.items(): 378 | if k == label_id: 379 | return v["class_id"] 380 | raise Exception(f"label_id {label_id} is invalid.") 381 | 382 | def label_id_to_class_annotation_mode(self, label_id): 383 | for k, v in self.classes_dict.items(): 384 | if k == label_id: 385 | return v["class_annotation_mode"] 386 | raise Exception(f"label_id {label_id} is invalid.") 387 | 388 | def show_classes(self): 389 | for k, v in self.classes_dict.items(): 390 | print(f"Label id: {k}, Class id: {v['class_id']}, Class text: {v['class_text']}") 391 | -------------------------------------------------------------------------------- /mdai/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdai/mdai-client-py/28ab35436f3db1fb056b1e8d3bc13d9ae4c5c555/mdai/utils/__init__.py -------------------------------------------------------------------------------- /mdai/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import copy 3 | import json 4 | import pandas as pd 5 | 6 | import os 7 | import uuid 8 | from functools import partial 9 | import multiprocessing 10 | import pydicom 11 | import numpy as np 12 | import nibabel as nib 13 | from tqdm import tqdm 14 | import cv2 15 | 16 | 17 | def hex2rgb(h): 18 | """Convert Hex color encoding to RGB color""" 19 | h = h.lstrip("#") 20 | return tuple(int(h[i : i + 2], 16) for i in (0, 2, 4)) 21 | 22 | 23 | def train_test_split(dataset, shuffle=True, validation_split=0.1): 24 | """ 25 | Split image ids into training and validation sets. 26 | TODO: Need to update dataset.dataset_data! 27 | TODO: What to set for images_dir for combined dataset? 28 | """ 29 | if validation_split < 0.0 or validation_split > 1.0: 30 | raise ValueError(f"{validation_split} is not a valid split ratio.") 31 | 32 | image_ids_list = dataset.get_image_ids() 33 | if shuffle: 34 | sorted(image_ids_list) 35 | random.seed(42) 36 | random.shuffle(image_ids_list) 37 | 38 | split_index = int((1 - validation_split) * len(image_ids_list)) 39 | train_image_ids = image_ids_list[:split_index] 40 | valid_image_ids = image_ids_list[split_index:] 41 | 42 | def filter_by_ids(ids, imgs_anns_dict): 43 | return {x: imgs_anns_dict[x] for x in ids} 44 | 45 | train_dataset = copy.deepcopy(dataset) 46 | train_dataset.id = dataset.id + "-TRAIN" 47 | train_dataset.name = dataset.name + "-TRAIN" 48 | 49 | valid_dataset = copy.deepcopy(dataset) 50 | valid_dataset.id = dataset.id + "-VALID" 51 | valid_dataset.name = dataset.name + "-VALID" 52 | 53 | imgs_anns_dict = dataset.imgs_anns_dict 54 | 55 | train_imgs_anns_dict = filter_by_ids(train_image_ids, imgs_anns_dict) 56 | valid_imgs_anns_dict = filter_by_ids(valid_image_ids, imgs_anns_dict) 57 | 58 | train_dataset.image_ids = train_image_ids 59 | valid_dataset.image_ids = valid_image_ids 60 | 61 | train_dataset.imgs_anns_dict = train_imgs_anns_dict 62 | valid_dataset.imgs_anns_dict = valid_imgs_anns_dict 63 | 64 | all_train_annotations = [] 65 | for _, annotations in train_dataset.imgs_anns_dict.items(): 66 | all_train_annotations += annotations 67 | train_dataset.all_annotations = all_train_annotations 68 | 69 | all_val_annotations = [] 70 | for _, annotations in valid_dataset.imgs_anns_dict.items(): 71 | all_val_annotations += annotations 72 | valid_dataset.all_annotations = all_val_annotations 73 | 74 | print( 75 | "Num of instances for training set: %d, validation set: %d" 76 | % (len(train_image_ids), len(valid_image_ids)) 77 | ) 78 | return train_dataset, valid_dataset 79 | 80 | 81 | def json_to_dataframe(json_file, datasets=[]): 82 | with open(json_file, "r", encoding="utf-8") as f: 83 | data = json.load(f) 84 | 85 | a = pd.DataFrame([]) 86 | studies = pd.DataFrame([]) 87 | labels = None 88 | 89 | # Gets annotations for all datasets 90 | for d in data["datasets"]: 91 | if d["id"] in datasets or len(datasets) == 0: 92 | study = pd.DataFrame(d["studies"]) 93 | study["dataset"] = d["name"] 94 | study["datasetId"] = d["id"] 95 | studies = pd.concat([studies, study], ignore_index=True, sort=False) 96 | 97 | annots = pd.DataFrame(d["annotations"]) 98 | annots["dataset"] = d["name"] 99 | a = pd.concat([a, annots], ignore_index=True, sort=False) 100 | 101 | if len(studies) > 0: 102 | studies = studies[["StudyInstanceUID", "dataset", "datasetId", "number"]] 103 | g = pd.DataFrame(data["labelGroups"]) 104 | # unpack arrays 105 | result = pd.DataFrame([(d, tup.id, tup.name) for tup in g.itertuples() for d in tup.labels]) 106 | if len(result) > 0: 107 | result.columns = ["labels", "labelGroupId", "labelGroupName"] 108 | 109 | def unpack_dictionary(df, column): 110 | ret = None 111 | ret = pd.concat( 112 | [df, pd.DataFrame((d for idx, d in df[column].items()))], axis=1, sort=False 113 | ) 114 | del ret[column] 115 | return ret 116 | 117 | labels = unpack_dictionary(result, "labels") 118 | if "parentId" in labels.columns: 119 | labels = labels[ 120 | [ 121 | "labelGroupId", 122 | "labelGroupName", 123 | "annotationMode", 124 | "color", 125 | "description", 126 | "id", 127 | "name", 128 | "radlexTagIds", 129 | "scope", 130 | "parentId", 131 | ] 132 | ] 133 | labels.columns = [ 134 | "labelGroupId", 135 | "labelGroupName", 136 | "annotationMode", 137 | "color", 138 | "description", 139 | "labelId", 140 | "labelName", 141 | "radlexTagIdsLabel", 142 | "scope", 143 | "parentLabelId", 144 | ] 145 | else: 146 | labels = labels[ 147 | [ 148 | "labelGroupId", 149 | "labelGroupName", 150 | "annotationMode", 151 | "color", 152 | "description", 153 | "id", 154 | "name", 155 | "radlexTagIds", 156 | "scope", 157 | ] 158 | ] 159 | labels.columns = [ 160 | "labelGroupId", 161 | "labelGroupName", 162 | "annotationMode", 163 | "color", 164 | "description", 165 | "labelId", 166 | "labelName", 167 | "radlexTagIdsLabel", 168 | "scope", 169 | ] 170 | 171 | if len(a) > 0: 172 | a = a.merge(labels, on=["labelId"], sort=False) 173 | if len(studies) > 0 and len(a) > 0: 174 | a = a.merge(studies, on=["StudyInstanceUID", "dataset"], sort=False) 175 | # Format data 176 | studies.number = studies.number.astype(int) 177 | a.number = a.number.astype(int) 178 | a.loc.createdAt = pd.to_datetime(a.createdAt) 179 | a.loc.updatedAt = pd.to_datetime(a.updatedAt) 180 | return {"annotations": a, "studies": studies, "labels": labels} 181 | 182 | 183 | def convert_mask_annotation_to_array(row): 184 | """ 185 | Converts a dataframe row containing a mask annotation from our internal complex polygon data representation to a numpy array. 186 | """ 187 | mask = np.zeros((int(row.width), int(row.height))) 188 | if row.data["foreground"]: 189 | for i in row.data["foreground"]: 190 | mask = cv2.fillPoly(mask, [np.array(i, dtype=np.int32)], 1) 191 | if row.data["background"]: 192 | for i in row.data["background"]: 193 | mask = cv2.fillPoly(mask, [np.array(i, dtype=np.int32)], 0) 194 | return mask 195 | 196 | 197 | def convert_mask_data(data): 198 | """ 199 | Converts a numpy array mask to our internal complex polygon data representation. 200 | """ 201 | mask = np.uint8(np.array(data) > 0) 202 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 203 | contours = [contours[i].reshape(-1, 2) for i in range(len(contours))] 204 | 205 | # Separate contours based on foreground / background polygons 206 | output_data = { 207 | "foreground": [], 208 | "background": [], 209 | } 210 | 211 | counts = [0] * len(contours) 212 | for i in range(len(contours)): 213 | parent = hierarchy[0][i][-1] 214 | if parent != -1: 215 | counts[i] = counts[parent] + 1 216 | 217 | if counts[i] % 2: 218 | output_data["background"].append(contours[i].tolist()) 219 | else: 220 | output_data["foreground"].append(contours[i].tolist()) 221 | return output_data 222 | 223 | 224 | """Converts NIFTI format to DICOM for CT exams. MR to come... 225 | 226 | """ 227 | 228 | 229 | def convert_ct( 230 | input_dir=None, 231 | output_dir=None, 232 | input_ext=".nii.gz", 233 | plane="axial", 234 | sample_dicom_fp=os.path.join(os.path.dirname(""), "./sample_dicom.dcm"), 235 | window_center=40, 236 | window_width=350, 237 | ): 238 | if not os.path.exists(input_dir): 239 | raise IOError("{:s} does not exist.".format(input_dir)) 240 | if not os.path.exists(sample_dicom_fp): 241 | raise IOError("{:s} does not exist.".format(sample_dicom_fp)) 242 | if plane not in ["axial", "sagittal", "coronal"]: 243 | raise ValueError("`plane` must be one of axial, sagittal, or coronal.") 244 | 245 | # make output dir if doesn't already exist 246 | if not os.path.exists(output_dir): 247 | os.makedirs(output_dir) 248 | 249 | found_filepaths = list(_get_files(input_dir, ext=input_ext)) 250 | print(f"{len(found_filepaths)} *{input_ext} files found. Processing...") 251 | 252 | n_procs = multiprocessing.cpu_count() - 1 253 | with multiprocessing.Pool(n_procs) as p: 254 | kwargs = { 255 | "input_dir": input_dir, 256 | "output_dir": output_dir, 257 | "input_ext": input_ext, 258 | "plane": plane, 259 | "sample_dicom_fp": sample_dicom_fp, 260 | "window_center": window_center, 261 | "window_width": window_width, 262 | } 263 | # need to iterate since Pool.imap is lazy 264 | for n in tqdm( 265 | p.imap_unordered(partial(_convert_nii_file_ct, **kwargs), found_filepaths), 266 | total=len(found_filepaths), 267 | ): 268 | pass 269 | 270 | 271 | def _get_files(root, ext=None): 272 | """Yields all file paths recursively from root path, optionally filtering on extension. 273 | """ 274 | for item in os.scandir(root): 275 | if item.is_file(): 276 | if not ext: 277 | yield item.path 278 | elif os.path.splitext(item.path)[1] == ext or item.path.endswith(ext): 279 | yield item.path 280 | elif item.is_dir(): 281 | yield from _get_files(item.path) 282 | 283 | 284 | def _get_datatype(headers): 285 | dt = str(headers.get_data_dtype()) 286 | return np.int16 287 | 288 | 289 | # header datatype not reliable 290 | 291 | # if dt == 'int8': 292 | # return np.int8 293 | # elif dt == 'int16': 294 | # return np.int16 295 | # elif dt == 'int32': 296 | # return np.int32 297 | # elif dt == 'int64': 298 | # return np.int64 299 | # elif dt == 'float32': 300 | # return np.float32 301 | # elif dt == 'float64': 302 | # return np.float64 303 | # return np.int16 304 | 305 | 306 | def _convert_nii_file_ct( 307 | filepath, 308 | input_dir=None, 309 | output_dir=None, 310 | input_ext=".nii.gz", 311 | plane="axial", 312 | sample_dicom_fp=os.path.join(os.path.dirname(""), "./sample_dicom.dcm"), 313 | window_center=40, 314 | window_width=350, 315 | ): 316 | dataobj = nib.load(filepath) 317 | headers = dataobj.header 318 | voxel_arr = dataobj.get_fdata() 319 | pixdim = headers["pixdim"][1:4].tolist() 320 | 321 | # NIFTI (RAS) -> DICOM (LPI) coordinates 322 | # i, Left/Right = sagittal plane 323 | # j, Anterior/Posterior = coronal plane 324 | # k, Superior/Inferior = axial plane 325 | voxel_arr = np.flip(voxel_arr, 0) 326 | voxel_arr = np.flip(voxel_arr, 1) 327 | voxel_arr = np.flip(voxel_arr, 2) 328 | 329 | # Image coordinates -> World coordinates 330 | if plane == "axial": 331 | slice_axis = 2 332 | plane_axes = [0, 1] 333 | elif plane == "coronal": 334 | slice_axis = 1 335 | plane_axes = [0, 2] 336 | elif plane == "sagittal": 337 | slice_axis = 0 338 | plane_axes = [1, 2] 339 | thickness = pixdim[slice_axis] 340 | spacing = [pixdim[plane_axes[1]], pixdim[plane_axes[0]]] 341 | voxel_arr = np.swapaxes(voxel_arr, *plane_axes) 342 | 343 | # generate DICOM UIDs (StudyInstanceUID and SeriesInstanceUID) 344 | study_uid = pydicom.uid.generate_uid(prefix=None) 345 | series_uid = pydicom.uid.generate_uid(prefix=None) 346 | 347 | # randomized patient ID 348 | patient_id = str(uuid.uuid4()) 349 | patient_name = patient_id 350 | 351 | try: 352 | scale_slope = str(int(headers["scl_slope"])) 353 | except ValueError: # handle NaN 354 | scale_slope = "1" 355 | try: 356 | scale_intercept = str(int(headers["scl_inter"])) 357 | except ValueError: # handle NaN 358 | scale_intercept = "0" 359 | 360 | for slice_index in range(voxel_arr.shape[slice_axis]): 361 | # generate SOPInstanceUID 362 | instance_uid = pydicom.uid.generate_uid(prefix=None) 363 | 364 | loc = slice_index * thickness 365 | 366 | ds = pydicom.dcmread(sample_dicom_fp) 367 | 368 | # delete tags 369 | del ds[0x00200052] # Frame of Reference UID 370 | del ds[0x00201040] # Position Reference Indicator 371 | 372 | # slice and set PixelData tag 373 | axes = [slice(None)] * 3 374 | axes[slice_axis] = slice_index 375 | arr = voxel_arr[tuple(axes)].astype(_get_datatype(headers)) 376 | ds[0x7FE00010].value = arr.tobytes() 377 | 378 | # modify tags 379 | # - UIDs are created by pydicom.uid.generate_uid at each level above 380 | # - image position is calculated by combination of slice index and slice thickness 381 | # - slice location is set to the value of image position along z-axis 382 | # - Rows/Columns determined by array shape 383 | # - we set slope/intercept to 1/0 since we're directly converting from PNG pixel values 384 | ds[0x00080018].value = instance_uid # SOPInstanceUID 385 | ds[0x00100010].value = patient_name 386 | ds[0x00100020].value = patient_id 387 | ds[0x0020000D].value = study_uid # StudyInstanceUID 388 | ds[0x0020000E].value = series_uid # SeriesInstanceUID 389 | ds[0x0008103E].value = "" # Series Description 390 | ds[0x00200011].value = "1" # Series Number 391 | ds[0x00200012].value = str(slice_index + 1) # Acquisition Number 392 | ds[0x00200013].value = str(slice_index + 1) # Instance Number 393 | ds[0x00201041].value = str(loc) # Slice Location 394 | ds[0x00280010].value = arr.shape[0] # Rows 395 | ds[0x00280011].value = arr.shape[1] # Columns 396 | ds[0x00280030].value = spacing # Pixel Spacing 397 | ds[0x00281050].value = str(window_center) # Window Center 398 | ds[0x00281051].value = str(window_width) # Window Width 399 | ds[0x00281052].value = str(scale_intercept) # Rescale Intercept 400 | ds[0x00281053].value = str(scale_slope) # Rescale Slope 401 | 402 | # Image Position (Patient) 403 | # Image Orientation (Patient) 404 | if plane == "axial": 405 | ds[0x00200032].value = ["0", "0", str(loc)] 406 | ds[0x00200037].value = ["1", "0", "0", "0", "1", "0"] 407 | elif plane == "coronal": 408 | ds[0x00200032].value = ["0", str(loc), "0"] 409 | ds[0x00200037].value = ["1", "0", "0", "0", "0", "1"] 410 | elif plane == "sagittal": 411 | ds[0x00200032].value = [str(loc), "0", "0"] 412 | ds[0x00200037].value = ["0", "1", "0", "0", "0", "1"] 413 | 414 | # add new tags 415 | # see tag info e.g., from https://dicom.innolitics.com/ciods/nm-image/nm-reconstruction/00180050 416 | # Slice Thickness 417 | ds[0x00180050] = pydicom.dataelem.DataElement(0x00180050, "DS", str(thickness)) 418 | 419 | # Output DICOM filepath 420 | # For root directory of data/, then: 421 | # e.g., 'data/x/y/z.nii.gz' becomes '{output_dir}/data/x/y/z/{001-999}.dcm' 422 | dicom_fp = os.path.join( 423 | output_dir, 424 | os.path.dirname(filepath).strip("/"), # remove leading and trailing slashes 425 | os.path.basename(filepath).replace(input_ext, ""), 426 | "{:03}.dcm".format(slice_index + 1), 427 | ) 428 | 429 | # create directory 430 | if not os.path.exists(os.path.dirname(dicom_fp)): 431 | os.makedirs(os.path.dirname(dicom_fp)) 432 | 433 | # write DICOM to file 434 | pydicom.dcmwrite(dicom_fp, ds) 435 | -------------------------------------------------------------------------------- /mdai/utils/dicom_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from .common_utils import json_to_dataframe 3 | from datetime import datetime 4 | import os 5 | import requests 6 | import cv2 7 | import numpy as np 8 | 9 | import pydicom 10 | from pydicom.filereader import dcmread 11 | from pydicom.dataset import Dataset, FileMetaDataset 12 | from pydicom.pixel_data_handlers.numpy_handler import pack_bits 13 | from pydicom.sequence import Sequence 14 | import warnings 15 | 16 | warnings.filterwarnings("ignore", module="pydicom") 17 | 18 | 19 | # Imports and Parse Some of the DICOM Standard Files 20 | # ----------------------------------------------- 21 | class DicomExport: 22 | """ 23 | Used to convert md.ai annotations to DICOM SR/SEG format for easier data processing. 24 | 25 | Inputs: 26 | `output_format` - determines if the output should be in DICOM SR/SEG (accepted inputs are "SR" or "SEG") 27 | `annotation_json` & `metadata_json` - are the exported annotation and metadata json paths from the md.ai project 28 | MAKE SURE THE DATASETS MATCH UP 29 | `combine_label_groups` - If `True` then each SEG file includes the annotations from every label group for that series 30 | If `False` then a different SEG file will be created for each different label group annotation 31 | (only applies to SEG output. Will be ignored for SR) 32 | `output_dir` - Specifies where the files should be downloaded 33 | If None then files will be placed in a "SR(or SEG)_OUTPUT" folder in your cwd. 34 | 35 | Outputs: 36 | There will be a folder in your cwd, specified by the `output_dir` parameter, containing your SR/SEG exports. 37 | 38 | Created by Dyllan Hofflich and MD.ai. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | output_format, 44 | annotation_json, 45 | metadata_json, 46 | combine_label_groups=True, 47 | output_dir=None, 48 | ): 49 | self.Annotation_Json = annotation_json 50 | self.Metadata_Json = metadata_json 51 | self.output_format = output_format 52 | self.combine = combine_label_groups 53 | self.output_dir = output_dir 54 | 55 | if output_format != "SR" and output_format != "SEG": 56 | raise Exception('Invalid output format. Must be either "SR" or "SEG"') 57 | 58 | self.dicom_standards_setups() 59 | self.dicom_tags_setup() 60 | 61 | def dicom_standards_setups(self): 62 | """ 63 | Searches the DICOM standard to gather which dicom tags are relevant to SR/Segmentation and which tags should be copied over from the original images. 64 | Gathers the DICOM standard data from https://github.com/innolitics/dicom-standard which is used in the Standard DICOM Browser website. 65 | """ 66 | 67 | ctm_json = requests.get( 68 | "https://raw.githubusercontent.com/innolitics/dicom-standard/master/standard/ciod_to_modules.json" 69 | ).text 70 | mta_json = requests.get( 71 | "https://raw.githubusercontent.com/innolitics/dicom-standard/master/standard/module_to_attributes.json" 72 | ).text 73 | attributes_json = requests.get( 74 | "https://raw.githubusercontent.com/innolitics/dicom-standard/master/standard/attributes.json" 75 | ).text 76 | # ciod_to_modules_dataframe 77 | ctm_df = pd.read_json(ctm_json) 78 | # module_to_attributes_dataframe 79 | mta_df = pd.read_json(mta_json) 80 | # attributes_dataframe 81 | attributes_df = pd.read_json(attributes_json) 82 | 83 | # Select basic-text-sr/SEG modules 84 | if self.output_format == "SR": 85 | SR_modules_df = ctm_df[ctm_df["ciodId"] == "basic-text-sr"] 86 | else: 87 | SR_modules_df = ctm_df[ctm_df["ciodId"] == "segmentation"] 88 | # Select all basic-text-sr/SEG attributes 89 | SR_attributes_df = mta_df[mta_df["moduleId"].isin(SR_modules_df["moduleId"])] 90 | 91 | attribute_to_keyword_map = dict(zip(attributes_df["tag"], attributes_df["keyword"])) 92 | self.keyword_to_VR_map = dict( 93 | zip(attributes_df["keyword"], attributes_df["valueRepresentation"]) 94 | ) 95 | attribute_to_type_map = dict(zip(SR_attributes_df["tag"], SR_attributes_df["type"])) 96 | 97 | self.keyword_to_type_map = {} 98 | for attribute in attribute_to_type_map: 99 | self.keyword_to_type_map[attribute_to_keyword_map[attribute]] = attribute_to_type_map[ 100 | attribute 101 | ] 102 | 103 | # Create dicom heirarchy for SR/SEG document (modeled after the Standard DICOM Browser) 104 | # --------------------------------------------------- 105 | SR_attributes_df.sort_values("path") 106 | self.dicom_tag_heirarchy = {} 107 | for _, row in SR_attributes_df.iterrows(): 108 | if row["path"].count(":") == 1: 109 | self.dicom_tag_heirarchy[attribute_to_keyword_map[row["tag"]]] = {} 110 | else: 111 | paths = row["path"].split(":") 112 | # convert all tags in path to tag format 113 | parents = [] 114 | for parent in paths[1:-1]: 115 | parent = f"({parent[:4]},{parent[4:]})".upper() 116 | parent = attribute_to_keyword_map[parent] 117 | parents.append(parent) 118 | 119 | child = paths[-1] 120 | child = f"({child[:4]},{child[4:]})".upper() 121 | child = attribute_to_keyword_map[child] 122 | 123 | # get to last tag sequence 124 | current_sequence = self.dicom_tag_heirarchy[parents[0]] 125 | for parent in parents[1:]: 126 | current_sequence = current_sequence[parent] 127 | current_sequence[child] = {} 128 | 129 | # Dictionary of VR and their corresponding types 130 | self.typos = { 131 | "AE": str, 132 | "AS": str, 133 | "AT": pydicom.tag.BaseTag, 134 | "CS": str, 135 | "DA": str, 136 | "DS": pydicom.valuerep.DSfloat, 137 | "DT": str, 138 | "FL": float, 139 | "FD": float, 140 | "IS": pydicom.valuerep.IS, 141 | "LO": str, 142 | "LT": str, 143 | "OB": bytes, 144 | "OB or OW": bytes, 145 | "OD": bytes, 146 | "OF": bytes, 147 | "OL": bytes, 148 | "OV": bytes, 149 | "OW": bytes, 150 | "PN": pydicom.valuerep.PersonName, 151 | "SH": str, 152 | "SL": int, 153 | "SQ": pydicom.sequence.Sequence, 154 | "SS": int, 155 | "ST": str, 156 | "SV": int, 157 | "TM": str, 158 | "UC": str, 159 | "UI": pydicom.uid.UID, 160 | "UL": int, 161 | "UN": bytes, 162 | "UR": str, 163 | "US": int, 164 | "US or SS": int, 165 | "UT": str, 166 | "UV": int, 167 | } 168 | 169 | def dicom_tags_setup(self): 170 | """ 171 | Organizes the dicom tags and study, series, and image level information into a more parsable structure. 172 | """ 173 | 174 | # Read Imported JSONs 175 | results = json_to_dataframe(os.getcwd() + "/" + self.Annotation_Json) 176 | self.metadata = pd.read_json(os.getcwd() + "/" + self.Metadata_Json) 177 | 178 | # Annotations dataframe 179 | self.annots_df = results["annotations"] 180 | labels = results["labels"] 181 | self.label_name_map = dict(zip(labels.labelId, labels.labelName)) 182 | self.label_scope_map = dict(zip(labels.labelId, labels.scope)) 183 | 184 | # Images DICOM Tags dataframe 185 | tags = [] 186 | for dataset in self.metadata["datasets"]: 187 | tags.extend(dataset["dicomMetadata"]) 188 | 189 | # Create organization of study, series, instance UID & dicom tags 190 | # ---------------------------------------------------------- 191 | self.studies = self.annots_df.StudyInstanceUID.unique() 192 | self.tags_df = pd.DataFrame.from_dict( 193 | tags 194 | ) # dataframe of study, series, instance UID & dicom tags 195 | self.dicom_hierarchy = {} 196 | for tag in tags: 197 | study_uid = tag["StudyInstanceUID"] 198 | series_uid = tag["SeriesInstanceUID"] 199 | sop_uid = tag["SOPInstanceUID"] 200 | 201 | # Check if already seen study_uid yet (avoids key error) 202 | if study_uid not in self.dicom_hierarchy: # Using study_uid bc rn it's exam level 203 | self.dicom_hierarchy[study_uid] = [] 204 | 205 | # Dicom_heirarchy is a dictionary with study_uid as keys and a list as value 206 | # each list contains a dictionary with the series_uid as a key and a list of sop_uids as value 207 | if not any(series_uid in d for d in self.dicom_hierarchy[study_uid]): 208 | self.dicom_hierarchy[study_uid].append({series_uid: []}) 209 | for d in self.dicom_hierarchy[ 210 | study_uid 211 | ]: # loops through item in dicom_heriarchy list (just the series_uid dict) 212 | if series_uid in d: 213 | d[series_uid].append(sop_uid) 214 | 215 | # Helper functions to place DICOM tags into SR/SEG document Template 216 | # --------------------------------------------------- 217 | """ 218 | > Iterates through a given sequence of tags from the standard DICOM heirarchy 219 | > Checks if the tag exists in the current DICOM file's headers 220 | >> If it does then it adds the tag to the SR document dataset 221 | > Recursively calls itself to add tags in sequences and 222 | >> Checks if a sequence contains all its required tags and adds them if so 223 | > Returns the SR document dataset with all tags added 224 | > If there were no tags added then returns False 225 | """ 226 | 227 | def place_tags(self, dicom_tags, curr_dataset, curr_seq, need_to_check_required=True): 228 | sequences = {} 229 | added = False 230 | # Iterate through sequence to add tags and find sequences 231 | for keyword in curr_seq: 232 | if keyword in dicom_tags: 233 | curr_dataset = self.add_to_dataset(curr_dataset, keyword, dicom_tags[keyword], True) 234 | added = True 235 | if self.keyword_to_VR_map[keyword] == "SQ": 236 | sequences[keyword] = curr_seq[keyword] 237 | 238 | # Iterate through sequences to add tags and recursively search within sequences for tags 239 | for keyword in sequences: 240 | if ( 241 | self.output_format == "SR" and keyword == "ContentSequence" 242 | ): # Skips ContentSequence since it's meant to contain the annotations data 243 | continue 244 | seq = sequences[keyword] 245 | new_dataset = Dataset() 246 | new_dataset = self.place_tags(dicom_tags, new_dataset, seq, need_to_check_required) 247 | if new_dataset: 248 | if self.keyword_to_VR_map[keyword] == "SQ": 249 | new_dataset = [new_dataset] # Pydicom requires sequences to be in a list 250 | if not need_to_check_required or self.check_required(new_dataset, seq): 251 | added = True 252 | curr_dataset = self.add_to_dataset(curr_dataset, keyword, new_dataset, True) 253 | 254 | if added: 255 | return curr_dataset 256 | 257 | return False 258 | 259 | # Checks if a sequence contains all its required tags 260 | def check_required(self, curr_dataset, curr_seq): 261 | for keyword in curr_seq: 262 | tag_type = self.keyword_to_type_map[keyword] 263 | if keyword not in curr_dataset and "1" == tag_type: 264 | return False 265 | return True 266 | 267 | # Adds tag to dataset and if the tag already exists then 268 | # Replaces tag if replace=True if not then does nothing 269 | def add_to_dataset(self, dataset, keyword, value, replace): 270 | VR = self.keyword_to_VR_map[keyword] 271 | 272 | # If the tag is a sequence then the value in dicom_tags will be a list containing dictionary so need to convert to sequence format 273 | if type(value) == list and VR == "SQ": 274 | if type(value[0]) == dict: 275 | value = self.dict_to_sequence(value) 276 | 277 | # If the tag is a byte encoding then need to switch it to so from string 278 | if self.typos[VR] == bytes and value != None: 279 | value = value[2:-1].encode("UTF-8") # removes b' and ' 280 | 281 | # If the tag is an int/float encoding then need to switch it to so from string 282 | if self.typos[VR] == int or self.typos[VR] == float: 283 | if value != None: 284 | value = self.typos[VR](value) 285 | 286 | # check if tag already in dataset 287 | if keyword in dataset: 288 | if not replace: 289 | return dataset 290 | dataset[keyword].value = value 291 | return dataset 292 | 293 | if ( 294 | "or SS" in VR and type(value) == int 295 | ): # Fix bug when VR == 'US or SS' and the value is negative (it always defaults to US) 296 | if value < 0: 297 | VR = "SS" 298 | 299 | dataset.add_new(keyword, VR, value) 300 | return dataset 301 | 302 | # Creates a sequence from a list of dictionaries 303 | def dict_to_sequence(self, dict_seq_list): 304 | sequences = [] 305 | for dict_seq in dict_seq_list: 306 | seq = Dataset() 307 | for keyword in dict_seq: 308 | if self.keyword_to_VR_map[keyword] == "SQ": 309 | inner_seq = self.dict_to_sequence(dict_seq[keyword]) 310 | seq = self.add_to_dataset(seq, keyword, inner_seq, True) 311 | else: 312 | seq = self.add_to_dataset(seq, keyword, dict_seq[keyword], True) 313 | sequences.append(seq) 314 | return sequences 315 | 316 | 317 | class SrExport(DicomExport): 318 | def __init__( 319 | self, 320 | annotation_json, 321 | metadata_json, 322 | combine_label_groups=True, 323 | output_dir=None, 324 | ): 325 | DicomExport.__init__( 326 | self, 327 | "SR", 328 | annotation_json, 329 | metadata_json, 330 | combine_label_groups, 331 | output_dir, 332 | ) 333 | self.create_sr_exports() 334 | 335 | def create_sr_exports(self): 336 | # Iterate through each study and create SR document for each annotator in each study 337 | # Save output to Output folder 338 | # --------------------------------------------------- 339 | try: 340 | if self.output_dir == None: 341 | out_dir = "SR_Output" 342 | os.mkdir("SR_Output") 343 | else: 344 | out_dir = self.output_dir 345 | os.mkdir(self.output_dir) 346 | except: 347 | pass 348 | 349 | from io import BytesIO 350 | 351 | document_file = os.path.join(os.path.dirname(__file__), "./sample_SR.dcm") 352 | for dataset_id in self.annots_df["datasetId"].unique(): 353 | self.dataset_annots = self.annots_df[self.annots_df.datasetId == dataset_id] 354 | for study_uid in self.studies: 355 | # load file template 356 | ds = dcmread(document_file) 357 | 358 | self.dicom_tags = self.tags_df[ 359 | self.tags_df.StudyInstanceUID == study_uid 360 | ].dicomTags.values[0] 361 | annotations = self.dataset_annots[self.dataset_annots.StudyInstanceUID == study_uid] 362 | 363 | annotators = annotations.createdById.unique() 364 | series_uid = pydicom.uid.generate_uid(prefix=None) 365 | instance_uid = pydicom.uid.generate_uid(prefix=None) 366 | date = datetime.now().strftime("%Y%m%d") 367 | time = datetime.now().strftime("%H%M%S") 368 | 369 | # Place all the tags from the dicom into the SR document 370 | ds = self.place_tags(self.dicom_tags, ds, self.dicom_tag_heirarchy) 371 | 372 | # modify file metadata 373 | ds.file_meta.MediaStorageSOPInstanceUID = ( 374 | instance_uid # Media Storage SOP Instance UID 375 | ) 376 | ds.file_meta.ImplementationClassUID = str( 377 | pydicom.uid.PYDICOM_IMPLEMENTATION_UID 378 | ) # Implementation Class UID 379 | ds.file_meta.ImplementationVersionName = str( 380 | pydicom.__version__ 381 | ) # Implementation Version Name 382 | 383 | # delete tags 384 | del ds[0x00080012] # Instance Creation Date 385 | del ds[0x00080013] # Instance Creation Time 386 | del ds[0x00080014] # Instance Creator UID 387 | # del ds[0x00100030] # Patient's Birth Date 388 | 389 | # modify tags 390 | # ------------------------- 391 | 392 | ds[ 393 | "SOPClassUID" 394 | ].value = "1.2.840.10008.5.1.4.1.1.88.22" # SOP Class UID = enhanced SR storage 395 | ds[0x00080018].value = instance_uid # SOPInstanceUID 396 | ds[0x0008103E].value = str(self.metadata["name"].values[0]) # Series Description 397 | ds[0x00080021].value = str(date) # Series Date 398 | ds[0x00080023].value = str(date) # Content Date 399 | ds[0x00080031].value = str(time) # Series Time 400 | ds[0x00080033].value = str(time) # Content Time 401 | 402 | ds[0x00181020].value = "" # Software Versions 403 | 404 | ds[0x0020000D].value = str(study_uid) # Study Instance UID 405 | ds[0x0020000E].value = str(series_uid) # Series Instance UID 406 | ds[0x00200011].value = str(1) # Series Number 407 | 408 | ds.Modality = "SR" 409 | 410 | # create dicom hierarchy 411 | dicom_hier = self.dicom_hierarchy[study_uid] 412 | series_sequence = [] 413 | for series in dicom_hier: 414 | for key in series: 415 | sops = series[key] 416 | series_hier = Dataset() 417 | sop_sequence = [] 418 | for sop in sops: 419 | sop_data = Dataset() 420 | if "SOPClassUID" in self.dicom_tags: 421 | sop_data.ReferencedSOPClassUID = self.dicom_tags["SOPClassUID"] 422 | sop_data.ReferencedSOPInstanceUID = sop 423 | sop_sequence.append(sop_data) 424 | series_hier.ReferencedSOPSequence = sop_sequence 425 | series_hier.SeriesInstanceUID = key 426 | series_sequence.append(series_hier) 427 | 428 | ds[0x0040A375][0].ReferencedSeriesSequence = series_sequence 429 | ds[0x0040A375][0].StudyInstanceUID = study_uid 430 | 431 | # add tags 432 | ds[0x00080005] = pydicom.dataelem.DataElement( 433 | 0x00080005, "CS", "ISO_IR 192" 434 | ) # Specific Character Set 435 | 436 | # create content for each annotator 437 | for i in range(len(annotators)): 438 | instance_number = i + 1 439 | ds[0x00200013] = pydicom.dataelem.DataElement( 440 | 0x00200013, "IS", str(instance_number) 441 | ) # Instance Number 442 | ds[0x0040A730][0][0x0040A123].value = f"Annotator{instance_number}" 443 | ds[0x0040A078][0][0x0040A123].value = f"Annotator{instance_number}" 444 | anns = annotations[annotations.createdById == annotators[i]] 445 | 446 | anns_map = {} 447 | 448 | def annotator_iteration(row): 449 | annotation = [] 450 | label_id = row["labelId"] 451 | parent_id = row["parentLabelId"] 452 | annotation.extend( 453 | [ 454 | parent_id, 455 | row["scope"], 456 | row["SOPInstanceUID"], 457 | row["SeriesInstanceUID"], 458 | ] 459 | ) 460 | if "SOPClassUID" in self.dicom_tags: 461 | annotation.append(self.dicom_tags["SOPClassUID"]) 462 | 463 | if label_id not in anns_map: 464 | anns_map[label_id] = [] 465 | anns_map[label_id].append(annotation) 466 | 467 | anns.apply(annotator_iteration, axis=1) 468 | 469 | # annotator_iteration has extraneous labels for those with child labels as it creates 2 separate entries for the child label and the parent label 470 | for label_id in anns_map: 471 | for annot in anns_map[label_id]: 472 | if annot[0] != None: 473 | if ( 474 | annot[0] not in anns_map 475 | ): # Fixes edge case where a child label appears with no parent label for that annotator 476 | continue # Occurs when another annotator adds a child label to a different annotator's label 477 | 478 | for j in range( 479 | len(anns_map[annot[0]]) - 1, -1, -1 480 | ): # iterate backwards so can delete while iterating 481 | parent_annot = anns_map[annot[0]][j] 482 | if ( 483 | ( 484 | type(parent_annot[2]) == type(annot[2]) 485 | and type(annot[2] == float) 486 | ) 487 | and ( 488 | type(parent_annot[3]) == type(annot[3]) 489 | and type(annot[3] == float) 490 | ) 491 | ) or ( 492 | (parent_annot[2] == annot[2]) 493 | and (parent_annot[3] == annot[3]) 494 | ): # check if series and sop uid are same 495 | del anns_map[annot[0]][j] 496 | 497 | content_sequence = [] 498 | code_number = 43770 # hello 499 | 500 | # Create a list of labelIds ordered from exam to series to image 501 | ordered_labels = [] 502 | j = 0 503 | for label_id in anns_map: 504 | if self.label_scope_map[label_id] == "EXAM": 505 | ordered_labels.insert(0, label_id) 506 | j += 1 507 | elif self.label_scope_map[label_id] == "INSTANCE": 508 | ordered_labels.append(label_id) 509 | else: 510 | ordered_labels.insert(j, label_id) 511 | 512 | for label_id in ordered_labels: 513 | for a in anns_map[label_id]: 514 | # Add 'Referenced Segment' if label is in IMAGE scope 515 | if a[1] == "INSTANCE": 516 | content = Dataset() 517 | content.ValueType = "IMAGE" 518 | referenced_sequence_ds = Dataset() 519 | if len(a) > 4: 520 | referenced_sequence_ds.ReferencedSOPClassUID = a[4] 521 | referenced_sequence_ds.ReferencedSOPInstanceUID = a[2] 522 | content.ReferencedSOPSequence = [referenced_sequence_ds] 523 | 524 | code_sequence_ds = Dataset() 525 | code_sequence_ds.CodeValue = str(code_number) 526 | code_sequence_ds.CodingSchemeDesignator = "99MDAI" 527 | code_sequence_ds.CodeMeaning = "Referenced Image" 528 | code_sequence = [code_sequence_ds] 529 | content.ConceptNameCodeSequence = code_sequence 530 | code_number += 1 531 | content_sequence.append(content) 532 | 533 | # Add parent label to text value 534 | content = Dataset() 535 | code_sequence_ds = Dataset() 536 | if a[0] != None: 537 | code_name = self.label_name_map[a[0]] 538 | else: 539 | code_name = self.label_name_map[label_id] 540 | code_sequence_ds.CodeValue = str(hash(code_name))[1:6] 541 | code_sequence_ds.CodingSchemeDesignator = "99MDAI" 542 | code_sequence_ds.CodeMeaning = code_name 543 | code_sequence = [code_sequence_ds] 544 | content.ConceptNameCodeSequence = code_sequence 545 | 546 | # Add child label text 547 | text_value = "" 548 | if a[0] != None: 549 | text_value = ",".join( 550 | map(lambda labelId: self.label_name_map[labelId], [label_id]) 551 | ) 552 | text_value += "\n" 553 | content.TextValue = text_value 554 | # Add 'Series UID:' 555 | if a[1] == "SERIES": 556 | text_value += f"Series UID: {series_uid}" 557 | content.TextValue = text_value 558 | if text_value != "": 559 | content.ValueType = "TEXT" 560 | else: 561 | content.ValueType = "CONTAINER" 562 | content_sequence.append(content) 563 | 564 | ds[0x0040A730][1][0x0040A730][0].ContentSequence = content_sequence 565 | 566 | ds.save_as( 567 | f"{os.getcwd()}/{out_dir}/DICOM_SR_{dataset_id}_{study_uid}_annotator_{instance_number}.dcm" 568 | ) 569 | print(f"Successfully exported DICOM SR files into {out_dir}") 570 | 571 | 572 | class SegExport(DicomExport): 573 | def __init__( 574 | self, 575 | annotation_json, 576 | metadata_json, 577 | combine_label_groups=True, 578 | output_dir=None, 579 | ): 580 | DicomExport.__init__( 581 | self, 582 | "SEG", 583 | annotation_json, 584 | metadata_json, 585 | combine_label_groups, 586 | output_dir, 587 | ) 588 | self.create_seg_exports() 589 | 590 | # Annotation dataframe has a separate row for a parent label. This function drops that row 591 | def drop_dupes(self, row): 592 | if row["parentLabelId"] != None: 593 | if ( 594 | row["parentLabelId"] not in self.annots_df["labelId"].unique() 595 | ): # Fixes edge case where a child label appears with no parent label for that annotator 596 | return # Occurs when another annotator adds a child label to a different annotator's label 597 | parents = self.annots_df[self.annots_df["labelId"] == row["parentLabelId"]] 598 | study_parents = parents[parents["StudyInstanceUID"] == row["StudyInstanceUID"]] 599 | series_parents = study_parents[ 600 | study_parents["SeriesInstanceUID"] == row["SeriesInstanceUID"] 601 | ] 602 | sop_parents = series_parents[series_parents["SOPInstanceUID"] == row["SOPInstanceUID"]] 603 | 604 | if len(sop_parents.index) > 0: 605 | self.annots_df.drop(sop_parents.index[0], inplace=True) 606 | elif len(series_parents.index) > 0: 607 | self.annots_df.drop(series_parents.index[0], inplace=True) 608 | elif len(study_parents.index) > 0: 609 | self.annots_df.drop(study_parents.index[0], inplace=True) 610 | 611 | # Gets imgs from annotations and creates segment sequence 612 | 613 | def img_insert(self, row, ds): 614 | data = self.load_mask_instance(row) 615 | if not np.isscalar(data): 616 | if self.prev_annot is not None and ( 617 | self.prev_annot["labelId"] == row["labelId"] 618 | and self.prev_annot["labelGroupName"] == row["labelGroupName"] 619 | and self.prev_annot["instanceNumber"] == row["instanceNumber"] 620 | ): 621 | mask2 = self.load_mask_instance(row) 622 | self.imgs[-1] = np.ma.mask_or(self.imgs[-1], mask2) 623 | else: 624 | self.imgs.append(self.load_mask_instance(row)) 625 | self.included_sops.append((len(self.seen_labels) + 1, row["SOPInstanceUID"])) 626 | self.unique_sops.add(row["SOPInstanceUID"]) 627 | self.name_number_map[len(self.seen_labels) + 1] = row["labelName"] 628 | self.prev_annot = row 629 | 630 | if row["labelId"] not in self.seen_labels: 631 | if row["parentLabelId"] == None: 632 | parent_label_name = self.label_name_map[row["labelId"]] 633 | else: 634 | parent_label_name = self.label_name_map[row["parentLabelId"]] 635 | child_label_name = self.label_name_map[row["labelId"]] 636 | 637 | segment_sequence = ds.SegmentSequence 638 | 639 | segment1 = Dataset() 640 | segment_sequence.append(segment1) 641 | 642 | # Segmented Property Category Code Sequence 643 | segmented_property_category_code_sequence = Sequence() 644 | segment1.SegmentedPropertyCategoryCodeSequence = ( 645 | segmented_property_category_code_sequence 646 | ) 647 | 648 | # Segmented Property Category Code Sequence: Segmented Property Category Code 1 649 | segmented_property_category_code1 = Dataset() 650 | segmented_property_category_code_sequence.append(segmented_property_category_code1) 651 | segmented_property_category_code1.CodeValue = str(hash(parent_label_name))[1:6] 652 | segmented_property_category_code1.CodingSchemeDesignator = "99MDAI" 653 | segmented_property_category_code1.CodeMeaning = ( 654 | f'{parent_label_name} from Label Group {row["labelGroupName"]}' 655 | ) 656 | 657 | segment1.SegmentNumber = len(self.seen_labels) + 1 # (number of labels) 658 | segment1.SegmentLabel = child_label_name 659 | segment1.SegmentAlgorithmType = "MANUAL" # Maybe change based on how it was created 660 | 661 | # Segmented Property Type Code Sequence 662 | segmented_property_type_code_sequence = Sequence() 663 | segment1.SegmentedPropertyTypeCodeSequence = segmented_property_type_code_sequence 664 | 665 | # Segmented Property Type Code Sequence: Segmented Property Type Code 1 666 | segmented_property_type_code1 = Dataset() 667 | segmented_property_type_code_sequence.append(segmented_property_type_code1) 668 | segmented_property_type_code1.CodeValue = str(hash(child_label_name))[1:6] 669 | segmented_property_type_code1.CodingSchemeDesignator = "99MDAI" 670 | segmented_property_type_code1.CodeMeaning = child_label_name 671 | 672 | self.seen_labels.add(row["labelId"]) 673 | 674 | def load_mask_instance(self, row): 675 | """Load instance masks for the given annotation row. Masks can be different types, 676 | mask is a binary true/false map of the same size as the image. 677 | """ 678 | 679 | if row.data == None: 680 | return 404 # no data found 681 | 682 | mask = np.zeros((int(row.height), int(row.width)), dtype=np.uint8) 683 | 684 | annotation_mode = row.annotationMode 685 | 686 | if annotation_mode == "bbox": 687 | # Bounding Box 688 | x = int(row.data["x"]) 689 | y = int(row.data["y"]) 690 | w = int(row.data["width"]) 691 | h = int(row.data["height"]) 692 | mask_instance = mask[:, :].copy() 693 | cv2.rectangle(mask_instance, (x, y), (x + w, y + h), 255, -1) 694 | mask[:, :] = mask_instance 695 | 696 | # FreeForm or Polygon 697 | elif annotation_mode == "freeform" or annotation_mode == "polygon": 698 | vertices = np.array(row.data["vertices"]) 699 | vertices = vertices.reshape((-1, 2)) 700 | mask_instance = mask[:, :].copy() 701 | cv2.fillPoly(mask_instance, np.int32([vertices]), (255, 255, 255)) 702 | mask[:, :] = mask_instance 703 | 704 | # Line 705 | elif annotation_mode == "line": 706 | vertices = np.array(row.data["vertices"]) 707 | vertices = vertices.reshape((-1, 2)) 708 | mask_instance = mask[:, :].copy() 709 | cv2.polylines(mask_instance, np.int32([vertices]), False, (255, 255, 255), 12) 710 | mask[:, :] = mask_instance 711 | 712 | elif annotation_mode == "location": 713 | # Bounding Box 714 | x = int(row.data["x"]) 715 | y = int(row.data["y"]) 716 | mask_instance = mask[:, :].copy() 717 | cv2.circle(mask_instance, (x, y), 7, (255, 255, 255), -1) 718 | mask[:, :] = mask_instance 719 | 720 | elif annotation_mode == "ellipse": 721 | cx = int(row.data["cx"]) 722 | cy = int(row.data["cy"]) 723 | rx = int(row.data["rx"]) 724 | ry = int(row.data["ry"]) 725 | mask_instance = mask[:, :].copy() 726 | cv2.ellipse(mask_instance, (cx, cy), (rx, ry), 0, 0, 360, (255, 255, 255), 12) 727 | mask[:, :] = mask_instance 728 | 729 | elif annotation_mode == "mask": 730 | mask_instance = mask[:, :].copy() 731 | if row.data["foreground"]: 732 | for i in row.data["foreground"]: 733 | mask_instance = cv2.fillPoly( 734 | mask_instance, [np.array(i, dtype=np.int32)], (255, 255, 255) 735 | ) 736 | if row.data["background"]: 737 | for i in row.data["background"]: 738 | mask_instance = cv2.fillPoly( 739 | mask_instance, [np.array(i, dtype=np.int32)], (0, 0, 0) 740 | ) 741 | mask[:, :] = mask_instance 742 | 743 | return mask.astype(bool) 744 | 745 | def create_seg_exports(self): 746 | """ 747 | Creates a template SEG File and adds in necessary SEG information 748 | Instead of working from a template, this function creates a segmentation file from scratch using pydicom 749 | """ 750 | try: 751 | if self.output_dir == None: 752 | out_dir = "SEG_Output" 753 | os.mkdir("SEG_Output") 754 | else: 755 | out_dir = self.output_dir 756 | os.mkdir(self.output_dir) 757 | except: 758 | pass 759 | 760 | self.annots_df.apply(self.drop_dupes, axis=1) 761 | 762 | for dataset_id in self.annots_df["datasetId"].unique(): 763 | self.dataset_annots = self.annots_df[self.annots_df.datasetId == dataset_id] 764 | for study_uid in self.studies: 765 | dicom_hier = self.dicom_hierarchy[study_uid] 766 | for series_dict in dicom_hier: 767 | for series_uid in series_dict: 768 | sops = series_dict[series_uid] 769 | 770 | annotations = self.dataset_annots[ 771 | self.dataset_annots.SeriesInstanceUID == series_uid 772 | ] 773 | annotations = annotations[annotations["scope"] == "INSTANCE"] 774 | if annotations.empty: 775 | continue 776 | 777 | self.dicom_tags = self.tags_df[ 778 | self.tags_df.SeriesInstanceUID == series_uid 779 | ].dicomTags.values[0] 780 | annotators = annotations.createdById.unique() 781 | instance_uid = pydicom.uid.generate_uid(prefix=None) 782 | date = datetime.now().strftime("%Y%m%d") 783 | time = datetime.now().strftime("%H%M%S") 784 | 785 | sop_instance_num_map = {} 786 | for sop in sops: 787 | sop_dicom_tags = self.tags_df[ 788 | self.tags_df.SOPInstanceUID == sop 789 | ].dicomTags.values[0] 790 | if "InstanceNumber" in sop_dicom_tags: 791 | sop_instance_num_map[sop] = sop_dicom_tags["InstanceNumber"] 792 | else: 793 | sop_instance_num_map[sop] = "1" 794 | 795 | def create_instance_number(row): 796 | return sop_instance_num_map[row["SOPInstanceUID"]] 797 | 798 | try: 799 | annotations["instanceNumber"] = annotations.apply( 800 | create_instance_number, axis=1 801 | ) 802 | except: 803 | continue 804 | annotations = annotations.sort_values( 805 | ["labelGroupName", "labelId", "instanceNumber"], ignore_index=True 806 | ) # sort by label group then annotation then appearance in series 807 | 808 | # File meta info data elements 809 | file_meta = FileMetaDataset() 810 | file_meta.FileMetaInformationVersion = b"\x00\x01" 811 | file_meta.TransferSyntaxUID = "1.2.840.10008.1.2.1" 812 | file_meta.MediaStorageSOPInstanceUID = instance_uid # Create Instance UID # Media Storage SOP Instance UID 813 | file_meta.ImplementationClassUID = str( 814 | pydicom.uid.PYDICOM_IMPLEMENTATION_UID 815 | ) # Implementation Class UID 816 | file_meta.ImplementationVersionName = str( 817 | pydicom.__version__ 818 | ) # Implementation Version Name 819 | file_meta.SourceApplicationEntityTitle = "POSDA" 820 | 821 | # Main data elements 822 | ds = Dataset() 823 | 824 | ds = self.place_tags(self.dicom_tags, ds, self.dicom_tag_heirarchy, True) 825 | 826 | ds.SpecificCharacterSet = "ISO_IR 192" 827 | ds.SOPClassUID = "1.2.840.10008.5.1.4.1.1.66.4" 828 | ds.SOPInstanceUID = instance_uid 829 | ds.SeriesDate = str(date) # Series Date 830 | ds.ContentDate = str(date) # Content Date 831 | ds.SeriesTime = str(time) # Series Time 832 | ds.ContentTime = str(time) # Series Time 833 | ds.Manufacturer = "MDAI" 834 | ds.Modality = "SEG" 835 | 836 | # Referenced Series Sequence 837 | refd_series_sequence = Sequence() 838 | ds.ReferencedSeriesSequence = refd_series_sequence 839 | 840 | # Referenced Series Sequence: Referenced Series 1 841 | refd_series1 = Dataset() 842 | refd_series_sequence.append(refd_series1) 843 | 844 | # Referenced Series Sequence: Referenced Series 1 845 | refd_series1 = Dataset() 846 | refd_series_sequence.append(refd_series1) 847 | refd_series1.SeriesInstanceUID = series_uid 848 | 849 | # Referenced Instance Sequence 850 | refd_instance_sequence = Sequence() 851 | refd_series1.ReferencedInstanceSequence = refd_instance_sequence 852 | 853 | ds.SegmentationType = "BINARY" 854 | 855 | for annotator_id in annotators: 856 | annotator_annots = annotations[annotations.createdById == annotator_id] 857 | 858 | if self.combine: 859 | label_group_sets = [annotator_annots.labelGroupName.unique()] 860 | else: 861 | label_group_sets = [ 862 | [group] for group in annotator_annots.labelGroupName.unique() 863 | ] 864 | 865 | ds.SamplesPerPixel = 1 866 | ds.PhotometricInterpretation = "MONOCHROME2" 867 | ds.BitsAllocated = 1 868 | ds.BitsStored = 1 869 | ds.HighBit = 0 870 | ds.PixelRepresentation = 0 871 | ds.LossyImageCompression = "00" 872 | 873 | for label_group_set in label_group_sets: 874 | label_group_annots = annotator_annots[ 875 | annotator_annots.labelGroupName.isin(label_group_set) 876 | ] 877 | 878 | # Segment Sequence 879 | segment_sequence = Sequence() 880 | ds.SegmentSequence = segment_sequence 881 | 882 | self.imgs = [] 883 | self.seen_labels = set() 884 | self.name_number_map = {} 885 | self.included_sops = [] 886 | self.unique_sops = set() 887 | self.label_groups = list(annotations.labelGroupName.unique()) 888 | self.prev_annot = None 889 | label_group_annots.apply(self.img_insert, args=(ds,), axis=1) 890 | 891 | ds.NumberOfFrames = len( 892 | self.imgs 893 | ) # create during last parts of SEG file (should equal length of annot_df) 894 | ds.PixelData = pack_bits(np.array(self.imgs)) 895 | 896 | for sop in self.unique_sops: 897 | sop_dicom_tags = self.tags_df[ 898 | self.tags_df.SOPInstanceUID == sop 899 | ].dicomTags.values[0] 900 | refd_instance1 = Dataset() 901 | refd_instance_sequence.append(refd_instance1) 902 | if "SOPClassUID" in sop_dicom_tags: 903 | refd_instance1.ReferencedSOPClassUID = sop_dicom_tags[ 904 | "SOPClassUID" 905 | ] 906 | refd_instance1.ReferencedSOPInstanceUID = sop_dicom_tags[ 907 | "SOPInstanceUID" 908 | ] 909 | 910 | # Leaving it out for now but if nothing works then maybe try to add it back in blank and then with dummy values 911 | # Edit: added it back in but still unnecessary. 912 | # ----------------------------------------------------------------------- 913 | # Dimension Index Sequence 914 | dimension_index_sequence = Sequence() 915 | ds.DimensionIndexSequence = dimension_index_sequence 916 | # ----------------------------------------------------------------------- 917 | 918 | ds.ContentLabel = "MDAI_SEG" 919 | ds.ContentCreatorName = f"annotator {annotator_id}" 920 | 921 | # Leaving it out for now but if nothing works then maybe try to add it back in blank and then with dummy values 922 | # Edit: added it back in but still unnecessary. 923 | # ----------------------------------------------------------------------- 924 | # Shared Functional Groups Sequence 925 | shared_functional_groups_sequence = Sequence() 926 | ds.SharedFunctionalGroupsSequence = ( 927 | shared_functional_groups_sequence 928 | ) 929 | # ----------------------------------------------------------------------- 930 | 931 | # Per-frame Functional Groups Sequence 932 | per_frame_functional_groups_sequence = Sequence() 933 | ds.PerFrameFunctionalGroupsSequence = ( 934 | per_frame_functional_groups_sequence 935 | ) 936 | 937 | # Per-frame Functional Groups Sequence 938 | per_frame_functional_groups_sequence = Sequence() 939 | ds.PerFrameFunctionalGroupsSequence = ( 940 | per_frame_functional_groups_sequence 941 | ) 942 | 943 | # Per-frame Functional Groups Sequence 944 | per_frame_functional_groups_sequence = [] 945 | 946 | # Loop through each frame with an annotation and create unique Per Frame Functional Group Sequence 947 | # --------------------------------------------------------------------------- 948 | for segment_number, sop in self.included_sops: 949 | label_names = ", ".join( 950 | label_group_annots["labelName"].unique() 951 | ) 952 | ds.SeriesDescription = ( 953 | f"Segmentation of {label_names} by annotator {annotator_id}" 954 | ) 955 | 956 | sop_dicom_tags = self.tags_df[ 957 | self.tags_df.SOPInstanceUID == sop 958 | ].dicomTags.values[0] 959 | 960 | # Per-frame Functional Groups Sequence: Per-frame Functional Groups 1 961 | per_frame_functional_groups1 = Dataset() 962 | per_frame_functional_groups_sequence.append( 963 | per_frame_functional_groups1 964 | ) 965 | 966 | # Derivation Image Sequence 967 | derivation_image_sequence = Sequence() 968 | per_frame_functional_groups1.DerivationImageSequence = ( 969 | derivation_image_sequence 970 | ) 971 | 972 | # Derivation Image Sequence: Derivation Image 1 973 | derivation_image1 = Dataset() 974 | derivation_image_sequence.append(derivation_image1) 975 | 976 | # Source Image Sequence 977 | source_image_sequence = Sequence() 978 | derivation_image1.SourceImageSequence = source_image_sequence 979 | 980 | # Source Image Sequence: Source Image 1 981 | source_image1 = Dataset() 982 | source_image_sequence.append(source_image1) 983 | if "SOPClassUID" in self.dicom_tags: 984 | source_image1.ReferencedSOPClassUID = self.dicom_tags[ 985 | "SOPClassUID" 986 | ] 987 | source_image1.ReferencedSOPInstanceUID = self.dicom_tags[ 988 | "SOPInstanceUID" 989 | ] 990 | 991 | # Purpose of Reference Code Sequence 992 | purpose_of_ref_code_sequence = Sequence() 993 | source_image1.PurposeOfReferenceCodeSequence = ( 994 | purpose_of_ref_code_sequence 995 | ) 996 | 997 | # Purpose of Reference Code Sequence: Purpose of Reference Code 1 998 | purpose_of_ref_code1 = Dataset() 999 | purpose_of_ref_code_sequence.append(purpose_of_ref_code1) 1000 | purpose_of_ref_code1.CodeValue = "121322" 1001 | purpose_of_ref_code1.CodingSchemeDesignator = "DCM" 1002 | purpose_of_ref_code1.CodeMeaning = ( 1003 | "Source image for image processing operation" 1004 | ) 1005 | 1006 | # Derivation Code Sequence 1007 | derivation_code_sequence = Sequence() 1008 | derivation_image1.DerivationCodeSequence = ( 1009 | derivation_code_sequence 1010 | ) 1011 | 1012 | # Derivation Code Sequence: Derivation Code 1 1013 | derivation_code1 = Dataset() 1014 | derivation_code_sequence.append(derivation_code1) 1015 | derivation_code1.CodeValue = "113076" 1016 | derivation_code1.CodingSchemeDesignator = "DCM" 1017 | derivation_code1.CodeMeaning = "Segmentation" 1018 | 1019 | # Segment Identification Sequence 1020 | segment_id_seq = Dataset() 1021 | per_frame_functional_groups1.SegmentIdentificationSequence = [ 1022 | segment_id_seq 1023 | ] 1024 | 1025 | # Segment Number 1026 | segment_id_seq.ReferencedSegmentNumber = segment_number 1027 | 1028 | per_frame_functional_groups1 = self.place_tags( 1029 | sop_dicom_tags, 1030 | per_frame_functional_groups1, 1031 | self.dicom_tag_heirarchy[ 1032 | "PerFrameFunctionalGroupsSequence" 1033 | ], 1034 | False, 1035 | ) 1036 | # ------------------------------------------------------------------------- 1037 | 1038 | ds.PerFrameFunctionalGroupsSequence = ( 1039 | per_frame_functional_groups_sequence 1040 | ) 1041 | 1042 | ds.file_meta = file_meta 1043 | ds.is_implicit_VR = False 1044 | ds.is_little_endian = True 1045 | 1046 | if self.included_sops: 1047 | if self.combine: 1048 | ds.save_as( 1049 | f"{os.getcwd()}/{out_dir}/DICOM_SEG_{dataset_id}_{series_uid}_annotator_{annotator_id}.dcm", 1050 | False, 1051 | ) 1052 | else: 1053 | ds.save_as( 1054 | f"{os.getcwd()}/{out_dir}/DICOM_SEG_{dataset_id}_label_group_{label_group_set[0]}_series_{series_uid}_annotator_{annotator_id}.dcm", 1055 | False, 1056 | ) 1057 | print(f"Successfully exported DICOM SEG files into {out_dir}") 1058 | 1059 | 1060 | def iterate_content_seq(content, content_seq_list): 1061 | """ 1062 | util helper function iterating through DICOM-SR content sequences and append to list 1063 | """ 1064 | for content_seq in content_seq_list: 1065 | parent_labels = [] 1066 | child_labels = [] 1067 | notes = [] 1068 | 1069 | if "RelationshipType" in content_seq: 1070 | if content_seq.RelationshipType == "HAS ACQ CONTEXT": 1071 | continue 1072 | 1073 | if content_seq.ValueType == "IMAGE": 1074 | if "ReferencedSOPSequence" in content_seq: 1075 | for ref_seq in content_seq.ReferencedSOPSequence: 1076 | if "ReferencedSOPClassUID" in ref_seq: 1077 | notes.append( 1078 | f"\n Referenced SOP Class UID = {ref_seq.ReferencedSOPClassUID}" 1079 | ) 1080 | if "ReferencedSOPInstanceUID" in ref_seq: 1081 | notes.append( 1082 | f"\n Referenced SOP Instance UID = {ref_seq.ReferencedSOPInstanceUID}" 1083 | ) 1084 | if "ReferencedSegmentNumber" in ref_seq: 1085 | notes.append( 1086 | f"\n Referenced Segment Number = {ref_seq.ReferencedSegmentNumber}" 1087 | ) 1088 | else: 1089 | continue 1090 | 1091 | if "ConceptNameCodeSequence" in content_seq: 1092 | if len(content_seq.ConceptNameCodeSequence) > 0: 1093 | parent_labels.append(content_seq.ConceptNameCodeSequence[0].CodeMeaning) 1094 | if "ConceptCodeSequence" in content_seq: 1095 | if len(content_seq.ConceptCodeSequence) > 0: 1096 | child_labels.append(content_seq.ConceptCodeSequence[0].CodeMeaning) 1097 | 1098 | if "DateTime" in content_seq: 1099 | notes.append(content_seq.DateTime) 1100 | if "Date" in content_seq: 1101 | notes.append(content_seq.Date) 1102 | if "PersonName" in content_seq: 1103 | notes.append(str(content_seq.PersonName)) 1104 | if "UID" in content_seq: 1105 | notes.append(content_seq.UID) 1106 | if "TextValue" in content_seq: 1107 | child_labels.append(content_seq.TextValue) 1108 | if "MeasuredValueSequence" in content_seq: 1109 | if len(content_seq.MeasuredValueSequence) > 0: 1110 | units = ( 1111 | content_seq.MeasuredValueSequence[0].MeasurementUnitsCodeSequence[0].CodeValue 1112 | ) 1113 | notes.append(str(content_seq.MeasuredValueSequence[0].NumericValue) + units) 1114 | 1115 | if "ContentSequence" in content_seq: 1116 | iterate_content_seq(content, list(content_seq.ContentSequence)) 1117 | else: 1118 | content.append([", ".join(parent_labels), ", ".join(child_labels), ", ".join(notes)]) 1119 | -------------------------------------------------------------------------------- /mdai/utils/keras_utils.py: -------------------------------------------------------------------------------- 1 | from mdai.visualize import load_dicom_image 2 | from keras.utils import Sequence, to_categorical 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | class DataGenerator(Sequence): 9 | def __init__( 10 | self, 11 | dataset, 12 | batch_size=32, 13 | dim=(32, 32), 14 | n_channels=1, 15 | n_classes=10, 16 | shuffle=True, 17 | to_RGB=True, 18 | rescale=False, 19 | ): 20 | """Generates data for Keras fit_generator() function. 21 | """ 22 | 23 | # Initialization 24 | self.dim = dim 25 | self.batch_size = batch_size 26 | 27 | self.img_ids = dataset.image_ids 28 | self.imgs_anns_dict = dataset.imgs_anns_dict 29 | self.dataset = dataset 30 | 31 | self.n_channels = n_channels 32 | self.n_classes = n_classes 33 | self.shuffle = shuffle 34 | self.to_RGB = to_RGB 35 | self.rescale = rescale 36 | self.on_epoch_end() 37 | 38 | def __len__(self): 39 | "Denotes the number of batches per epoch" 40 | return int(np.floor(len(self.img_ids) / self.batch_size)) 41 | 42 | def __getitem__(self, index): 43 | "Generate one batch of data" 44 | 45 | # Generate indexes of the batch 46 | indexes = self.indexes[index * self.batch_size : (index + 1) * self.batch_size] 47 | 48 | # Find list of IDs 49 | img_ids_temp = [self.img_ids[k] for k in indexes] 50 | 51 | # Generate data 52 | X, y = self.__data_generation(img_ids_temp) 53 | 54 | return X, y 55 | 56 | def on_epoch_end(self): 57 | "Updates indexes after each epoch" 58 | self.indexes = np.arange(len(self.img_ids)) 59 | if self.shuffle: 60 | np.random.shuffle(self.indexes) 61 | 62 | def __data_generation(self, img_ids_temp): 63 | "Generates data containing batch_size samples" 64 | 65 | # Initialization 66 | X = np.empty((self.batch_size, *self.dim, self.n_channels)) 67 | y = np.empty((self.batch_size), dtype=int) 68 | 69 | # Generate data 70 | for i, ID in enumerate(img_ids_temp): 71 | image = load_dicom_image(ID, to_RGB=self.to_RGB, rescale=self.rescale) 72 | try: 73 | image = Image.fromarray(image) 74 | except Exception: 75 | print( 76 | "Pil.Image can't read image. Possible 12 or 16 bit image. Try rescale=True to " 77 | + "scale to 8 bit." 78 | ) 79 | 80 | image = image.resize((self.dim[0], self.dim[1])) 81 | 82 | X[i,] = image 83 | 84 | ann = self.imgs_anns_dict[ID][0] 85 | y[i] = self.dataset.classes_dict[ann["labelId"]]["class_id"] 86 | return X, to_categorical(y, num_classes=self.n_classes) 87 | -------------------------------------------------------------------------------- /mdai/utils/sample_SR.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdai/mdai-client-py/28ab35436f3db1fb056b1e8d3bc13d9ae4c5c555/mdai/utils/sample_SR.dcm -------------------------------------------------------------------------------- /mdai/utils/sample_dicom.dcm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdai/mdai-client-py/28ab35436f3db1fb056b1e8d3bc13d9ae4c5c555/mdai/utils/sample_dicom.dcm -------------------------------------------------------------------------------- /mdai/utils/tensorflow_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import tensorflow as tf 3 | import sys 4 | import io 5 | import hashlib 6 | from object_detection.utils import dataset_util 7 | from mdai import visualize 8 | 9 | 10 | def create_tf_bbox_example(annotations, image_id, classes_dict): 11 | 12 | image = visualize.load_dicom_image(image_id) 13 | width = int(image.shape[1]) 14 | height = int(image.shape[0]) 15 | 16 | raw_img = visualize.load_dicom_image(image_id, to_RGB=True) 17 | img = Image.fromarray(raw_img) 18 | img_buffer = io.BytesIO() 19 | img.save(img_buffer, format="jpeg") 20 | encoded_jpg = Image.open(img_buffer) 21 | 22 | if encoded_jpg.format != "JPEG": 23 | raise ValueError("Image format not JPEG") 24 | 25 | key = hashlib.sha256(img_buffer.getvalue()).hexdigest() 26 | 27 | xmins = [] # List of normalized left x coordinates in bounding box (1 per box) 28 | xmaxs = [] # List of normalized right x coordinates in bounding box (1 per box) 29 | ymins = [] # List of normalized top y coordinates in bounding box (1 per box) 30 | ymaxs = [] # List of normalized bottom y coordinates in bounding box (1 per box) 31 | classes_text = [] # List of string class name of bounding box (1 per box) 32 | classes = [] # List of integer class id of bounding box (1 per box) 33 | 34 | # per annotation 35 | for a in annotations: 36 | w = int(a["data"]["width"]) 37 | h = int(a["data"]["height"]) 38 | 39 | x_min = int(a["data"]["x"]) 40 | y_min = int(a["data"]["y"]) 41 | x_max = x_min + w 42 | y_max = y_min + h 43 | 44 | # WARN: these are normalized 45 | xmins.append(float(x_min / width)) 46 | xmaxs.append(float(x_max / width)) 47 | ymins.append(float(y_min / height)) 48 | ymaxs.append(float(y_max / height)) 49 | 50 | classes_text.append(a["labelId"].encode("utf8")) 51 | classes.append(classes_dict[a["labelId"]]["class_id"]) 52 | 53 | # print(classes) 54 | 55 | tf_example = tf.train.Example( 56 | features=tf.train.Features( 57 | feature={ 58 | "image/height": dataset_util.int64_feature(height), 59 | "image/width": dataset_util.int64_feature(width), 60 | "image/filename": dataset_util.bytes_feature(image_id.encode("utf8")), 61 | "image/source_id": dataset_util.bytes_feature(image_id.encode("utf8")), 62 | "image/key/sha256": dataset_util.bytes_feature(key.encode("utf8")), 63 | "image/encoded": dataset_util.bytes_feature(img_buffer.getvalue()), 64 | "image/format": dataset_util.bytes_feature("jpg".encode("utf8")), 65 | "image/object/bbox/xmin": dataset_util.float_list_feature(xmins), 66 | "image/object/bbox/xmax": dataset_util.float_list_feature(xmaxs), 67 | "image/object/bbox/ymin": dataset_util.float_list_feature(ymins), 68 | "image/object/bbox/ymax": dataset_util.float_list_feature(ymaxs), 69 | "image/object/class/text": dataset_util.bytes_list_feature(classes_text), 70 | "image/object/class/label": dataset_util.int64_list_feature(classes), 71 | } 72 | ) 73 | ) 74 | 75 | return tf_example 76 | 77 | 78 | def write_to_tfrecords(output_path, dataset): 79 | """Write images and annotations to tfrecords. 80 | Args: 81 | output_path (str): Output file path of the TFRecord. 82 | dataset (object): Mdai dataset object. 83 | Examples: 84 | 85 | >>> train_record_fp = os.path.abspath('./train.record') 86 | >>> export.write_to_tfrecords(train_record_fp, train_dataset, label_ids_dict) 87 | """ 88 | 89 | def _print_progress(count, total): 90 | # Percentage completion. 91 | pct_complete = float(count) / total 92 | 93 | # Status-message. 94 | # Note the \r which means the line should overwrite itself. 95 | msg = "\r- Progress: {0:.1%}".format(pct_complete) 96 | 97 | # Print it. 98 | sys.stdout.write(msg) 99 | sys.stdout.flush() 100 | 101 | print("\nOutput File Path: %s" % output_path) 102 | writer = tf.python_io.TFRecordWriter(output_path) 103 | num_images = len(dataset.image_ids) 104 | for i, image_id in enumerate(dataset.image_ids): 105 | _print_progress(count=i, total=num_images - 1) 106 | annotations = dataset.imgs_anns[image_id] 107 | tf_example = create_tf_bbox_example(annotations, image_id, dataset.classes_dict) 108 | writer.write(tf_example.SerializeToString()) 109 | writer.close() 110 | -------------------------------------------------------------------------------- /mdai/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import dicom2nifti 5 | import pydicom 6 | 7 | DEFAULT_IMAGE_SIZE = (512.0, 512.0) 8 | 9 | 10 | def apply_slope_intercept(dicom_file): 11 | """ 12 | Applies rescale slope and rescale intercept transformation. 13 | """ 14 | array = dicom_file.pixel_array.copy() 15 | 16 | scale_slope = 1 17 | scale_intercept = 0 18 | if "RescaleIntercept" in dicom_file: 19 | scale_intercept = int(dicom_file.RescaleIntercept) 20 | if "RescaleSlope" in dicom_file: 21 | scale_slope = int(dicom_file.RescaleSlope) 22 | array = array * scale_slope 23 | array = array + scale_intercept 24 | return array 25 | 26 | 27 | def remove_padding(array): 28 | """ 29 | Removes background/padding from an 8bit numpy array. 30 | """ 31 | arr = array.copy() 32 | nonzeros = np.nonzero(arr) 33 | x1 = np.min(nonzeros[0]) 34 | x2 = np.max(nonzeros[0]) 35 | y1 = np.min(nonzeros[1]) 36 | y2 = np.max(nonzeros[1]) 37 | return arr[x1:x2, y1:y2] 38 | 39 | 40 | def get_window_from_dicom(dicom_file): 41 | """ 42 | Returns window width and window center values. 43 | If no window width/level is provided or available, returns None. 44 | """ 45 | width, level = None, None 46 | if "WindowWidth" in dicom_file: 47 | width = dicom_file.WindowWidth 48 | if isinstance(width, pydicom.multival.MultiValue): 49 | width = int(width[0]) 50 | else: 51 | width = int(width) 52 | 53 | if "WindowCenter" in dicom_file: 54 | level = dicom_file.WindowCenter 55 | if isinstance(level, pydicom.multival.MultiValue): 56 | level = int(level[0]) 57 | else: 58 | level = int(level) 59 | return width, level 60 | 61 | 62 | def window(array, width, level): 63 | """ 64 | Applies windowing operation. 65 | If window width/level is None, returns the array itself. 66 | """ 67 | if width is not None and level is not None: 68 | array = np.clip(array, level - width // 2, level + width // 2) 69 | return array 70 | 71 | 72 | def rescale_to_8bit(array): 73 | """ 74 | Convert an array to 8bit (0-255). 75 | """ 76 | array = array - np.min(array) 77 | array = array / np.max(array) 78 | array = (array * 255).astype("uint8") 79 | return array 80 | 81 | 82 | def load_dicom_array(dicom_file, apply_slope_intercept=True): 83 | """ 84 | Returns the dicom image as a Numpy array. 85 | """ 86 | array = dicom_file.pixel_array.copy() 87 | if apply_slope_intercept: 88 | array = apply_slope_intercept(dicom_file) 89 | return array 90 | 91 | 92 | def convert_dicom_to_nifti(dicom_files, tempdir): 93 | """ 94 | Converts a dicom series to nifti format. 95 | Saves nifti in directory provided with filename as SeriesInstanceUID.nii.gz 96 | Returns a sorted list of dicom files based on image position patient. 97 | """ 98 | output_file = os.path.join(tempdir, dicom_files[0].SeriesInstanceUID + ".nii.gz") 99 | nifti_file = dicom2nifti.convert_dicom.dicom_array_to_nifti( 100 | dicom_files, output_file=output_file, reorient_nifti=True, 101 | ) 102 | return dicom2nifti.common.sort_dicoms(dicom_files) 103 | 104 | 105 | def convert_dicom_to_8bit(dicom_file, imsize=None, width=None, level=None, keep_padding=True): 106 | """ 107 | Given a DICOM file, window specifications, and image size, 108 | return the image as a Numpy array scaled to [0,255] of the specified size. 109 | """ 110 | if width is None or level is None: 111 | width, level = get_window_from_dicom(dicom_file) 112 | 113 | array = apply_slope_intercept(dicom_file) 114 | array = window(array, width, level) 115 | array = rescale_to_8bit(array) 116 | 117 | if ( 118 | "PhotometricInterpretation" in dicom_file 119 | and dicom_file.PhotometricInterpretation == "MONOCHROME1" 120 | ): 121 | array = 255 - array 122 | 123 | if not keep_padding: 124 | array = remove_padding(array) 125 | 126 | if imsize is not None: 127 | array = cv2.resize(array, imsize) 128 | return array 129 | 130 | 131 | def convert_to_RGB(array, imsize=None): 132 | """ 133 | Converts a single channel monochrome image to a 3 channel RGB image. 134 | """ 135 | img = np.stack((array,) * 3, axis=-1) 136 | if imsize is not None: 137 | img = cv2.resize(img, imsize) 138 | return img 139 | 140 | 141 | def convert_to_RGB_window(array, width, level, imsize=None): 142 | """ 143 | Converts a monochrome image to 3 channel RGB with windowing. 144 | Width and level can be lists for different values per channel. 145 | """ 146 | if type(width) is list and type(level) is list: 147 | R = window(array, width[0], level[0]) 148 | G = window(array, width[1], level[1]) 149 | B = window(array, width[2], level[2]) 150 | img = np.stack([R, G, B], axis=-1) 151 | else: 152 | R = window(array, width, level) 153 | img = np.stack((R,) * 3, axis=-1) 154 | 155 | if imsize is not None: 156 | img = cv2.resize(img, imsize) 157 | return img 158 | 159 | 160 | def stack_slices(dicom_files): 161 | """ 162 | Stacks the +-1 slice to each slice in a dicom series. 163 | Returns the list of stacked images and sorted list of dicom files. 164 | """ 165 | dicom_files = dicom2nifti.common.sort_dicoms(dicom_files) 166 | dicom_images = [load_dicom_array(i) for i in dicom_files] 167 | 168 | stacked_images = [] 169 | for i, file in enumerate(dicom_images): 170 | if i == 0: 171 | img = np.stack([dicom_images[i], dicom_images[i], dicom_images[i + 1]], axis=-1) 172 | stacked_images.append(img) 173 | elif i == len(dicom_files) - 1: 174 | img = np.stack([dicom_images[i - 1], dicom_images[i], dicom_images[i]], axis=-1) 175 | stacked_images.append(img) 176 | else: 177 | img = np.stack([dicom_images[i - 1], dicom_images[i], dicom_images[i + 1]], axis=-1) 178 | stacked_images.append(img) 179 | 180 | return stacked_images, dicom_files 181 | -------------------------------------------------------------------------------- /mdai/visualize.py: -------------------------------------------------------------------------------- 1 | import pydicom 2 | import numpy as np 3 | import colorsys 4 | import random 5 | import cv2 6 | from skimage.measure import find_contours 7 | import matplotlib.pyplot as plt 8 | from matplotlib import patches 9 | 10 | 11 | def random_colors(N, bright=True): 12 | """Generate random colors. 13 | To get visually distinct colors, generate them in HSV space then convert to RGB. 14 | 15 | Args: 16 | N (int): 17 | Number of colors. 18 | """ 19 | brightness = 1.0 if bright else 0.7 20 | hsv = [(i / N, 1, brightness) for i in range(N)] 21 | colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) 22 | random.shuffle(colors) 23 | return colors 24 | 25 | 26 | # based on functions in: https://github.com/matterport/Mask_RCNN/blob/master/mrcnn/visualize.py 27 | def display_images(image_ids, titles=None, cols=3, cmap="gray", norm=None, interpolation=None): 28 | """Display images given image ids. 29 | 30 | Args: 31 | image_ids (list): 32 | List of image ids. 33 | 34 | TODO: figsize should not be hardcoded 35 | """ 36 | titles = titles if titles is not None else [""] * len(image_ids) 37 | rows = len(image_ids) // cols + 1 38 | plt.figure(figsize=(14, 14 * rows // cols)) 39 | i = 1 40 | for image_id, title in zip(image_ids, titles): 41 | plt.subplot(rows, cols, i) 42 | plt.title(title, fontsize=9) 43 | plt.axis("off") 44 | 45 | image = load_dicom_image(image_id, rescale=True) 46 | plt.imshow(image, cmap=cmap, norm=norm, interpolation=interpolation) 47 | 48 | i += 1 49 | plt.show() 50 | 51 | 52 | def load_dicom_image(image_id, to_RGB=False, rescale=False): 53 | """Load a DICOM image. 54 | 55 | Args: 56 | image_id (str): 57 | image id (filepath). 58 | to_RGB (bool, optional): 59 | Convert grayscale image to RGB. 60 | 61 | Returns: 62 | image array. 63 | """ 64 | ds = pydicom.dcmread(image_id) 65 | try: 66 | image = ds.pixel_array 67 | except Exception: 68 | msg = ( 69 | "Could not read pixel array from DICOM with TransferSyntaxUID " 70 | + ds.file_meta.TransferSyntaxUID 71 | + ". Likely unsupported compression format." 72 | ) 73 | print(msg) 74 | 75 | if rescale: 76 | max_pixel_value = np.amax(image) 77 | min_pixel_value = np.amin(image) 78 | 79 | if max_pixel_value >= 255: 80 | # print("Input image pixel range exceeds 255, rescaling for visualization.") 81 | pixel_range = np.abs(max_pixel_value - min_pixel_value) 82 | pixel_range = pixel_range if pixel_range != 0 else 1 83 | image = image.astype(np.float32) / pixel_range * 255 84 | image = image.astype(np.uint8) 85 | 86 | if to_RGB: 87 | # If grayscale. Convert to RGB for consistency. 88 | if len(image.shape) != 3 or image.shape[2] != 3: 89 | image = np.stack((image,) * 3, -1) 90 | 91 | return image 92 | 93 | 94 | def load_mask(image_id, dataset): 95 | """Load instance masks for the given image. Masks can be different types, 96 | mask is a binary true/false map of the same size as the image. 97 | 98 | """ 99 | # annotations = imgs_anns[image_id] 100 | annotations = dataset.get_annotations_by_image_id(image_id) 101 | count = len(annotations) 102 | print("Number of annotations: %d" % count) 103 | 104 | image = load_dicom_image(image_id) 105 | width = image.shape[1] 106 | height = image.shape[0] 107 | 108 | if count == 0: 109 | print("No annotations") 110 | mask = np.zeros((height, width, 1), dtype=np.uint8) 111 | class_ids = np.zeros((1,), dtype=np.int32) 112 | else: 113 | mask = np.zeros((height, width, count), dtype=np.uint8) 114 | class_ids = np.zeros((count,), dtype=np.int32) 115 | 116 | for i, a in enumerate(annotations): 117 | label_id = a["labelId"] 118 | annotation_mode = dataset.label_id_to_class_annotation_mode(label_id) 119 | # print(annotation_mode) 120 | 121 | if annotation_mode == "bbox": 122 | # Bounding Box 123 | x = int(a["data"]["x"]) 124 | y = int(a["data"]["y"]) 125 | w = int(a["data"]["width"]) 126 | h = int(a["data"]["height"]) 127 | mask_instance = mask[:, :, i].copy() 128 | cv2.rectangle(mask_instance, (x, y), (x + w, y + h), 255, -1) 129 | mask[:, :, i] = mask_instance 130 | 131 | # FreeForm or Polygon 132 | elif annotation_mode == "freeform" or annotation_mode == "polygon": 133 | vertices = np.array(a["data"]["vertices"]) 134 | vertices = vertices.reshape((-1, 2)) 135 | mask_instance = mask[:, :, i].copy() 136 | cv2.fillPoly(mask_instance, np.int32([vertices]), (255, 255, 255)) 137 | mask[:, :, i] = mask_instance 138 | 139 | # Line 140 | elif annotation_mode == "line": 141 | vertices = np.array(a["data"]["vertices"]) 142 | vertices = vertices.reshape((-1, 2)) 143 | mask_instance = mask[:, :, i].copy() 144 | cv2.polylines(mask_instance, np.int32([vertices]), False, (255, 255, 255), 12) 145 | mask[:, :, i] = mask_instance 146 | 147 | elif annotation_mode == "location": 148 | # Bounding Box 149 | x = int(a["data"]["x"]) 150 | y = int(a["data"]["y"]) 151 | mask_instance = mask[:, :, i].copy() 152 | cv2.circle(mask_instance, (x, y), 7, (255, 255, 255), -1) 153 | mask[:, :, i] = mask_instance 154 | 155 | elif annotation_mode == "mask": 156 | mask_instance = mask[:, :, i].copy() 157 | if a.data["foreground"]: 158 | for i in a.data["foreground"]: 159 | mask_instance = cv2.fillPoly( 160 | mask_instance, [np.array(i, dtype=np.int32)], (255, 255, 255) 161 | ) 162 | if a.data["background"]: 163 | for i in a.data["background"]: 164 | mask_instance = cv2.fillPoly( 165 | mask_instance, [np.array(i, dtype=np.int32)], (0, 0, 0) 166 | ) 167 | mask[:, :, i] = mask_instance 168 | 169 | elif annotation_mode is None: 170 | print("Not a local instance") 171 | 172 | # load class id 173 | class_ids[i] = dataset.label_id_to_class_id(label_id) 174 | 175 | return mask.astype(bool), class_ids.astype(np.int32) 176 | 177 | 178 | def apply_mask(image, mask, color, alpha=0.3): 179 | """Apply the given mask to the image. 180 | 181 | Args: 182 | image: height, widht, channel. 183 | 184 | Returns: 185 | image with applied color mask. 186 | """ 187 | for c in range(3): 188 | image[:, :, c] = np.where( 189 | mask == 1, image[:, :, c] * (1 - alpha) + alpha * color[c] * 255, image[:, :, c] 190 | ) 191 | return image 192 | 193 | 194 | def extract_bboxes(mask): 195 | """Compute bounding boxes from masks. 196 | 197 | Args: 198 | mask [height, width, num_instances]: 199 | Mask pixels are either 1 or 0. 200 | 201 | Returns: 202 | bounding box array [num_instances, (y1, x1, y2, x2)]. 203 | """ 204 | boxes = np.zeros([mask.shape[-1], 4], dtype=np.int32) 205 | for i in range(mask.shape[-1]): 206 | m = mask[:, :, i] 207 | # Bounding box. 208 | horizontal_indicies = np.where(np.any(m, axis=0))[0] 209 | vertical_indicies = np.where(np.any(m, axis=1))[0] 210 | if horizontal_indicies.shape[0]: 211 | x1, x2 = horizontal_indicies[[0, -1]] 212 | y1, y2 = vertical_indicies[[0, -1]] 213 | # x2 and y2 should not be part of the box. Increment by 1. 214 | x2 += 1 215 | y2 += 1 216 | else: 217 | # No mask for this instance. Might happen due to 218 | # resizing or cropping. Set bbox to zeros 219 | x1, x2, y1, y2 = 0, 0, 0, 0 220 | boxes[i] = np.array([y1, x1, y2, x2]) 221 | return boxes.astype(np.int32) 222 | 223 | 224 | def get_image_ground_truth(image_id, dataset): 225 | """Load and return ground truth data for an image (image, mask, bounding boxes). 226 | 227 | Args: 228 | image_id: 229 | Image id. 230 | 231 | Returns: 232 | image: 233 | [height, width, 3] 234 | class_ids: 235 | [instance_count] Integer class IDs 236 | bbox: 237 | [instance_count, (y1, x1, y2, x2)] 238 | mask: 239 | [height, width, instance_count]. The height and width are those of the image unless 240 | use_mini_mask is True, in which case they are defined in MINI_MASK_SHAPE. 241 | """ 242 | # image = load_dicom_image(image_id, to_RGB=True) 243 | image = load_dicom_image(image_id, to_RGB=True, rescale=True) 244 | 245 | mask, class_ids = load_mask(image_id, dataset) 246 | 247 | _idx = np.sum(mask, axis=(0, 1)) > 0 248 | mask = mask[:, :, _idx] 249 | class_ids = class_ids[_idx] 250 | 251 | # Bounding boxes. Note that some boxes might be all zeros 252 | # if the corresponding mask got cropped out. 253 | # bbox: [num_instances, (y1, x1, y2, x2)] 254 | bbox = extract_bboxes(mask) 255 | 256 | return image, class_ids, bbox, mask 257 | 258 | 259 | def display_annotations( 260 | image, 261 | boxes, 262 | masks, 263 | class_ids, 264 | scores=None, 265 | title="", 266 | figsize=(16, 16), 267 | ax=None, 268 | show_mask=True, 269 | show_bbox=True, 270 | colors=None, 271 | captions=None, 272 | ): 273 | """Display annotations for image. 274 | 275 | Args: 276 | boxes: 277 | [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates. 278 | masks: 279 | [height, width, num_instances] 280 | class_ids: 281 | [num_instances] 282 | scores: 283 | (optional) confidence scores for each box 284 | title: 285 | (optional) Figure title 286 | show_mask, show_bbox: 287 | To show masks and bounding boxes or not 288 | figsize: 289 | (optional) the size of the image 290 | colors: 291 | (optional) An array or colors to use with each object 292 | captions: 293 | (optional) A list of strings to use as captions for each object 294 | """ 295 | 296 | # Number of instancesload_mask 297 | N = boxes.shape[0] 298 | if not N: 299 | print("\n*** No instances to display *** \n") 300 | else: 301 | assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0] 302 | 303 | # If no axis is passed, create one and automatically call show() 304 | auto_show = False 305 | if not ax: 306 | _, ax = plt.subplots(1, figsize=figsize) 307 | auto_show = True 308 | 309 | # Generate random colors 310 | colors = colors or random_colors(N) 311 | 312 | # Show area outside image boundaries. 313 | height, width = image.shape[:2] 314 | ax.set_ylim(height + 10, -10) 315 | ax.set_xlim(-10, width + 10) 316 | ax.axis("off") 317 | ax.set_title(title) 318 | 319 | masked_image = image.astype(np.uint32).copy() 320 | for i in range(N): 321 | color = colors[i] 322 | 323 | # Bounding box 324 | if not np.any(boxes[i]): 325 | # Skip this instance. Has no bbox. Likely lost in image cropping. 326 | continue 327 | y1, x1, y2, x2 = boxes[i] 328 | if show_bbox: 329 | p = patches.Rectangle( 330 | (x1, y1), 331 | x2 - x1, 332 | y2 - y1, 333 | linewidth=2, 334 | alpha=0.7, 335 | linestyle="dashed", 336 | edgecolor=color, 337 | facecolor="none", 338 | ) 339 | ax.add_patch(p) 340 | 341 | # Label 342 | if not captions: 343 | class_id = class_ids[i] 344 | score = scores[i] if scores is not None else None 345 | 346 | label = class_id 347 | caption = "{} {:.3f}".format(label, score) if score else label 348 | else: 349 | caption = captions[i] 350 | ax.text(x1, y1 + 8, caption, color="w", size=11, backgroundcolor="none") 351 | 352 | # Mask 353 | mask = masks[:, :, i] 354 | if show_mask: 355 | masked_image = apply_mask(masked_image, mask, color) 356 | 357 | # Mask Polygon 358 | # Pad to ensure proper polygons for masks that touch image edges. 359 | padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8) 360 | padded_mask[1:-1, 1:-1] = mask 361 | contours = find_contours(padded_mask, 0.5) 362 | for verts in contours: 363 | # Subtract the padding and flip (y, x) to (x, y) 364 | verts = np.fliplr(verts) - 1 365 | p = patches.Polygon(verts, facecolor="none", edgecolor=color) 366 | ax.add_patch(p) 367 | ax.imshow(masked_image.astype(np.uint8)) 368 | # ax.imshow(masked_image) 369 | if auto_show: 370 | plt.show() 371 | 372 | 373 | def draw_box_on_image(image, boxes, h, w): 374 | """Draw box on an image. 375 | 376 | Args: 377 | image: 378 | three channel (e.g. RGB) image. 379 | boxes: 380 | normalized box coordinate (between 0.0 and 1.0). 381 | h: 382 | image height 383 | w: 384 | image width 385 | """ 386 | 387 | for i in range(len(boxes)): 388 | (left, right, top, bottom) = ( 389 | boxes[i][0] * w, 390 | boxes[i][2] * w, 391 | boxes[i][1] * h, 392 | boxes[i][3] * h, 393 | ) 394 | p1 = (int(left), int(top)) 395 | p2 = (int(right), int(bottom)) 396 | cv2.rectangle(image, p1, p2, (77, 255, 9), 3, 1) 397 | -------------------------------------------------------------------------------- /notebooks/MDai_Simple_API.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Using the MDai API\n", 8 | "### Get annotations, create a Pandas Dataframe, create csv, import new labels." 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "metadata": {}, 15 | "outputs": [ 16 | { 17 | "name": "stdout", 18 | "output_type": "stream", 19 | "text": [ 20 | "\u001b[33mYou are using pip version 19.0.3, however version 19.1.1 is available.\r\n", 21 | "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\r\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "!pip3 install -q --upgrade mdai" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": null, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "import mdai\n", 36 | "mdai.__version__" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "### We need some variables. \n", 44 | "- DOMAIN is the base portion of the project url eg. company.md.ai\n", 45 | "- ACCESS_TOKEN can be obtained from User Icon -> User Settings -> Personal Access Tokens\n", 46 | "- PROJECT_ID is shown via the info icon on the left of the Annotator" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "DOMAIN = 'Fill this in with the appropriate value'\n", 56 | "ACCESS_TOKEN = 'Fill this in with the appropriate value'\n", 57 | "PROJECT_ID = 'Fill this in with the appropriate value'" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "### Create the MDai Client" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "mdai_client = mdai.Client(domain=DOMAIN, access_token=ACCESS_TOKEN)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "### Download the annotations for the project" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "mdai_client.project(PROJECT_ID, annotations_only=True)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 2, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "import json\n", 99 | "import pandas as pd" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "### Another variable - JSON_FILE\n", 107 | "- The project you create prints out the filename of the json annotations file.\n", 108 | "- Insert that into JSON_FILE variable" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 3, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "JSON_FILE = 'Fill this in with the appropriate value'\n", 118 | "with open(JSON_FILE, 'r') as f:\n", 119 | " data = json.load(f)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 6, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/html": [ 130 | "
\n", 131 | "\n", 144 | "\n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | "
SOPInstanceUIDSeriesInstanceUIDStudyInstanceUIDannotationNumbercreatedAtcreatedByIddataheightidlabelId...reviewsPositiveCountupdatedAtupdatedByIdwidthdatasetlabelGroupIdlabelGroupNamelabelNameannotationModenumber
0NaNNaN2.16.840.1.114274.4504117462496947907456710035...None2018-12-03T16:43:14.886ZU_KymW3ENoneNoneA_a4dnoVL_2l2P0B...02018-12-03T16:43:14.894ZU_KymW3ENoneDatasetG_76xRlaGroup 2NeckNone29
1NaNNaN2.16.840.1.114274.4504117462496947907456710035...None2018-12-03T16:43:11.872ZU_KymW3ENoneNoneA_W497pVL_8logl7...02018-12-03T16:43:11.879ZU_KymW3ENoneDatasetG_76xRlaGroup 2HeadNone29
2NaNNaN2.16.840.1.114274.4504117462496947907456710035...None2018-12-04T17:24:46.380ZU_6y5LdVNoneNoneA_YNjoA4L_rj6Vjm...02018-12-04T17:24:46.468ZU_6y5LdVNoneDatasetG_gnD1lmGroup 3HeadNone29
3NaNNaN2.16.840.1.114274.6360863111461398062925621323...None2018-12-03T16:25:03.432ZU_KymW3ENoneNoneA_OrPZNkL_2l2P0B...02018-12-03T16:25:03.440ZU_KymW3ENoneDatasetG_76xRlaGroup 2NeckNone20
4NaNNaN2.16.840.1.114274.6360863111461398062925621323...None2018-12-04T17:23:55.043ZU_6y5LdVNoneNoneA_KVpjz4L_Mj51jN...02018-12-04T17:23:55.049ZU_6y5LdVNoneDatasetG_gnD1lmGroup 3NeckNone20
\n", 294 | "

5 rows × 24 columns

\n", 295 | "
" 296 | ], 297 | "text/plain": [ 298 | " SOPInstanceUID SeriesInstanceUID \\\n", 299 | "0 NaN NaN \n", 300 | "1 NaN NaN \n", 301 | "2 NaN NaN \n", 302 | "3 NaN NaN \n", 303 | "4 NaN NaN \n", 304 | "\n", 305 | " StudyInstanceUID annotationNumber \\\n", 306 | "0 2.16.840.1.114274.4504117462496947907456710035... None \n", 307 | "1 2.16.840.1.114274.4504117462496947907456710035... None \n", 308 | "2 2.16.840.1.114274.4504117462496947907456710035... None \n", 309 | "3 2.16.840.1.114274.6360863111461398062925621323... None \n", 310 | "4 2.16.840.1.114274.6360863111461398062925621323... None \n", 311 | "\n", 312 | " createdAt createdById data height id labelId \\\n", 313 | "0 2018-12-03T16:43:14.886Z U_KymW3E None None A_a4dnoV L_2l2P0B \n", 314 | "1 2018-12-03T16:43:11.872Z U_KymW3E None None A_W497pV L_8logl7 \n", 315 | "2 2018-12-04T17:24:46.380Z U_6y5LdV None None A_YNjoA4 L_rj6Vjm \n", 316 | "3 2018-12-03T16:25:03.432Z U_KymW3E None None A_OrPZNk L_2l2P0B \n", 317 | "4 2018-12-04T17:23:55.043Z U_6y5LdV None None A_KVpjz4 L_Mj51jN \n", 318 | "\n", 319 | " ... reviewsPositiveCount updatedAt updatedById width \\\n", 320 | "0 ... 0 2018-12-03T16:43:14.894Z U_KymW3E None \n", 321 | "1 ... 0 2018-12-03T16:43:11.879Z U_KymW3E None \n", 322 | "2 ... 0 2018-12-04T17:24:46.468Z U_6y5LdV None \n", 323 | "3 ... 0 2018-12-03T16:25:03.440Z U_KymW3E None \n", 324 | "4 ... 0 2018-12-04T17:23:55.049Z U_6y5LdV None \n", 325 | "\n", 326 | " dataset labelGroupId labelGroupName labelName annotationMode number \n", 327 | "0 Dataset G_76xRla Group 2 Neck None 29 \n", 328 | "1 Dataset G_76xRla Group 2 Head None 29 \n", 329 | "2 Dataset G_gnD1lm Group 3 Head None 29 \n", 330 | "3 Dataset G_76xRla Group 2 Neck None 20 \n", 331 | "4 Dataset G_gnD1lm Group 3 Neck None 20 \n", 332 | "\n", 333 | "[5 rows x 24 columns]" 334 | ] 335 | }, 336 | "execution_count": 6, 337 | "metadata": {}, 338 | "output_type": "execute_result" 339 | } 340 | ], 341 | "source": [ 342 | "def unpackDictionary(df, column):\n", 343 | " ret = None\n", 344 | " ret = pd.concat([df, pd.DataFrame((d for idx, d in df[column].items()))], axis=1)\n", 345 | " del ret[column]\n", 346 | " return ret\n", 347 | "\n", 348 | "a = pd.DataFrame([])\n", 349 | "studies = pd.DataFrame([])\n", 350 | "\n", 351 | "# Gets annotations for all datasets\n", 352 | "for d in data['datasets']:\n", 353 | " annotations = pd.DataFrame(d['annotations'])\n", 354 | " annotations['dataset'] = d['name']\n", 355 | " study = pd.DataFrame(d['studies'])\n", 356 | " study['dataset'] = d['name']\n", 357 | " a = a.append(annotations,ignore_index=True)\n", 358 | " studies = studies.append(study,ignore_index=True)\n", 359 | "\n", 360 | "studies = studies[['StudyInstanceUID', 'dataset', 'number']]\n", 361 | "g = pd.DataFrame(data['labelGroups'])\n", 362 | "\n", 363 | "#unpack arrays\n", 364 | "result = pd.DataFrame([(d, tup.id, tup.name) for tup in g.itertuples() for d in tup.labels])\n", 365 | "result.columns = ['labels','id','name']\n", 366 | "\n", 367 | "labelGroups = unpackDictionary(result, 'labels')\n", 368 | "labelGroups = labelGroups[['id','name','annotationMode']]\n", 369 | "labelGroups.columns = ['labelId','labelGroupId','labelGroupName','labelName','annotationMode']\n", 370 | "\n", 371 | "a = a.merge(labelGroups, on='labelId')\n", 372 | "a = a.merge(studies[['StudyInstanceUID', 'number']], on='StudyInstanceUID')\n", 373 | "a.head()" 374 | ] 375 | }, 376 | { 377 | "cell_type": "code", 378 | "execution_count": 7, 379 | "metadata": {}, 380 | "outputs": [ 381 | { 382 | "data": { 383 | "text/plain": [ 384 | "Index(['SOPInstanceUID', 'SeriesInstanceUID', 'StudyInstanceUID',\n", 385 | " 'annotationNumber', 'createdAt', 'createdById', 'data', 'height', 'id',\n", 386 | " 'labelId', 'modelId', 'note', 'radlexTagIds',\n", 387 | " 'reviewsNegativeCount', 'reviewsPositiveCount', 'updatedAt',\n", 388 | " 'updatedById', 'width', 'dataset', 'labelGroupId', 'labelGroupName', 'labelName',\n", 389 | " 'annotationMode', 'number'],\n", 390 | " dtype='object')" 391 | ] 392 | }, 393 | "execution_count": 7, 394 | "metadata": {}, 395 | "output_type": "execute_result" 396 | } 397 | ], 398 | "source": [ 399 | "a.columns" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "# Create csv" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "a.to_csv(\"annotations.csv\", index=False)" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "metadata": {}, 421 | "source": [ 422 | "# Importing Annotations or Predictions" 423 | ] 424 | }, 425 | { 426 | "cell_type": "markdown", 427 | "metadata": {}, 428 | "source": [ 429 | "### Importing labels, we need some more variables\n", 430 | "- LABEL_ID - create a new label, show id using Label Controls, copy and fill in\n", 431 | "- DATASET_ID s shown via the info icon on the left of the Annotator\n", 432 | "- MODEL_ID - go to the Models tab on the left, create a new Model and name it or use the id for a prior model" 433 | ] 434 | }, 435 | { 436 | "cell_type": "code", 437 | "execution_count": null, 438 | "metadata": {}, 439 | "outputs": [], 440 | "source": [ 441 | "LABEL_ID = 'Fill this in with the appropriate value'\n", 442 | "DATASET_ID = 'Fill this in with the appropriate value'\n", 443 | "MODEL_ID = 'Fill this in with the appropriate value'" 444 | ] 445 | }, 446 | { 447 | "cell_type": "markdown", 448 | "metadata": {}, 449 | "source": [ 450 | "### Create subset" 451 | ] 452 | }, 453 | { 454 | "cell_type": "code", 455 | "execution_count": null, 456 | "metadata": {}, 457 | "outputs": [], 458 | "source": [ 459 | "#For example, get all exams with MLA\n", 460 | "subset = a[~a.modelId.isnull()]" 461 | ] 462 | }, 463 | { 464 | "cell_type": "markdown", 465 | "metadata": {}, 466 | "source": [ 467 | "### Create imported annotations dictionary\n", 468 | "- Use correct format for type and scope of annotation|" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [ 477 | "#For example, this is a global label at the exam level\n", 478 | "annotations = []\n", 479 | "for i,row in subset.iterrows():\n", 480 | " annotations.append( {\n", 481 | " 'labelId': LABEL_ID,\n", 482 | " 'StudyInstanceUID': row.StudyInstanceUID\n", 483 | " })\n", 484 | " \n", 485 | "len(annotations)" 486 | ] 487 | }, 488 | { 489 | "cell_type": "markdown", 490 | "metadata": {}, 491 | "source": [ 492 | "### Or use this for bounding boxes" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "metadata": {}, 499 | "outputs": [], 500 | "source": [ 501 | "# local image-scoped label where annotation mode is 'bbox' (Bounding Box)\n", 502 | "annotations = []\n", 503 | "for i,row in subset.iterrows():\n", 504 | " annotations.append( {\n", 505 | " 'labelId': LABEL_ID,\n", 506 | " 'SOPInstanceUID': row.StudyInstanceUID,\n", 507 | " 'data': {'x': 200, 'y': 200, 'width': 200, 'height': 400}\n", 508 | " })\n", 509 | " \n", 510 | "len(annotations)" 511 | ] 512 | }, 513 | { 514 | "cell_type": "markdown", 515 | "metadata": {}, 516 | "source": [ 517 | "### Import your annotations\n", 518 | "- If all works, you should see the labels in your project Progress " 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": null, 524 | "metadata": {}, 525 | "outputs": [], 526 | "source": [ 527 | "#Import\n", 528 | "mdai_client.load_model_annotations(PROJECT_ID, DATASET_ID, MODEL_ID, annotations)" 529 | ] 530 | } 531 | ], 532 | "metadata": { 533 | "kernelspec": { 534 | "display_name": "Python 3", 535 | "language": "python", 536 | "name": "python3" 537 | }, 538 | "language_info": { 539 | "codemirror_mode": { 540 | "name": "ipython", 541 | "version": 3 542 | }, 543 | "file_extension": ".py", 544 | "mimetype": "text/x-python", 545 | "name": "python", 546 | "nbconvert_exporter": "python", 547 | "pygments_lexer": "ipython3", 548 | "version": "3.7.1" 549 | } 550 | }, 551 | "nbformat": 4, 552 | "nbformat_minor": 2 553 | } 554 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | 4 | [tool.ruff] 5 | line-length = 100 6 | # Enable Pyflakes and pycodestyle rules. 7 | select = ["E", "F"] 8 | # E501 is the "Line too long" error. We disable it because we use Black for 9 | # code formatting. Black makes a best effort to keep lines under the max 10 | # length, but can go over in some cases. 11 | # E203: Whitespace before ':'. Conflicts with black formatting. 12 | # E231: Missing whitespace after ',', ';', or ':'. Conflicts with black formatting. 13 | ignore = ["E501", "E203", "E231"] 14 | 15 | [tool.poetry] 16 | name = "mdai" 17 | version = "0.15.1" 18 | description = "MD.ai Python client library" 19 | license = "Apache-2.0" 20 | authors = ["MD.ai "] 21 | readme = "README.md" 22 | homepage = "https://github.com/mdai/mdai-client-py" 23 | repository = "https://github.com/mdai/mdai-client-py" 24 | documentation = "https://docs.md.ai/annotator/python/installation/" 25 | classifiers = [ 26 | "Intended Audience :: Developers", 27 | "Intended Audience :: Education", 28 | "Intended Audience :: Healthcare Industry", 29 | "Intended Audience :: Science/Research", 30 | "License :: OSI Approved :: Apache Software License", 31 | "Programming Language :: Python :: 3", 32 | "Programming Language :: Python :: 3.8", 33 | "Programming Language :: Python :: 3.9", 34 | "Programming Language :: Python :: 3.10", 35 | "Programming Language :: Python :: 3.11", 36 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 37 | "Topic :: Scientific/Engineering :: Medical Science Apps.", 38 | "Topic :: Software Development :: Libraries", 39 | "Topic :: Software Development :: Libraries :: Python Modules", 40 | ] 41 | 42 | [tool.poetry.dependencies] 43 | python = ">=3.8" 44 | arrow = "^1.3.0" 45 | matplotlib = "^3.7.3" 46 | nibabel = "^5.2.1" 47 | numpy = "^1.24.0" 48 | opencv-python = "^4.8.1.78" 49 | pandas = "^2.0.0" 50 | pillow = "^10.0.0" 51 | pydicom = "^2.4.0" 52 | requests = "^2.32.0" 53 | retrying = "^1.3.4" 54 | scikit-image = ">=0.21.0, <1.0.0" 55 | tqdm = "^4.66.5" 56 | dicom2nifti = "<2.6.0" 57 | PyYAML = "^6.0.2" 58 | 59 | [tool.poetry.group.dev.dependencies] 60 | black = "23.3.0" 61 | ruff = "0.0.272" 62 | pytest = "*" 63 | sphinx = "*" 64 | recommonmark = "*" 65 | 66 | [build-system] 67 | requires = ["poetry-core"] 68 | build-backend = "poetry.core.masonry.api" 69 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdai/mdai-client-py/28ab35436f3db1fb056b1e8d3bc13d9ae4c5c555/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import shutil 4 | import requests 5 | import pytest 6 | 7 | from mdai.preprocess import Project 8 | 9 | FIXTURES_BASE_URL = "https://storage.googleapis.com/mdai-app-data/test-fixtures/mdai-client-py/" 10 | IMG_FILE = "mdai_staging_project_bwRnkNW2_images_2018-08-25-192424.zip" 11 | ANNO_FILE = "mdai_staging_project_bwRnkNW2_annotations_labelgroup_all_2018-08-25-204133.json" 12 | TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data") 13 | 14 | 15 | def download_file(url): 16 | local_filename = os.path.join(TEST_DATA_DIR, url.split("/")[-1]) 17 | r = requests.get(url, stream=True) 18 | if r.status_code == requests.codes.ok: 19 | with open(local_filename, "wb") as f: 20 | shutil.copyfileobj(r.raw, f) 21 | return local_filename 22 | else: 23 | r.raise_for_status() 24 | 25 | 26 | @pytest.fixture 27 | def p(): 28 | 29 | os.makedirs(TEST_DATA_DIR, exist_ok=True) 30 | annotations_fp = download_file(FIXTURES_BASE_URL + ANNO_FILE) 31 | images_dir_zipped = download_file(FIXTURES_BASE_URL + IMG_FILE) 32 | with zipfile.ZipFile(images_dir_zipped) as zf: 33 | zf.extractall(TEST_DATA_DIR) 34 | (images_dir, ext) = os.path.splitext(images_dir_zipped) 35 | 36 | p = Project(annotations_fp=annotations_fp, images_dir=images_dir) 37 | return p 38 | -------------------------------------------------------------------------------- /tests/test_preprocess.py: -------------------------------------------------------------------------------- 1 | from mdai.utils.common_utils import train_test_split 2 | 3 | 4 | def test_project(p): 5 | 6 | # label groups 7 | label_groups = p.get_label_groups() 8 | 9 | assert label_groups[0].id == "G_L3dP31" 10 | assert label_groups[1].id == "G_WVRrVJ" 11 | 12 | 13 | def test_dataset(p): 14 | 15 | # two datasets 16 | datasets = p.get_datasets() 17 | assert len(datasets) == 2 18 | 19 | labels_dict = { 20 | "L_egJRyg": 1, # bounding box 21 | "L_MgevP2": 2, # polygon 22 | "L_D21YL2": 3, # freeform 23 | "L_lg7klg": 4, # line 24 | "L_eg69RZ": 5, # location 25 | "L_GQoaJg": 6, # global_image 26 | "L_JQVWjZ": 7, # global_series 27 | "L_3QEOpg": 8, # global_exam 28 | } 29 | p.set_labels_dict(labels_dict) 30 | 31 | assert p.get_label_id_annotation_mode("L_MgevP2") == "polygon" 32 | assert p.get_label_id_annotation_mode("L_3QEOpg") is None 33 | 34 | ct_dataset = p.get_dataset_by_id("D_qGQdpN") 35 | ct_dataset.prepare() 36 | 37 | xray_dataset = p.get_dataset_by_id("D_0Z4nDG") 38 | xray_dataset.prepare() 39 | 40 | assert ct_dataset.classes_dict == xray_dataset.classes_dict 41 | 42 | image_ids = ct_dataset.get_image_ids() 43 | assert len(image_ids) == len(ct_dataset.imgs_anns_dict.keys()) 44 | 45 | image_id = ct_dataset.get_image_ids()[7] 46 | 47 | ann_mode = [ 48 | (ct_dataset.label_id_to_class_annotation_mode(ann["labelId"]), ann["labelId"]) 49 | for ann in ct_dataset.imgs_anns_dict[image_id] 50 | ] 51 | 52 | assert ann_mode == [ 53 | ("line", "L_lg7klg"), 54 | ("polygon", "L_MgevP2"), 55 | ("freeform", "L_D21YL2"), 56 | ("bbox", "L_egJRyg"), 57 | ("location", "L_eg69RZ"), 58 | (None, "L_3QEOpg"), 59 | ] 60 | 61 | train_ds, valid_ds = train_test_split(ct_dataset, shuffle=False, validation_split=0.2) 62 | 63 | assert len(train_ds.get_image_ids()) == 9 64 | assert len(valid_ds.get_image_ids()) == 3 65 | -------------------------------------------------------------------------------- /tests/test_visualize.py: -------------------------------------------------------------------------------- 1 | from mdai import visualize 2 | import numpy as np 3 | 4 | # TODO: test load_dicom_image (with RGB option or not) 5 | 6 | 7 | def test_visualize(p): 8 | labels_dict = { 9 | "L_egJRyg": 1, # bounding box 10 | "L_MgevP2": 2, # polygon 11 | "L_D21YL2": 3, # freeform 12 | "L_lg7klg": 4, # line 13 | "L_eg69RZ": 5, # location 14 | "L_GQoaJg": 6, # global_image 15 | "L_JQVWjZ": 7, # global_series 16 | "L_3QEOpg": 8, # global_exam 17 | } 18 | 19 | p.set_labels_dict(labels_dict) 20 | 21 | ct_dataset = p.get_dataset_by_id("D_qGQdpN") 22 | ct_dataset.prepare() 23 | 24 | # image with multiple annotations 25 | image_id = ct_dataset.get_image_ids()[7] 26 | 27 | grey_image = visualize.load_dicom_image(image_id) 28 | rgb_image = visualize.load_dicom_image(image_id, to_RGB=True) 29 | scaled_image_1 = visualize.load_dicom_image(image_id, rescale=True) 30 | 31 | assert np.amax(grey_image) == 701 32 | assert np.amax(scaled_image_1) == 255 33 | 34 | assert grey_image.shape == (256, 256) 35 | assert rgb_image.shape == (256, 256, 3) 36 | 37 | scaled_image_2, gt_class_id, gt_bbox, gt_mask = visualize.get_image_ground_truth( 38 | image_id, ct_dataset 39 | ) 40 | 41 | assert len(gt_class_id) == len(gt_class_id) 42 | assert gt_mask.shape == (256, 256, 5) 43 | --------------------------------------------------------------------------------