├── .github
└── workflows
│ └── test.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.rst
├── chat-assistant
├── Dockerfile
├── OPENAI_API_KEY.txt
├── README.md
├── code-corpus-generation.py
├── corpus
│ ├── pyhealth-code.pkl
│ ├── pyhealth-code.txt
│ ├── pyhealth-text.pkl
│ ├── pyhealth-text.txt
│ └── pyhealth.txt
├── ingest.py
├── main.py
├── prompts
│ ├── __init__.py
│ ├── introduction_prompt.py
│ ├── qa_prompt.py
│ └── summary_prompt.py
├── qa_chain.py
├── query_data.py
├── requirements.txt
└── text-corpus-generation.py
├── docs
├── Makefile
├── _static
│ ├── css
│ │ ├── overwrite.css
│ │ └── sphinx_gallery.css
│ └── pyhealth_logos
│ │ └── pyhealth-logo.png
├── about.rst
├── advance_tutorials.rst
├── api
│ ├── calib.rst
│ ├── calib
│ │ ├── pyhealth.calib.calibration.rst
│ │ └── pyhealth.calib.predictionset.rst
│ ├── data.rst
│ ├── data
│ │ ├── pyhealth.data.Event.rst
│ │ └── pyhealth.data.Patient.rst
│ ├── datasets.rst
│ ├── datasets
│ │ ├── pyhealth.datasets.BaseDataset.rst
│ │ ├── pyhealth.datasets.CardiologyDataset.rst
│ │ ├── pyhealth.datasets.ISRUCDataset.rst
│ │ ├── pyhealth.datasets.MIMIC3Dataset.rst
│ │ ├── pyhealth.datasets.MIMIC4Dataset.rst
│ │ ├── pyhealth.datasets.MIMICExtractDataset.rst
│ │ ├── pyhealth.datasets.MedicalTranscriptionsDataset.rst
│ │ ├── pyhealth.datasets.OMOPDataset.rst
│ │ ├── pyhealth.datasets.SHHSDataset.rst
│ │ ├── pyhealth.datasets.SampleDataset.rst
│ │ ├── pyhealth.datasets.SleepEDFDataset.rst
│ │ ├── pyhealth.datasets.TUABDataset.rst
│ │ ├── pyhealth.datasets.TUEVDataset.rst
│ │ ├── pyhealth.datasets.eICUDataset.rst
│ │ ├── pyhealth.datasets.splitter.rst
│ │ └── pyhealth.datasets.utils.rst
│ ├── medcode.rst
│ ├── metrics.rst
│ ├── metrics
│ │ ├── pyhealth.metrics.binary.rst
│ │ ├── pyhealth.metrics.calibration.rst
│ │ ├── pyhealth.metrics.fairness.rst
│ │ ├── pyhealth.metrics.multiclass.rst
│ │ ├── pyhealth.metrics.multilabel.rst
│ │ └── pyhealth.metrics.prediction_set.rst
│ ├── models.rst
│ ├── models
│ │ ├── pyhealth.models.AdaCare.rst
│ │ ├── pyhealth.models.Agent.rst
│ │ ├── pyhealth.models.BaseModel.rst
│ │ ├── pyhealth.models.CNN.rst
│ │ ├── pyhealth.models.ConCare.rst
│ │ ├── pyhealth.models.ContraWR.rst
│ │ ├── pyhealth.models.Deepr.rst
│ │ ├── pyhealth.models.GAMENet.rst
│ │ ├── pyhealth.models.GAN.rst
│ │ ├── pyhealth.models.GNN.rst
│ │ ├── pyhealth.models.GRASP.rst
│ │ ├── pyhealth.models.MICRON.rst
│ │ ├── pyhealth.models.MLP.rst
│ │ ├── pyhealth.models.MedLink.rst
│ │ ├── pyhealth.models.MoleRec.rst
│ │ ├── pyhealth.models.RETAIN.rst
│ │ ├── pyhealth.models.RNN.rst
│ │ ├── pyhealth.models.SafeDrug.rst
│ │ ├── pyhealth.models.SparcNet.rst
│ │ ├── pyhealth.models.StageNet.rst
│ │ ├── pyhealth.models.TCN.rst
│ │ ├── pyhealth.models.Transformer.rst
│ │ ├── pyhealth.models.TransformersModel.rst
│ │ └── pyhealth.models.VAE.rst
│ ├── tasks.rst
│ ├── tasks
│ │ ├── pyhealth.tasks.BaseTask.rst
│ │ ├── pyhealth.tasks.COVID19CXRClassification.rst
│ │ ├── pyhealth.tasks.EEG_abnormal.rst
│ │ ├── pyhealth.tasks.EEG_events.rst
│ │ ├── pyhealth.tasks.InHospitalMortalityMIMIC4.rst
│ │ ├── pyhealth.tasks.MIMIC3ICD9Coding.rst
│ │ ├── pyhealth.tasks.MedicalTranscriptionsClassification.rst
│ │ ├── pyhealth.tasks.Readmission30DaysMIMIC4.rst
│ │ ├── pyhealth.tasks.SleepStagingSleepEDF.rst
│ │ ├── pyhealth.tasks.cardiology_detect.rst
│ │ ├── pyhealth.tasks.drug_recommendation.rst
│ │ ├── pyhealth.tasks.length_of_stay_prediction.rst
│ │ ├── pyhealth.tasks.mortality_prediction.rst
│ │ ├── pyhealth.tasks.patient_linkage_mimic3_fn.rst
│ │ ├── pyhealth.tasks.readmission_prediction.rst
│ │ ├── pyhealth.tasks.sleep_staging.rst
│ │ └── pyhealth.tasks.temple_university_EEG_tasks.rst
│ ├── tokenizer.rst
│ └── trainer.rst
├── conf.py
├── developer.rst
├── examples.rst
├── examples_bak.rst
├── faq.rst
├── how_to_get_started.rst
├── index.rst
├── install.rst
├── live.rst
├── log.rst
├── make.bat
├── requirements.txt
├── tutorials.rst
└── use_cases.rst
├── examples
├── ChestXray-image-generation-GAN.ipynb
├── EEG_events_SparcNet.py
├── EEG_isAbnormal_SparcNet.py
├── benchmark_ehrshot_xgboost.ipynb
├── cardiology_detection_isAR_SparcNet.py
├── chextXray_image_generation_VAE.py
├── drug_recommendation_eICU_transformer.py
├── drug_recommendation_mimic3_gamenet.py
├── drug_recommendation_mimic3_micron.py
├── drug_recommendation_mimic3_molerec.py
├── drug_recommendation_mimic3_safedrug.py
├── drug_recommendation_mimic3_transformer.py
├── drug_recommendation_mimic4_gamenet.py
├── graph_torchvision_model.ipynb
├── interpret_demo.ipynb
├── kg_embedding.ipynb
├── length_of_stay_mimic3_rnn.py
├── lm_embedding_huggingface.ipynb
├── lm_embedding_openai.ipynb
├── medcode.py
├── mortality_mimic3_adacare.py
├── mortality_mimic3_agent.py
├── mortality_mimic3_concare.py
├── mortality_mimic3_grasp.py
├── mortality_mimic3_rnn.py
├── mortality_mimic3_stagenet.py
├── mortality_mimic3_tcn.py
├── patient_linkage_mimic3_medlink.py
├── prepare_mapping.ipynb
├── readmission_mimic3_fairness.py
├── readmission_mimic3_rnn.py
├── sleep_staging_ISRUC_SparcNet.py
├── sleep_staging_shhs_contrawr.py
├── sleep_staging_sleepEDF_contrawr.py
├── test.py
├── test_dataloader.sh
└── test_eICU_addition.py
├── figure
├── five-stage-pipeline.png
├── overview.png
├── poster.png
└── structured-dataset.png
├── hackthon
└── 202410-sunlab-hackthon.md
├── leaderboard
├── README.rst
├── __init__.py
├── ext_plot.py
├── leaderboard_gen.py
├── rtd_build_trigger.py
└── utils.py
├── makefile
├── pyhealth
├── __init__.py
├── calib
│ ├── README.md
│ ├── __init__.py
│ ├── base_classes.py
│ ├── calibration
│ │ ├── __init__.py
│ │ ├── dircal.py
│ │ ├── hb.py
│ │ ├── kcal
│ │ │ ├── __init__.py
│ │ │ ├── bw.py
│ │ │ ├── embed_data.py
│ │ │ └── kde.py
│ │ └── temperature_scale.py
│ ├── predictionset
│ │ ├── __init__.py
│ │ ├── favmac
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ └── quantiletree.py
│ │ ├── label.py
│ │ └── scrib
│ │ │ ├── __init__.py
│ │ │ ├── quicksearch.py
│ │ │ └── quicksearch_cython.pyx
│ └── utils.py
├── data
│ ├── __init__.py
│ └── data.py
├── datasets
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── cardiology.py
│ ├── configs
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── covid19_cxr.yaml
│ │ ├── ehrshot.yaml
│ │ ├── medical_transcriptions.yaml
│ │ ├── mimic3.yaml
│ │ ├── mimic4_cxr.yaml
│ │ ├── mimic4_ehr.yaml
│ │ └── mimic4_note.yaml
│ ├── covid19_cxr.py
│ ├── ehrshot.py
│ ├── eicu.py
│ ├── featurizers
│ │ ├── signal.py
│ │ └── text.py
│ ├── isruc.py
│ ├── medical_transcriptions.py
│ ├── mimic3.py
│ ├── mimic4.py
│ ├── mimicextract.py
│ ├── omop.py
│ ├── sample_dataset.py
│ ├── shhs.py
│ ├── sleepedf.py
│ ├── splitter.py
│ ├── tuab.py
│ ├── tuev.py
│ └── utils.py
├── interpret
│ ├── __init__.py
│ └── methods
│ │ ├── __init__.py
│ │ └── chefer.py
├── medcode
│ ├── __init__.py
│ ├── codes
│ │ ├── __init__.py
│ │ ├── atc.py
│ │ ├── ccscm.py
│ │ ├── ccsproc.py
│ │ ├── icd10cm.py
│ │ ├── icd10proc.py
│ │ ├── icd9cm.py
│ │ ├── icd9proc.py
│ │ ├── ndc.py
│ │ ├── rxnorm.py
│ │ └── umls.py
│ ├── cross_map.py
│ ├── inner_map.py
│ ├── pretrained_embeddings
│ │ ├── __init__.py
│ │ ├── kg_emb
│ │ │ ├── __init__.py
│ │ │ ├── datasets
│ │ │ │ ├── __init__.py
│ │ │ │ ├── base_kg_dataset.py
│ │ │ │ ├── sample_kg_dataset.py
│ │ │ │ ├── splitter.py
│ │ │ │ └── umls.py
│ │ │ ├── examples
│ │ │ │ └── train_kge_model.py
│ │ │ ├── models
│ │ │ │ ├── __init__.py
│ │ │ │ ├── complex.py
│ │ │ │ ├── distmult.py
│ │ │ │ ├── kg_base.py
│ │ │ │ ├── rotate.py
│ │ │ │ └── transe.py
│ │ │ └── tasks
│ │ │ │ ├── __init__.py
│ │ │ │ └── link_prediction.py
│ │ └── lm_emb
│ │ │ ├── __init__.py
│ │ │ ├── huggingface_retriever.py
│ │ │ └── openai_retriever.py
│ └── utils.py
├── metrics
│ ├── __init__.py
│ ├── binary.py
│ ├── calibration.py
│ ├── drug_recommendation.py
│ ├── fairness.py
│ ├── fairness_utils
│ │ ├── __init__.py
│ │ ├── group.py
│ │ └── utils.py
│ ├── multiclass.py
│ ├── multilabel.py
│ ├── prediction_set.py
│ ├── ranking.py
│ └── regression.py
├── models
│ ├── __init__.py
│ ├── adacare.py
│ ├── agent.py
│ ├── base_model.py
│ ├── cnn.py
│ ├── concare.py
│ ├── contrawr.py
│ ├── deepr.py
│ ├── embedding.py
│ ├── gamenet.py
│ ├── gan.py
│ ├── gnn.py
│ ├── graph_torchvision_model.py
│ ├── grasp.py
│ ├── medlink
│ │ ├── __init__.py
│ │ ├── bm25.py
│ │ ├── model.py
│ │ └── utils.py
│ ├── micron.py
│ ├── mlp.py
│ ├── molerec.py
│ ├── retain.py
│ ├── rnn.py
│ ├── safedrug.py
│ ├── sparcnet.py
│ ├── stagenet.py
│ ├── tcn.py
│ ├── torchvision_model.py
│ ├── transformer.py
│ ├── transformers_model.py
│ ├── utils.py
│ └── vae.py
├── nlp
│ └── metrics.py
├── processors
│ ├── __init__.py
│ ├── base_processor.py
│ ├── image_processor.py
│ ├── label_processor.py
│ ├── raw_processor.py
│ ├── sequence_processor.py
│ ├── signal_processor.py
│ ├── text_processor.py
│ └── timeseries_processor.py
├── sampler
│ ├── __init__.py
│ └── sage_sampler.py
├── tasks
│ ├── EEG_abnormal.py
│ ├── EEG_events.py
│ ├── __init__.py
│ ├── base_task.py
│ ├── benchmark_ehrshot.py
│ ├── cardiology_detect.py
│ ├── covid19_cxr_classification.py
│ ├── drug_recommendation.py
│ ├── in_hospital_mortality_mimic4.py
│ ├── length_of_stay_prediction.py
│ ├── medical_coding.py
│ ├── medical_transcriptions_classification.py
│ ├── mortality_prediction.py
│ ├── patient_linkage.py
│ ├── readmission_30days_mimic4.py
│ ├── readmission_prediction.py
│ ├── sleep_staging.py
│ ├── sleep_staging_v2.py
│ └── temple_university_EEG_tasks.py
├── tokenizer.py
├── trainer.py
├── unittests
│ ├── test.ipynb
│ ├── test.py
│ ├── test_data
│ │ └── test_data.py
│ ├── test_datasets
│ │ ├── test_eicu.py
│ │ ├── test_mimic3.py
│ │ ├── test_mimic4.py
│ │ ├── test_omop.py
│ │ └── utils.py
│ ├── test_medcode.py
│ ├── test_mortality_prediction.py
│ └── test_tokenizer.py
└── utils.py
├── readthedocs.yml
├── requirements-nlp.txt
├── requirements.txt
├── setup.cfg
├── setup.py
├── test-resources
└── nlp
│ └── metrics.csv
└── tests
├── base.py
└── nlp
├── __init__.py
└── test_metrics.py
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | pull_request:
5 | push:
6 | paths-ignore:
7 | - '**.md'
8 | - 'doc/*'
9 | branches:
10 | - master
11 |
12 | jobs:
13 | build:
14 | runs-on: ubuntu-latest
15 | strategy:
16 | matrix:
17 | python-version: ['3.11']
18 | steps:
19 | - name: Checkout reposistory
20 | uses: actions/checkout@v3
21 | with:
22 | submodules: recursive
23 |
24 | - name: Set up Python
25 | uses: actions/setup-python@v3
26 | with:
27 | python-version: ${{ matrix.python-version }}
28 |
29 | - name: Display Python version
30 | run: python -c "import sys; print(sys.version)"
31 |
32 | - name: Install dependent packages
33 | run: 'make deps'
34 |
35 | - name: Run tests
36 | run: 'make test'
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *DS_Store*
3 | *output*
4 | resource*
5 |
6 | debug_entry*
7 | playground.ipynb
8 |
9 | # Byte-compiled / optimized / DLL files
10 | __pycache__/
11 | *.py[cod]
12 | *$py.class
13 | .pytest_cache
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # celery beat schedule file
88 | celerybeat-schedule
89 |
90 | # SageMath parsed files
91 | *.sage.py
92 |
93 | # Environments
94 | .env
95 | .venv
96 | env/
97 | venv/
98 | ENV/
99 | env.bak/
100 | venv.bak/
101 |
102 | # Spyder project settings
103 | .spyderproject
104 | .spyproject
105 |
106 | # Rope project settings
107 | .ropeproject
108 |
109 | # mkdocs documentation
110 | /site
111 |
112 | # mypy
113 | .mypy_cache/
114 |
115 | # Sphinx documentation
116 | docs/build/
117 | # /docs/api/data/
118 | # /docs/api/datasets/
119 | # /docs/api/evaluator/
120 | # /docs/api/models/
121 | # /docs/api/tasks/
122 |
123 | #docs/examples.rst
124 | #docs/examples_bak.rst
125 | #docs/faq.rst
126 | #docs/usecase.rst
127 |
128 | # leaderboard credentials
129 | leaderboard/credentials.json
130 | leaderboard/rtd_token.txt
131 |
132 | # locally pre-trained models
133 | pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 SunLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/chat-assistant/Dockerfile:
--------------------------------------------------------------------------------
1 | # Use an official Python runtime as a parent image
2 | FROM python:3.10
3 |
4 | # Set the working directory in the container to /app
5 | WORKDIR /app
6 |
7 | # Copy the current directory contents into the container at /app
8 | COPY . /app
9 |
10 | # Install any needed packages specified in requirements.txt
11 | RUN pip install --upgrade pip
12 | RUN pip install --no-cache-dir -r requirements.txt
13 |
14 | # Make port 80 available to the world outside this container
15 | EXPOSE 7861
16 |
17 |
18 | # Define environment variable
19 | ENV LOG_PATH=/app/logs/
20 | ENV CORPUS_PATH=/app/corpus/
21 | ENV GRADIO_SERVER_NAME 0.0.0.0
22 | ENV GRADIO_SERVER_PORT 7861
23 |
24 | # put openai api key here
25 | ENV OPENAI_API_KEY sk-xxxx
26 |
27 | # Run your_script.py when the container launches
28 | CMD ["python", "main.py"]
29 |
--------------------------------------------------------------------------------
/chat-assistant/OPENAI_API_KEY.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/chat-assistant/OPENAI_API_KEY.txt
--------------------------------------------------------------------------------
/chat-assistant/README.md:
--------------------------------------------------------------------------------
1 | ## LLM for PyHealth
2 | The current LLM for pyhealth interface is deployed here: http://35.208.88.194:7861/.
3 |
4 | ### Step 1:
5 | Merge all pyhealth related information (code and doc txt) into `pyhealth.txt`
6 | - currently, we use https://github.com/mpoon/gpt-repository-loader
7 |
8 | ### Step 2:
9 | Run the `ingest.py` to transform the `pyhealth.txt` into the FAISS vector database
10 | ```python
11 | python ingest.py
12 | ```
13 |
14 | ### Step 3:
15 | Run the retrieval augmented generation (RAG) app for Q & A based on the `pyhealth.txt` document.
16 | ```python
17 | python app_rag.py
18 | ```
19 |
20 |
21 | ### Launch in Docker
22 |
23 | 1. Modfiy environment variables (OPENAI_API_KEY, server address...)in `Dockerfile`.
24 | 2. Build image by `docker build -t chat-pyhealth .`.
25 | 3. Debug a container by `docker run -p 0.0.0.0:7861:7861 --name chat-pyhealth-c -v ./logs/:/app/logs/ chat-pyhealth`.
26 | 4. Run a container by `docker run -d -p 0.0.0.0:7861:7861 --name chat-pyhealth-c -v ./logs/:/app/logs/ chat-pyhealth`.
27 |
28 | ```shell
29 | ## build container
30 | docker run -d -p [host address and port]:[container port] --name [name] -v [host path]:[container path] [image]
31 |
32 | # -d: detached
33 | # -p: port
34 | # --name: container name
35 | # -v: mount directory of container to local host path
36 |
37 | ## check
38 | docker ps
39 | docker images
40 |
41 | ## remove
42 | docker stop / restart chat-pyhealth-c # container
43 | docker rm chat-pyhealth-c # container
44 | docker rmi chat-pyhealth # image
45 |
46 | ## modify directly
47 | docker cp [local file in host] chat-pyhealth-c:[container path]
48 | ```
49 |
50 |
51 | **Let me know if you want to join and help us improve the current interface.**
52 |
--------------------------------------------------------------------------------
/chat-assistant/code-corpus-generation.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | def merge_py_files(directory, output_file):
4 | with open(output_file, 'a') as outfile:
5 | for root, dirs, files in os.walk(directory):
6 | for filename in files:
7 | if filename.endswith(".py"):
8 | with open(os.path.join(root, filename), 'r') as infile:
9 | outfile.write("Here is the code content for " + filename + ":\n")
10 | outfile.write(infile.read())
11 | outfile.write("\n\n")
12 |
13 | # --------- text ------------
14 | with open('pyhealth-code.txt', 'w') as f:
15 | f.write("The code content for pyhealth")
16 |
17 | merge_py_files('../pyhealth', 'pyhealth-code.txt')
18 | merge_py_files('../examples', 'pyhealth-code.txt')
19 | merge_py_files('/home/chaoqiy2/pyhealth-tutorial', 'pyhealth-code.txt')
--------------------------------------------------------------------------------
/chat-assistant/corpus/pyhealth-code.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/chat-assistant/corpus/pyhealth-code.pkl
--------------------------------------------------------------------------------
/chat-assistant/corpus/pyhealth-text.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/chat-assistant/corpus/pyhealth-text.pkl
--------------------------------------------------------------------------------
/chat-assistant/ingest.py:
--------------------------------------------------------------------------------
1 | from langchain.text_splitter import CharacterTextSplitter
2 | from langchain.document_loaders import UnstructuredFileLoader
3 | from langchain.vectorstores.faiss import FAISS
4 | from langchain.embeddings import OpenAIEmbeddings
5 | import pickle
6 | import os
7 | from env import OPENAI_API_KEY
8 |
9 | os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
10 |
11 |
12 | def vectorize_corpus(corpus_path):
13 | print(f"Vectorizing {corpus_path}...")
14 | # print(f"Loading data {corpus_path}...")
15 | loader = UnstructuredFileLoader(corpus_path)
16 | raw_documents = loader.load()
17 |
18 |
19 | # print("Splitting text...")
20 | text_splitter = CharacterTextSplitter(
21 | separator="\n\n",
22 | chunk_size=600,
23 | chunk_overlap=100,
24 | length_function=len,
25 | )
26 | documents = text_splitter.split_documents(raw_documents)
27 |
28 | # print("Creating vectorstore...")
29 | embeddings = OpenAIEmbeddings()
30 | vectorstore = FAISS.from_documents(documents, embeddings)
31 |
32 | base_name, ext = os.path.splitext(corpus_path)
33 | vectorstore_path = base_name + '.pkl'
34 | with open(vectorstore_path, "wb") as f:
35 | pickle.dump(vectorstore, f)
36 |
37 |
38 | vectorize_corpus('corpus/pyhealth-text.txt')
39 | vectorize_corpus('corpus/pyhealth-code.txt')
--------------------------------------------------------------------------------
/chat-assistant/prompts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/chat-assistant/prompts/__init__.py
--------------------------------------------------------------------------------
/chat-assistant/prompts/introduction_prompt.py:
--------------------------------------------------------------------------------
1 | USER_INTRO = '''I am here for some help of PyHealth.'''
2 |
3 |
4 |
5 | AI_INTRO = '''Hi, I am a chat assistant for PyHealth. I am here to solve all kinds of questions and problems when you are using PyHealth. My specific features include:
6 |
7 | 1. I can give you a brief introduction to PyHealth and its functions
8 | 2. I can provide tutorials, explain how to use different modules
9 | 4. I can even provide sample code for specific tasks using PyHealth
10 | 5. I can provide examples on mapping functions between different medical coding systems
11 | 6. I can tell you a little bit about our development team
12 |
13 | How can I assist you with PyHealth today?'''
--------------------------------------------------------------------------------
/chat-assistant/prompts/qa_prompt.py:
--------------------------------------------------------------------------------
1 | QA_PROMPT_TEMPLATE = '''You are a very warm and helpful PyHealth usage assistant for users who want to use PyHealth or are using PyHeath. You need to answer questions related to PyHealth.
2 |
3 | PyHealth is a comprehensive deep learning toolkit for supporting clinical predictive modeling, which is designed for both ML researchers and medical practitioners. We can make your healthcare AI applications easier to deploy and more flexible and customizable.
4 |
5 | You need to refer to Reference Documents and Reference Code to get information about PyHealth to improve your answer quality.
6 | Reference Documents is PyHealth document, which contains information on how to use it, sample code, etc.
7 | Source Code is source code of PyHealth, which contains code implementation details.
8 |
9 | When users describe a use case for PyHealth or want to write a code using PyHealth, you should give a paragraph of sample code with some modifications and comments based on users need.
10 |
11 |
12 | ##### Reference Documents START #####
13 | {ref_doc}
14 | ###### Reference Documents END ######
15 |
16 | ##### Source Code START ##### (All the code you generated should base on this)
17 | {source_code}
18 | ###### Source Code END ######
19 |
20 | ###### Chat History START ###### (This is the chat history between you and the user.)
21 | {chat_history}
22 | ###### Chat History END ######
23 |
24 |
25 | Notice: You should not express any subjective opinion.
26 |
27 | When you generate PyHealth code for users, you must ensure all code are based on source code or sample code in reference documents.
28 | The code you generate must be executable and cannot contain errors, like calling undefined functions in PyHealth.
29 |
30 | You cannot answer any questions that are not related to PyHealth!
31 | You cannot answer any questions that are not related to PyHealth!
32 | You cannot answer any questions that are not related to PyHealth!
33 |
34 |
35 | ##### Current Conversation #####
36 | User: {human_input}
37 | AI: (give your response in Markdown format.)'''
--------------------------------------------------------------------------------
/chat-assistant/prompts/summary_prompt.py:
--------------------------------------------------------------------------------
1 | SUMMARY_PROMPT_TEMPLATE = '''You are an AI who is responsible for summarizing chat history between one another AI and the user.
2 |
3 | I will provide you with the summary of previous chat history and new message in the current round of conversation. Then you need to generate new summary based on these information.
4 |
5 | Summary of Previous Chat History: {previous_summary}
6 | New Message: {new_message}
7 |
8 | The length of your new summary is limited to {summary_token_limitation} tokens.
9 |
10 | Now, give the new summary of chat history:'''
--------------------------------------------------------------------------------
/chat-assistant/requirements.txt:
--------------------------------------------------------------------------------
1 | langchain==0.1.11
2 | openai
3 | faiss-cpu
4 | unstructured
5 | tiktoken
6 | rich #for console formatting
7 | gradio==4.19.2
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/css/overwrite.css:
--------------------------------------------------------------------------------
1 | /*
2 | Furo CSS variables
3 | https://github.com/pradyunsg/furo/blob/main/src/furo/assets/styles/variables/_index.scss
4 | https://github.com/pradyunsg/furo/blob/main/src/furo/theme/partials/_head_css_variables.html
5 | https://github.com/streamlink/streamlink/blob/17a4088c38709123c0bcab4a150549bd16d19e07/docs/_static/styles/custom.css
6 | */
7 |
8 | /* for the sphinx design cards */
9 | body {
10 | --sd-color-shadow: dimgrey;
11 | }
12 |
13 | dt:target,
14 | span.highlighted {
15 | background-color: #f0f0f0;
16 | }
17 |
18 | dl.citation > dt {
19 | float: left;
20 | margin-right: 15px;
21 | font-weight: bold;
22 | }
23 |
24 | /* Parameters normalize size and captialized, */
25 | dl .field-list dt {
26 | font-size: var(--font-size--normal) !important;
27 | text-transform: none !important;
28 | }
29 |
30 | /* examples and headings in classes */
31 | p.rubric {
32 | font-size: var(--font-size--normal);
33 | text-transform: none;
34 | font-weight: 500;
35 | }
36 |
37 | /* adapted from https://github.com/dask/dask-sphinx-theme/blob/main/dask_sphinx_theme/static/css/nbsphinx.css */
38 |
39 | .nbinput .prompt,
40 | .nboutput .prompt {
41 | display: none;
42 | }
43 | .nboutput .stderr {
44 | display: none;
45 | }
46 |
47 | div.nblast.container {
48 | padding-bottom: 10px !important;
49 | padding-right: 0px;
50 | padding-left: 0px;
51 | }
52 |
53 | div.nbinput.container {
54 | padding-top: 10px !important;
55 | padding-right: 0px;
56 | padding-left: 0px;
57 | }
58 |
59 | div.nbinput.container div.input_area div[class*="highlight"] > pre {
60 | padding: 10px !important;
61 | margin: 0;
62 | }
63 |
64 | p.topic-title {
65 | margin-top: 0;
66 | }
67 |
68 | /* so that api methods are small in sidebar */
69 | li.toctree-l3 {
70 | font-size: 81.25% !important;
71 | }
72 | li.toctree-l4 {
73 | font-size: 75% !important;
74 | }
75 |
--------------------------------------------------------------------------------
/docs/_static/css/sphinx_gallery.css:
--------------------------------------------------------------------------------
1 | .sphx-glr-thumbcontainer {
2 | background: inherit !important;
3 | min-height: 250px !important;
4 | margin: 10px !important;
5 | }
6 |
7 | .sphx-glr-thumbcontainer .headerlink {
8 | display: none !important;
9 | }
10 |
11 | div.sphx-glr-thumbcontainer span {
12 | font-style: normal !important;
13 | }
14 |
15 | .sphx-glr-thumbcontainer a.internal {
16 | padding: 140px 10px 0 !important;
17 | }
18 |
19 | .sphx-glr-thumbcontainer .figure {
20 | width: 200px !important;
21 | }
22 |
23 | .sphx-glr-thumbcontainer .figure.align-center {
24 | text-align: center;
25 | margin-left: 0%;
26 | transform: translate(0%);
27 | }
28 |
--------------------------------------------------------------------------------
/docs/_static/pyhealth_logos/pyhealth-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/docs/_static/pyhealth_logos/pyhealth-logo.png
--------------------------------------------------------------------------------
/docs/about.rst:
--------------------------------------------------------------------------------
1 | About us
2 | ========
3 |
4 | We are the `SunLab `_ healthcare research team at UIUC.
5 |
6 | Current Maintainers
7 | ------------------
8 |
9 | `Zhenbang Wu `_ (Ph.D. Student @ University of Illinois Urbana-Champaign)
10 |
11 | `John Wu `_ (Ph.D. Student @ University of Illinois Urbana-Champaign)
12 |
13 | `Junyi Gao `_ (M.S. @ UIUC, Ph.D. Student @ University of Edinburgh)
14 |
15 | `Jimeng Sun `_ (Professor @ University of Illinois Urbana-Champaign)
16 |
17 | Alumni
18 | ------
19 |
20 | `Chaoqi Yang `_ (Ph.D. Student @ University of Illinois Urbana-Champaign)
21 |
22 | `Zhen Lin `_ (Ph.D. Student @ University of Illinois Urbana-Champaign)
23 |
24 | `Benjamin Danek `_ (M.S. Student @ University of Illinois Urbana-Champaign)
25 |
26 | `Patrick Jiang `_ (M.S. Student @ University of Illinois Urbana-Champaign)
--------------------------------------------------------------------------------
/docs/advance_tutorials.rst:
--------------------------------------------------------------------------------
1 | Advanced Tutorials
2 | ========================
3 |
4 | We provided advanced tutorials for supporting various needs.
5 |
6 | `Advanced Tutorial 1: Fit your dataset into our pipeline `_ `[Video] `_
7 |
8 | `Advanced Tutorial 2: Define your own healthcare task `_
9 |
10 | `Advanced Tutorial 3: Adopt customized model into pyhealth `_ `[Video] `_
11 |
12 | `Advanced Tutorial 4: Load your own processed data into pyhealth and try out our ML models `_ `[Video] `_
13 |
14 | ----------
15 |
--------------------------------------------------------------------------------
/docs/api/calib.rst:
--------------------------------------------------------------------------------
1 | Calibration and Uncertainty Quantification
2 | ===============
3 |
4 | In this module, we implemented the following prediction
5 | set constructors or model calibration methods, which can be
6 | combined with any PyHealth models.
7 |
8 | .. toctree::
9 | :maxdepth: 3
10 |
11 | calib/pyhealth.calib.calibration
12 |
13 | calib/pyhealth.calib.predictionset
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/calib/pyhealth.calib.calibration.rst:
--------------------------------------------------------------------------------
1 | pyhealth.calib.calibration
2 | ===================================
3 |
4 |
5 | .. automodule:: pyhealth.calib.calibration
6 | :members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/api/calib/pyhealth.calib.predictionset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.calib.predictionset
2 | ===================================
3 |
4 |
5 | .. automodule:: pyhealth.calib.predictionset
6 | :members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/api/data.rst:
--------------------------------------------------------------------------------
1 | Data
2 | ===============
3 |
4 | **pyhealth.data** defines the atomic data structures of this package.
5 |
6 | .. toctree::
7 | :maxdepth: 3
8 |
9 | data/pyhealth.data.Event
10 | data/pyhealth.data.Patient
11 |
12 |
--------------------------------------------------------------------------------
/docs/api/data/pyhealth.data.Event.rst:
--------------------------------------------------------------------------------
1 | pyhealth.data.Event
2 | =========================
3 |
4 | One basic data structure in the package. It is a simple container for a single event.
5 | It contains all necessary attributes for supporting various healthcare tasks.
6 |
7 | .. autoclass:: pyhealth.data.Event
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
--------------------------------------------------------------------------------
/docs/api/data/pyhealth.data.Patient.rst:
--------------------------------------------------------------------------------
1 | pyhealth.data.Patient
2 | =========================
3 |
4 | Another basic data structure in the package. A Patient is a collection of Visit for
5 | the current patients. It contains all necessary attributes of a patient, such as
6 | ethnicity, mortality status, gender, etc. It can support various healthcare tasks.
7 |
8 | .. autoclass:: pyhealth.data.Patient
9 | :members:
10 | :undoc-members:
11 | :show-inheritance:
12 |
--------------------------------------------------------------------------------
/docs/api/datasets.rst:
--------------------------------------------------------------------------------
1 | Datasets
2 | ===============
3 |
4 | .. toctree::
5 | :maxdepth: 3
6 |
7 | datasets/pyhealth.datasets.BaseDataset
8 | datasets/pyhealth.datasets.SampleDataset
9 | datasets/pyhealth.datasets.MIMIC3Dataset
10 | datasets/pyhealth.datasets.MIMIC4Dataset
11 | datasets/pyhealth.datasets.MedicalTranscriptionsDataset
12 | datasets/pyhealth.datasets.CardiologyDataset
13 | datasets/pyhealth.datasets.eICUDataset
14 | datasets/pyhealth.datasets.ISRUCDataset
15 | datasets/pyhealth.datasets.MIMICExtractDataset
16 | datasets/pyhealth.datasets.OMOPDataset
17 | datasets/pyhealth.datasets.SHHSDataset
18 | datasets/pyhealth.datasets.SleepEDFDataset
19 | datasets/pyhealth.datasets.TUABDataset
20 | datasets/pyhealth.datasets.TUEVDataset
21 | datasets/pyhealth.datasets.splitter
22 | datasets/pyhealth.datasets.utils
23 |
24 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.BaseDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.BaseDataset
2 | ===================================
3 |
4 | This is the basic base dataset class. Any specific datasets will inherit from this class.
5 |
6 | .. autoclass:: pyhealth.datasets.BaseDataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.CardiologyDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.CardiologyDataset
2 | ===================================
3 |
4 | The Cardiology dataset includes six portions "cpsc_2018", "cpsc_2018_extra", "georgia", "ptb", "ptb-xl", "st_petersburg_incart", refer to `doc `_ for more information.
5 |
6 | .. autoclass:: pyhealth.datasets.CardiologyDataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.ISRUCDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.ISRUCDataset
2 | ===================================
3 |
4 | The open ISRUC EEF database, refer to `doc `_ for more information.
5 |
6 | .. autoclass:: pyhealth.datasets.ISRUCDataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.MIMIC3Dataset
2 | ===================================
3 |
4 | The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis.
5 |
6 | .. autoclass:: pyhealth.datasets.MIMIC3Dataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.MIMIC4Dataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.MIMIC4Dataset
2 | ===================================
3 |
4 | The open Medical Information Mart for Intensive Care (MIMIC-IV) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis.
5 |
6 | .. autoclass:: pyhealth.datasets.MIMIC4Dataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.MIMICExtractDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.MIMICExtractDataset
2 | ===================================
3 |
4 | The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis.
5 |
6 | .. autoclass:: pyhealth.datasets.MIMICExtractDataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.MedicalTranscriptionsDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.MedicalTranscriptionsDataset
2 | ===================================
3 |
4 | The Medical Transcriptions dataset, refer to `doc `_.
5 |
6 | .. autoclass:: pyhealth.datasets.MedicalTranscriptionsDataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.OMOPDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.OMOPDataset
2 | ===================================
3 |
4 | We can process any OMOP-CDM formatted database, refer to `doc `_ for more information. The raw data is processed into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis.
5 |
6 | .. autoclass:: pyhealth.datasets.OMOPDataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.SHHSDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.SHHSDataset
2 | ===================================
3 |
4 | The open Sleep-EDF Database Expanded database, refer to `doc `_ for more information.
5 |
6 | .. autoclass:: pyhealth.datasets.SHHSDataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.SampleDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.SampleDataset
2 | ===================================
3 |
4 | This class is the basic sample dataset. All sample datasets are inherited from this class.
5 |
6 | .. autoclass:: pyhealth.datasets.SampleDataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.SleepEDFDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.SleepEDFDataset
2 | ===================================
3 |
4 | The open Sleep-EDF Database Expanded database, refer to `doc `_ for more information.
5 |
6 | .. autoclass:: pyhealth.datasets.SleepEDFDataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.TUABDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.TUABDataset
2 | ===================================
3 |
4 | Dataset is available at https://isip.piconepress.com/projects/nedc/html/tuh_eeg/#c_tuab
5 |
6 | The TUAB dataset (or Temple University Hospital EEG Abnormal Corpus) is a collection of EEG data acquired at the Temple University Hospital.
7 |
8 | The dataset contains both normal and abnormal EEG readings.
9 |
10 | .. autoclass:: pyhealth.datasets.TUABDataset
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.TUEVDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.TUEVDataset
2 | ===================================
3 |
4 | Dataset is available at https://isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml
5 |
6 | This corpus is a subset of TUEG that contains annotations of EEG segments as one of six classes: (1) spike and sharp wave (SPSW), (2) generalized periodic epileptiform discharges (GPED), (3) periodic lateralized epileptiform discharges (PLED), (4) eye movement (EYEM), (5) artifact (ARTF) and (6) background (BCKG).
7 |
8 | .. autoclass:: pyhealth.datasets.TUEVDataset
9 | :members:
10 | :undoc-members:
11 | :show-inheritance:
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.eICUDataset.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.eICUDataset
2 | ===================================
3 |
4 | The open eICU Collaborative Research Database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis.
5 |
6 | .. autoclass:: pyhealth.datasets.eICUDataset
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.splitter.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.splitter
2 | ===================================
3 |
4 | Several data splitting function for `pyhealth.datasets` module to obtain training / validation / test sets.
5 |
6 | .. automodule:: pyhealth.datasets.splitter
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/docs/api/datasets/pyhealth.datasets.utils.rst:
--------------------------------------------------------------------------------
1 | pyhealth.datasets.utils
2 | ===================================
3 |
4 | Several utility functions.
5 |
6 | .. automodule:: pyhealth.datasets.utils
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/docs/api/medcode.rst:
--------------------------------------------------------------------------------
1 | MedCode
2 | ===================================
3 |
4 | We provide medical code mapping tools for (i) ontology mapping within one coding system and
5 | (ii) mapping the same concept cross different coding systems.
6 |
7 |
8 | .. autoclass:: pyhealth.medcode.InnerMap
9 | :members:
10 | :undoc-members:
11 | :show-inheritance:
12 |
13 | .. autoclass:: pyhealth.medcode.CrossMap
14 | :members:
15 | :undoc-members:
16 | :show-inheritance:
17 |
18 | Diagnosis codes:
19 | ----------------------
20 |
21 | .. autoclass:: pyhealth.medcode.ICD9CM
22 | :members:
23 | :undoc-members:
24 | :show-inheritance:
25 |
26 | .. autoclass:: pyhealth.medcode.ICD10CM
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | .. autoclass:: pyhealth.medcode.CCSCM
32 | :members:
33 | :undoc-members:
34 | :show-inheritance:
35 |
36 | Procedure codes:
37 | ----------------------
38 |
39 | .. autoclass:: pyhealth.medcode.ICD9PROC
40 | :members:
41 | :undoc-members:
42 | :show-inheritance:
43 |
44 | .. autoclass:: pyhealth.medcode.ICD10PROC
45 | :members:
46 | :undoc-members:
47 | :show-inheritance:
48 |
49 | .. autoclass:: pyhealth.medcode.CCSPROC
50 | :members:
51 | :undoc-members:
52 | :show-inheritance:
53 |
54 | Medication codes:
55 | -----------------------
56 |
57 | .. autoclass:: pyhealth.medcode.NDC
58 | :members:
59 | :undoc-members:
60 | :show-inheritance:
61 |
62 | .. autoclass:: pyhealth.medcode.RxNorm
63 | :members:
64 | :undoc-members:
65 | :show-inheritance:
66 |
67 | .. autoclass:: pyhealth.medcode.ATC
68 | :members:
69 | :undoc-members:
70 | :show-inheritance:
71 |
72 |
73 |
74 |
75 |
--------------------------------------------------------------------------------
/docs/api/metrics.rst:
--------------------------------------------------------------------------------
1 | Metrics
2 | ===============
3 |
4 | We provide easy to use metrics (the same style and args as `sklearn.metrics`) for binary classification,
5 | multiclass classification, multilabel classification.
6 | For applicable tasks, we provide the relevant metrics for model calibration, as well as those for prediction set evaluation.
7 | Among these we also provide metrics related to uncertainty quantification, for model calibration, as well as metrics that measure the quality of prediction sets
8 | We also provide other metrics specically for healthcare
9 | tasks, such as drug drug interaction (DDI) rate.
10 |
11 |
12 | .. toctree::
13 | :maxdepth: 3
14 |
15 | metrics/pyhealth.metrics.multiclass
16 | metrics/pyhealth.metrics.multilabel
17 | metrics/pyhealth.metrics.binary
18 | metrics/pyhealth.metrics.calibration
19 | metrics/pyhealth.metrics.prediction_set
20 | metrics/pyhealth.metrics.fairness
21 |
--------------------------------------------------------------------------------
/docs/api/metrics/pyhealth.metrics.binary.rst:
--------------------------------------------------------------------------------
1 | pyhealth.metrics.binary
2 | ===================================
3 |
4 | .. currentmodule:: pyhealth.metrics.binary
5 |
6 | .. autofunction:: binary_metrics_fn
--------------------------------------------------------------------------------
/docs/api/metrics/pyhealth.metrics.calibration.rst:
--------------------------------------------------------------------------------
1 | [core] calibration
2 | ===================================
3 |
4 | .. automodule:: pyhealth.metrics.calibration
5 | :members:
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/docs/api/metrics/pyhealth.metrics.fairness.rst:
--------------------------------------------------------------------------------
1 | pyhealth.metrics.fairness
2 | ===================================
3 |
4 | .. currentmodule:: pyhealth.metrics.fairness
5 |
6 | .. autofunction:: fairness_metrics_fn
7 |
8 | .. currentmodule:: pyhealth.metrics.fairness_utils
9 |
10 | .. autofunction:: disparate_impact
11 |
12 | .. autofunction:: statistical_parity_difference
13 |
14 | .. autofunction:: sensitive_attributes_from_patient_ids
--------------------------------------------------------------------------------
/docs/api/metrics/pyhealth.metrics.multiclass.rst:
--------------------------------------------------------------------------------
1 | pyhealth.metrics.multiclass
2 | ===================================
3 |
4 | .. currentmodule:: pyhealth.metrics.multiclass
5 |
6 | .. autofunction:: multiclass_metrics_fn
7 |
8 |
--------------------------------------------------------------------------------
/docs/api/metrics/pyhealth.metrics.multilabel.rst:
--------------------------------------------------------------------------------
1 | pyhealth.metrics.multilabel
2 | ===================================
3 |
4 | .. currentmodule:: pyhealth.metrics.multilabel
5 |
6 | .. autofunction:: multilabel_metrics_fn
--------------------------------------------------------------------------------
/docs/api/metrics/pyhealth.metrics.prediction_set.rst:
--------------------------------------------------------------------------------
1 | [core] prediction_set
2 | ===================================
3 |
4 | .. automodule:: pyhealth.metrics.prediction_set
5 | :members:
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/docs/api/models.rst:
--------------------------------------------------------------------------------
1 | Models
2 | ===============
3 |
4 | We implement the following models for supporting multiple healthcare predictive tasks.
5 |
6 |
7 | .. toctree::
8 | :maxdepth: 3
9 |
10 | models/pyhealth.models.BaseModel
11 | models/pyhealth.models.MLP
12 | models/pyhealth.models.CNN
13 | models/pyhealth.models.RNN
14 | models/pyhealth.models.GNN
15 | models/pyhealth.models.Transformer
16 | models/pyhealth.models.TransformersModel
17 | models/pyhealth.models.RETAIN
18 | models/pyhealth.models.GAMENet
19 | models/pyhealth.models.MICRON
20 | models/pyhealth.models.SafeDrug
21 | models/pyhealth.models.MoleRec
22 | models/pyhealth.models.Deepr
23 | models/pyhealth.models.ContraWR
24 | models/pyhealth.models.SparcNet
25 | models/pyhealth.models.StageNet
26 | models/pyhealth.models.AdaCare
27 | models/pyhealth.models.ConCare
28 | models/pyhealth.models.Agent
29 | models/pyhealth.models.GRASP
30 | models/pyhealth.models.MedLink
31 | models/pyhealth.models.TCN
32 | models/pyhealth.models.GAN
33 | models/pyhealth.models.VAE
34 |
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.AdaCare.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.AdaCare
2 | ===================================
3 |
4 | The separate callable AdaCareLayer and the complete AdaCare model.
5 |
6 | .. autoclass:: pyhealth.models.AdaCareLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.AdaCare
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.Agent.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.Agent
2 | ===================================
3 |
4 | The separate callable AgentLayer and the complete Agent model.
5 |
6 | .. autoclass:: pyhealth.models.AgentLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.Agent
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.BaseModel.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.BaseModel
2 | ===================================
3 |
4 | The base class for all models.
5 |
6 | .. autoclass:: pyhealth.models.BaseModel
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.CNN.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.CNN
2 | ===================================
3 |
4 | The separate callable CNNLayer and the complete CNN model.
5 |
6 | .. autoclass:: pyhealth.models.CNNLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.CNN
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
15 |
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.ConCare.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.ConCare
2 | ===================================
3 |
4 | The separate callable ConCareLayer and the complete ConCare model.
5 |
6 | .. autoclass:: pyhealth.models.ConCareLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.ConCare
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.ContraWR.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.ContraWR
2 | ===================================
3 |
4 |
5 | The separate callable ResBlock2D and the complete ContraWR model.
6 |
7 | .. autoclass:: pyhealth.models.ResBlock2D
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
12 | .. autoclass:: pyhealth.models.ContraWR
13 | :members:
14 | :undoc-members:
15 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.Deepr.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.Deepr
2 | ===================================
3 |
4 | The separate callable DeeprLayer and the complete Deepr model.
5 |
6 | .. autoclass:: pyhealth.models.DeeprLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.Deepr
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
15 |
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.GAMENet.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.GAMENet
2 | ===================================
3 |
4 | The separate callable GAMENetLayer and the complete GAMENet model.
5 |
6 | .. autoclass:: pyhealth.models.GAMENetLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.GAMENet
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
15 |
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.GAN.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.GAN
2 | ===================================
3 |
4 | The GAN model (pyhealth trainer does not apply to GAN, refer to the example/ChestXray-image-generation-GAN.ipynb for examples of using GAN model).
5 |
6 | .. autoclass:: pyhealth.models.GAN
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.GNN.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.GNN
2 | ===================================
3 |
4 | The GNN model (pyhealth trainer does not apply to GNN, refer to the example/ChestXray-image-generation-GAN.ipynb for examples of using GNN model).
5 |
6 | .. autoclass:: pyhealth.models.GAT
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.GCN
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
15 |
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.GRASP.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.GRASP
2 | ===================================
3 |
4 | The separate callable GRASPLayer and the complete GRASP model.
5 |
6 | .. autoclass:: pyhealth.models.GRASPLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.GRASP
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.MICRON.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.MICRON
2 | ===================================
3 |
4 | The separate callable MICRONLayer and the complete MICRON model.
5 |
6 | .. autoclass:: pyhealth.models.MICRONLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.MICRON
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.MLP.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.MLP
2 | ===================================
3 |
4 | The separate callable MLP model.
5 |
6 | .. autoclass:: pyhealth.models.MLP
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.MedLink.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.MedLink
2 | ===================================
3 |
4 | The complete MedLink model.
5 |
6 |
7 | .. autoclass:: pyhealth.models.MedLink
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.MoleRec.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.MoleRec
2 | ===================================
3 |
4 | The separate callable MoleRecLayer and the complete MoleRec model.
5 |
6 | .. autoclass:: pyhealth.models.MoleRecLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.MoleRec
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.RETAIN.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.RETAIN
2 | ===================================
3 |
4 | The separate callable RETAINLayer and the complete RETAIN model.
5 |
6 | .. autoclass:: pyhealth.models.RETAINLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.RETAIN
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.RNN.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.RNN
2 | ===================================
3 |
4 |
5 | The separate callable RNNLayer and the complete RNN model.
6 |
7 | .. autoclass:: pyhealth.models.RNNLayer
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
12 | .. autoclass:: pyhealth.models.RNN
13 | :members:
14 | :undoc-members:
15 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.SafeDrug.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.SafeDrug
2 | ===================================
3 |
4 |
5 | The separate callable SafeDrugLayer and the complete SafeDrug model.
6 |
7 | .. autoclass:: pyhealth.models.SafeDrugLayer
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
12 | .. autoclass:: pyhealth.models.SafeDrug
13 | :members:
14 | :undoc-members:
15 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.SparcNet.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.SparcNet
2 | ===================================
3 |
4 |
5 | The SparcNet Model: Jin Jing, et al. Development of Expert-level Classification of Seizures and Rhythmic and Periodic Patterns During EEG Interpretation. Neurology 2023.
6 |
7 | .. autoclass:: pyhealth.models.DenseLayer
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
12 | .. autoclass:: pyhealth.models.DenseBlock
13 | :members:
14 | :undoc-members:
15 | :show-inheritance:
16 |
17 | .. autoclass:: pyhealth.models.TransitionLayer
18 | :members:
19 | :undoc-members:
20 | :show-inheritance:
21 |
22 | .. autoclass:: pyhealth.models.SparcNet
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.StageNet.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.StageNet
2 | ===================================
3 |
4 | The separate callable StageNetLayer and the complete StageNet model.
5 |
6 | .. autoclass:: pyhealth.models.StageNetLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.StageNet
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.TCN.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.TCN
2 | ===================================
3 |
4 | The separate callable TCNLayer and the complete TCN model.
5 |
6 | .. autoclass:: pyhealth.models.TCNLayer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 | .. autoclass:: pyhealth.models.TCN
12 | :members:
13 | :undoc-members:
14 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.Transformer.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.Transformer
2 | ===================================
3 |
4 |
5 | The separate callable TransformerLayer and the complete Transformer model.
6 |
7 | .. autoclass:: pyhealth.models.TransformerLayer
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
12 | .. autoclass:: pyhealth.models.Transformer
13 | :members:
14 | :undoc-members:
15 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.TransformersModel.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.TransformersModel
2 | ===================================
3 |
4 |
5 | The complete TransformersModel model.
6 |
7 | .. autoclass:: pyhealth.models.TransformersModel
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/models/pyhealth.models.VAE.rst:
--------------------------------------------------------------------------------
1 | pyhealth.models.VAE
2 | ===================================
3 |
4 | The VAE model (treated as a regression task).
5 |
6 | .. autoclass:: pyhealth.models.VAE
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
--------------------------------------------------------------------------------
/docs/api/tasks.rst:
--------------------------------------------------------------------------------
1 | Tasks
2 | ===============
3 |
4 | We support various real-world healthcare predictive tasks defined by **function calls**. The following example tasks are collected from top AI/Medical venues, such as:
5 |
6 | (i) Drug Recommendation [Yang et al. IJCAI 2021a, Yang et al. IJCAI 2021b, Shang et al. AAAI 2020]
7 |
8 | (ii) Readmission Prediction [Choi et al. AAAI 2021]
9 |
10 | (iii) Mortality Prediction [Choi et al. AAAI 2021]
11 |
12 | (iv) Length of Stay Prediction
13 |
14 | (v) Sleep Staging [Yang et al. ArXiv 2021]
15 |
16 | .. toctree::
17 | :maxdepth: 3
18 |
19 | tasks/pyhealth.tasks.BaseTask
20 | tasks/pyhealth.tasks.Readmission30DaysMIMIC4
21 | tasks/pyhealth.tasks.InHospitalMortalityMIMIC4
22 | tasks/pyhealth.tasks.MIMIC3ICD9Coding
23 | tasks/pyhealth.tasks.cardiology_detect
24 | tasks/pyhealth.tasks.COVID19CXRClassification
25 | tasks/pyhealth.tasks.drug_recommendation
26 | tasks/pyhealth.tasks.EEG_abnormal
27 | tasks/pyhealth.tasks.EEG_events
28 | tasks/pyhealth.tasks.length_of_stay_prediction
29 | tasks/pyhealth.tasks.MedicalTranscriptionsClassification
30 | tasks/pyhealth.tasks.mortality_prediction
31 | tasks/pyhealth.tasks.patient_linkage_mimic3_fn
32 | tasks/pyhealth.tasks.readmission_prediction
33 | tasks/pyhealth.tasks.sleep_staging
34 | tasks/pyhealth.tasks.SleepStagingSleepEDF
35 | tasks/pyhealth.tasks.temple_university_EEG_tasks
36 |
37 |
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.BaseTask.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.BaseTask
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.BaseTask
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.COVID19CXRClassification.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.COVID19CXRClassification
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.COVID19CXRClassification
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.EEG_abnormal.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.EEG_abnormal
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.EEG_abnormal
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.EEG_events.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.EEG_events
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.EEG_events
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.InHospitalMortalityMIMIC4.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.EEG_evInHospitalMortalityMIMIC4ents
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.InHospitalMortalityMIMIC4
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.MIMIC3ICD9Coding.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.MIMIC3ICD9Coding
2 | ===========================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.MIMIC3ICD9Coding
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.MedicalTranscriptionsClassification.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.MedicalTranscriptionsClassification
2 | ===========================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.MedicalTranscriptionsClassification
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.Readmission30DaysMIMIC4.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.Readmission30DaysMIMIC4
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.Readmission30DaysMIMIC4
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.SleepStagingSleepEDF.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.SleepStagingSleepEDF
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.SleepStagingSleepEDF
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.cardiology_detect.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.cardiology_detect
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.cardiology_detect
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.drug_recommendation.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.drug_recommendation
2 | ===================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.drug_recommendation
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.length_of_stay_prediction.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.length_of_stay_prediction
2 | ===========================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.length_of_stay_prediction
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.mortality_prediction.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.mortality_prediction
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.mortality_prediction
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.patient_linkage_mimic3_fn.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.patient_linkage_mimic3_fn
2 | ===========================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.patient_linkage_mimic3_fn
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.readmission_prediction.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.readmission_prediction
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.readmission_prediction
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.sleep_staging.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.sleep_staging
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.sleep_staging
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tasks/pyhealth.tasks.temple_university_EEG_tasks.rst:
--------------------------------------------------------------------------------
1 | pyhealth.tasks.temple_university_EEG_tasks
2 | =======================================
3 |
4 |
5 | .. automodule:: pyhealth.tasks.temple_university_EEG_tasks
6 | :members:
7 | :undoc-members:
8 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/api/tokenizer.rst:
--------------------------------------------------------------------------------
1 | Tokenizer
2 | ===============
3 |
4 | The tokenizer functionality can be used for supporting tokens-to-index or index-to-token mapping in general ML setting.
5 |
6 | .. automodule:: pyhealth.tokenizer
7 | :members:
8 | :undoc-members:
9 | :show-inheritance:
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/docs/api/trainer.rst:
--------------------------------------------------------------------------------
1 | Trainer
2 | ===================================
3 |
4 | .. autoclass:: pyhealth.trainer.Trainer
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/developer.rst:
--------------------------------------------------------------------------------
1 | .. _how_to_contribute:
2 |
3 | =====================
4 | How to Contribute
5 | =====================
6 |
7 | Thank you for your interest in contributing to PyHealth! We welcome contributions of all types, including bug fixes, new features, documentation improvements, and more.
8 |
9 | Contribution Guidelines
10 | =======================
11 |
12 | - For **bug fixes** and **documentation improvements**, you can get started on your own.
13 | - For **major framework updates, new datasets, or new models**, it's best to **open an issue and discuss with the PyHealth team first** before proceeding.
14 |
15 | Getting Started
16 | ===============
17 |
18 | 1. **Fork the Repository**: Navigate to the PyHealth GitHub repository and click the **Fork** button.
19 | 2. **Clone Your Fork**:
20 | .. code-block:: bash
21 |
22 | git clone https://github.com/your-username/PyHealth.git
23 | cd PyHealth
24 | 3. **Create a New Branch**:
25 | .. code-block:: bash
26 |
27 | git checkout -b feature-branch
28 |
29 | Setting Up the Development Environment
30 | ======================================
31 |
32 | To avoid interfering with an existing PyHealth installation via `pip`, it is recommended to create a new Python environment specifically for development.
33 | You can use `conda` or `virtualenv` to do so.
34 |
35 | Making Changes
36 | ==============
37 |
38 | 1. **Write Code**: Implement your changes in a clean and modular way.
39 | 2. **Follow Code Style**: Ensure your code follows `black` formatting:
40 | .. code-block:: bash
41 |
42 | black .
43 | 3. **Write Tests**: Add unit tests for your changes in `unittests/`.
44 | 4. **Run Tests**: Before submitting, make sure all tests pass:
45 | .. code-block:: bash
46 |
47 | pytest
48 |
49 | Submitting Your Changes
50 | =======================
51 |
52 | 1. **Commit Changes**:
53 | .. code-block:: bash
54 |
55 | git add .
56 | git commit -m "Description of your changes"
57 | 2. **Push to Your Fork**:
58 | .. code-block:: bash
59 |
60 | git push origin feature-branch
61 | 3. **Open a Pull Request (PR)**:
62 | - Go to the original PyHealth repository.
63 | - Click **New Pull Request**.
64 | - Select your fork and branch.
65 | - Provide a clear description of your changes.
66 | - Submit the PR.
67 |
68 | Review Process
69 | ==============
70 |
71 | - The PyHealth maintainers will review your PR and provide feedback.
72 | - Make any requested changes and push updates to your PR.
73 | - Once approved, your changes will be merged into the main branch!
74 |
75 | Need Help?
76 | ==========
77 |
78 | - Check the `Issues` tab on GitHub.
79 | - Reach out via `email` if applicable.
80 |
81 | We appreciate your contributions and look forward to working with you!
--------------------------------------------------------------------------------
/docs/faq.rst:
--------------------------------------------------------------------------------
1 | Frequently Asked Questions
2 | ==========================
3 |
4 | ----
5 |
6 |
7 | Blueprint & Development Plan
8 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
9 |
10 | The long term goal of PyHealth is to become a comprehensive healthcare AI toolkit that supports
11 | beyond EHR data, but also the images and clinical notes.
12 |
13 | This is the central place to track important things to be fixed/added:
14 |
15 | - The support of image datasets and clinical notes
16 | - The compatibility and the support of OMOP format datasets
17 | - Model persistence (save, load, and portability)
18 | - The release of a benchmark paper with PyHealth
19 | - Add contact channel with `Gitter `_
20 | - Support additional languages, see `Manage Translations `_
21 |
22 | Feel free to open on issue report if needed.
23 | See `Issues `_.
24 |
25 |
26 | Inclusion Criteria
27 | ^^^^^^^^^^^^^^^^^^
28 |
29 | Similarly to Similarly to scikit-learn, We mainly consider well-established algorithms for inclusion.
30 | A rule of thumb is at least two years since publication, 50+ citations, and usefulness.
31 |
32 | However, we encourage the author(s) of newly proposed models to share and add your implementation into combo
33 | for boosting ML accessibility and reproducibility.
34 | This exception only applies if you could commit to the maintenance of your model for at least two year period.
35 |
36 |
--------------------------------------------------------------------------------
/docs/install.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | You could install from PyPi:
5 |
6 | .. code-block:: bash
7 |
8 | pip install pyhealth
9 |
10 |
11 | or from github source:
12 |
13 | .. code-block:: bash
14 |
15 | git clone https://github.com/sunlabuiuc/PyHealth.git
16 | cd pyhealth
17 | pip install .
18 |
19 |
20 | **Required Dependencies**\ :
21 |
22 | .. code-block:: bash
23 |
24 | python>=3.8
25 | torch>=1.8.0
26 | rdkit>=2022.03.4
27 | scikit-learn>=0.24.2
28 | networkx>=2.6.3
29 | pandas>=1.3.2
30 | tqdm
31 |
32 | **Warning 1**\ :
33 |
34 | PyHealth has multiple neural network based models, e.g., LSTM, which are
35 | implemented in PyTorch. However, PyHealth does **NOT** install these DL libraries for you.
36 | This reduces the risk of interfering with your local copies.
37 | If you want to use neural-net based models, please make sure PyTorch is installed.
38 | Similarly, models depending on **xgboost** would **NOT** enforce xgboost installation by default.
39 |
40 |
41 | **CUDA Setting**\ :
42 |
43 | To run PyHealth, you also need CUDA and cudatoolkit that support your GPU well. `More info `_
44 |
45 | For example, if you use NVIDIA RTX A6000 as your GPU for training, you should install a compatible cudatoolkit using:
46 |
47 | .. code-block:: bash
48 |
49 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch.
50 |
51 | ----
--------------------------------------------------------------------------------
/docs/live.rst:
--------------------------------------------------------------------------------
1 | PyHealth Webinar 4/9/25
2 | ======================
3 |
4 | **Start Time**: 2 PM Central Time, Wednesday
5 |
6 | **Recurrence**: There are no repeat sessions for now.
7 |
8 | **Zoom**: `Join Link `_
9 |
10 | **YouTube**: `Recorded Live Sessions `_
11 |
12 | **User/Developer Discord**: `Click to join `_
13 |
14 | Past Scheduled Videos
15 | ^^^^^^^^^^^^^^
16 | **(Dec 21, 2022)** Live 01 - What is PyHealth and How to Get Started? `[Recap] `_
17 |
18 | **(Dec 28, 2022)** Live 02 - Data & Datasets & Tasks: store unstructured data in an structured way. `[Recap I] `_ `[II] `_ `[III] `_ `[IV] `_
19 |
20 | **(Jan 4, 2023)** Live 03 - Models & Trainer & Metrics: initialize and train a deep learning model. `[Recap I] `_ `[II] `_ `[III] `_
21 |
22 | **(Jan 11, 2023)** Live 04 - Tokenizer & Medcode: master the medical code lookup and mapping `[Recap I] `_ `[II] `_
23 |
24 | **(Jan 18, 2023)** Live 05 - PyHealth can support a complete healthcare ML pipeline `[Recap I] `_ `[II] `_
25 |
26 | **(Jan 25, 2023)** Live 06 - Fit your own dataset into pipeline and use our model `[Recap] `_
27 |
28 | **(Feb 1, 2023)** Live 07 - Adopt your customized model and quickly try it on our data `[Recap] `_
29 |
30 | **(Feb 8, 2023)** Live 08 - New feature: support for biosignal data (EEG, ECG, etc.) classification `[Recap I] `_ `[II] `_
31 |
32 | **(Feb 15, 2023)** Live 09 - New feature: parallel and faster data loading architecture
33 |
34 | **(Feb 22, 2023)** Live 10 - Add a covid prediction benchmark (new datasets, new models)
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | Sphinx==5.2.3
2 | sphinx-automodapi
3 | sphinx-autodoc-annotation
4 | sphinx_last_updated_by_git
5 | sphinxcontrib-spelling
6 | sphinx-autodoc-typehints
7 | sphinx_gallery==0.11.1
8 | sphinxcontrib-bibtex==2.5.0
9 | nbsphinx==0.8.9
10 | nbsphinx-link==1.3.0
11 | sphinx-copybutton==0.5.0
12 | furo
13 | myst-parser==0.18.1
14 | sphinx-remove-toctrees==0.0.3
15 | sphinx-design==0.3.0
16 | sphinxext-opengraph==0.6.3
17 | sphinxcontrib-httpdomain==1.8.0
18 | sphinx_copybutton
19 | sphinx_toggleprompt
20 | pandas-bokeh==0.5.5
21 | bokeh==3.0.1
22 | gspread==5.6.2
23 | google-cloud-storage==2.6.0
24 | oauth2client==4.1.3
25 | Jinja2==3.1.4
26 | flask==2.2.5
--------------------------------------------------------------------------------
/docs/use_cases.rst:
--------------------------------------------------------------------------------
1 | .. _use_cases:
2 |
3 | =====================
4 | Use Cases of PyHealth
5 | =====================
6 |
7 | PyHealth enables various healthcare machine learning applications. Below are some practical use cases, categorized by data modality, each linked to an interactive Google Colab notebook.
8 |
9 | Structured Data
10 | ===============
11 |
12 | Predicting Hospital Readmission
13 | -------------------------------
14 | Hospital readmission prediction helps identify patients at high risk of returning to the hospital shortly after discharge. This can assist healthcare providers in taking preventive measures.
15 |
16 | - [Colab Notebook](https://colab.research.google.com/drive/1bhCwbXce1YFtVaQLsOt4FcyZJ1_my7Cs?usp=sharing)
17 |
18 | Drug Recommendation System
19 | --------------------------
20 | Personalized drug recommendation models can suggest appropriate medications based on a patient’s medical history, improving treatment outcomes.
21 |
22 | - [Colab Notebook](https://colab.research.google.com/drive/10CSb4F4llYJvv42yTUiRmvSZdoEsbmFF?usp=sharing)
23 |
24 | Length of Stay Prediction
25 | --------------------------
26 | Predicting hospital length of stay aids resource allocation, bed management, and patient care planning in hospitals.
27 |
28 | - [Colab Notebook](https://colab.research.google.com/drive/1JoPpXqqB1_lGF1XscBOsDHMLtgvlOYI1?usp=sharing)
29 |
30 | Mortality Prediction from ICU Data
31 | -----------------------------------
32 | Predicting ICU patient mortality using clinical data can help prioritize critical care and optimize resource usage.
33 |
34 | - [Colab Notebook](https://colab.research.google.com/drive/1Qblpcv4NWjrnADT66TjBcNwOe8x6wU4c?usp=sharing)
35 |
36 | Time-Series Data
37 | ================
38 |
39 | Sleep Staging
40 | -------------
41 | Sleep staging classification uses EEG data to determine different sleep stages, aiding in the diagnosis and treatment of sleep disorders.
42 |
43 | - [Colab Notebook](https://colab.research.google.com/drive/1mpSeNCAthXG3cqROkdUcUdozIPIMTCuo?usp=sharing)
44 |
45 | Imaging Data
46 | ============
47 |
48 | X-ray Classification
49 | ---------------------
50 | X-ray classification models can assist radiologists by automatically detecting abnormalities in chest X-rays and other radiographic images.
51 |
52 | - [Colab Notebook](https://drive.google.com/file/d/1XokhV8dN3lis7gMdjpMhZBEGs03sus9R/view?usp=share_link)
53 |
54 | Text Data
55 | =========
56 |
57 | Medical Transcription Classification
58 | -------------------------------------
59 | Classifying medical transcriptions enables automated processing of clinical notes, improving documentation efficiency and accessibility.
60 |
61 | - [Colab Notebook](https://drive.google.com/file/d/1JxQYEj94WjEsRifAOyEfrFWWIjqXxrqH/view?usp=share_link)
62 |
63 | Each notebook provides step-by-step guidance on data processing, model training, and evaluation using PyHealth.
64 |
--------------------------------------------------------------------------------
/examples/EEG_events_SparcNet.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import split_by_visit, get_dataloader
2 | from pyhealth.trainer import Trainer
3 | from pyhealth.datasets import TUEVDataset
4 | from pyhealth.tasks import EEG_events_fn
5 | from pyhealth.models import SparcNet
6 |
7 | # step 1: load signal data
8 | dataset = TUEVDataset(root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/",
9 | dev=True,
10 | refresh_cache=True,
11 | )
12 |
13 | # step 2: set task
14 | TUEV_ds = dataset.set_task(EEG_events_fn)
15 | TUEV_ds.stat()
16 |
17 | # split dataset
18 | train_dataset, val_dataset, test_dataset = split_by_visit(
19 | TUEV_ds, [0.6, 0.2, 0.2]
20 | )
21 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
22 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
23 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
24 | print(
25 | "loader size: train/val/test",
26 | len(train_dataset),
27 | len(val_dataset),
28 | len(test_dataset),
29 | )
30 |
31 | # STEP 3: define model
32 | model = SparcNet(
33 | dataset=TUEV_ds,
34 | feature_keys=["signal"],
35 | label_key="label",
36 | mode="multiclass",
37 | )
38 |
39 | # STEP 4: define trainer
40 | trainer = Trainer(model=model, device="cuda:4")
41 | trainer.train(
42 | train_dataloader=train_dataloader,
43 | val_dataloader=val_dataloader,
44 | epochs=10,
45 | optimizer_params={"lr": 1e-3},
46 | )
47 |
48 | # STEP 5: evaluate
49 | print(trainer.evaluate(test_dataloader))
50 |
--------------------------------------------------------------------------------
/examples/EEG_isAbnormal_SparcNet.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import split_by_visit, get_dataloader
2 | from pyhealth.trainer import Trainer
3 | from pyhealth.datasets import TUABDataset
4 | from pyhealth.tasks import EEG_isAbnormal_fn
5 | from pyhealth.models import SparcNet
6 |
7 | # step 1: load signal data
8 | dataset = TUABDataset(root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/",
9 | dev=True,
10 | refresh_cache=True,
11 | )
12 |
13 | # step 2: set task
14 | TUAB_ds = dataset.set_task(EEG_isAbnormal_fn)
15 | TUAB_ds.stat()
16 |
17 | # split dataset
18 | train_dataset, val_dataset, test_dataset = split_by_visit(
19 | TUAB_ds, [0.6, 0.2, 0.2]
20 | )
21 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
22 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
23 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
24 | print(
25 | "loader size: train/val/test",
26 | len(train_dataset),
27 | len(val_dataset),
28 | len(test_dataset),
29 | )
30 |
31 | # STEP 3: define model
32 | model = SparcNet(
33 | dataset=TUAB_ds,
34 | feature_keys=["signal"],
35 | label_key="label",
36 | mode="binary",
37 | )
38 |
39 | # STEP 4: define trainer
40 | trainer = Trainer(model=model, device="cuda:4")
41 | trainer.train(
42 | train_dataloader=train_dataloader,
43 | val_dataloader=val_dataloader,
44 | epochs=10,
45 | monitor="pr_auc",
46 | optimizer_params={"lr": 1e-3},
47 | )
48 |
49 | # STEP 5: evaluate
50 | print(trainer.evaluate(test_dataloader))
51 |
--------------------------------------------------------------------------------
/examples/cardiology_detection_isAR_SparcNet.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import split_by_visit, get_dataloader
2 | from pyhealth.trainer import Trainer
3 | from pyhealth.datasets import CardiologyDataset
4 | from pyhealth.tasks import cardiology_isAR_fn
5 | from pyhealth.models import ContraWR, SparcNet
6 |
7 | # step 1: load signal data
8 | dataset = CardiologyDataset(root="/srv/local/data/physionet.org/files/challenge-2020/1.0.2/training",
9 | chosen_dataset=[1,1,1,1,1,1],
10 | refresh_cache=False,
11 | dev=True)
12 |
13 | # step 2: set task
14 | cardiology_ds = dataset.set_task(cardiology_isAR_fn)
15 | cardiology_ds.stat()
16 |
17 | # split dataset
18 | train_dataset, val_dataset, test_dataset = split_by_visit(
19 | cardiology_ds, [0.6, 0.2, 0.2]
20 | )
21 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
22 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
23 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
24 | print(
25 | "loader size: train/val/test",
26 | len(train_dataset),
27 | len(val_dataset),
28 | len(test_dataset),
29 | )
30 |
31 | # STEP 3: define model
32 | model = SparcNet(
33 | dataset=cardiology_ds,
34 | feature_keys=["signal"],
35 | label_key="label",
36 | mode="binary",
37 | )
38 |
39 | # STEP 4: define trainer
40 | trainer = Trainer(model=model, device="cuda:4")
41 | trainer.train(
42 | train_dataloader=train_dataloader,
43 | val_dataloader=val_dataloader,
44 | epochs=5,
45 | monitor="pr_auc",
46 | )
47 |
48 | # STEP 5: evaluate
49 | print(trainer.evaluate(test_dataloader))
50 |
--------------------------------------------------------------------------------
/examples/chextXray_image_generation_VAE.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import split_by_visit, get_dataloader
2 | from pyhealth.trainer import Trainer
3 | from pyhealth.datasets import COVID19CXRDataset
4 | from pyhealth.models import VAE
5 | from torchvision import transforms
6 |
7 | import torch
8 | import numpy as np
9 |
10 | # step 1: load signal data
11 | root = "/srv/local/data/COVID-19_Radiography_Dataset"
12 | base_dataset = COVID19CXRDataset(root)
13 |
14 | # step 2: set task
15 | sample_dataset = base_dataset.set_task()
16 |
17 | # the transformation automatically normalize the pixel intensity into [0, 1]
18 | transform = transforms.Compose([
19 | transforms.Lambda(lambda x: x if x.shape[0] == 3 else x.repeat(3, 1, 1)), # only use the first channel
20 | transforms.Resize((128, 128)),
21 | ])
22 |
23 | def encode(sample):
24 | sample["path"] = transform(sample["path"])
25 | return sample
26 |
27 | sample_dataset.set_transform(encode)
28 |
29 |
30 | # split dataset
31 | train_dataset, val_dataset, test_dataset = split_by_visit(
32 | sample_dataset, [0.6, 0.2, 0.2]
33 | )
34 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
35 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
36 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)
37 |
38 | data = next(iter(train_dataloader))
39 | print (data)
40 |
41 | print (data["path"][0].shape)
42 |
43 | print(
44 | "loader size: train/val/test",
45 | len(train_dataset),
46 | len(val_dataset),
47 | len(test_dataset),
48 | )
49 |
50 | # STEP 3: define model
51 | model = VAE(
52 | dataset=sample_dataset,
53 | input_channel=3,
54 | input_size=128,
55 | feature_keys=["path"],
56 | label_key="path",
57 | mode="regression",
58 | hidden_dim = 128,
59 | )
60 |
61 | # STEP 4: define trainer
62 | trainer = Trainer(model=model, device="cuda:4", metrics=["kl_divergence", "mse", "mae"])
63 | trainer.train(
64 | train_dataloader=train_dataloader,
65 | val_dataloader=val_dataloader,
66 | epochs=10,
67 | monitor="kl_divergence",
68 | monitor_criterion="min",
69 | optimizer_params={"lr": 1e-3},
70 | )
71 |
72 | # # STEP 5: evaluate
73 | # print(trainer.evaluate(test_dataloader))
74 |
75 |
76 | import matplotlib.pyplot as plt
77 |
78 | # EXP 1: check the real chestxray image and the reconstructed image
79 | X, X_rec, _ = trainer.inference(test_dataloader)
80 |
81 | plt.figure()
82 | plt.subplot(1, 2, 1)
83 | plt.imshow(X[0].reshape(128, 128), cmap="gray")
84 | plt.subplot(1, 2, 2)
85 | plt.imshow(X_rec[0].reshape(128, 128), cmap="gray")
86 | plt.savefig("chestxray_vae_comparison.png")
87 |
88 | # EXP 2: random images
89 | model = trainer.model
90 |
91 | model.eval()
92 | with torch.no_grad():
93 | x = np.random.normal(0, 1, 128)
94 | x = x.astype(np.float32)
95 | x = torch.from_numpy(x).to(trainer.device)
96 | rec = model.decoder(x).detach().cpu().numpy()
97 | rec = rec.reshape((128, 128))
98 | plt.figure()
99 | plt.imshow(rec, cmap="gray")
100 | plt.savefig("chestxray_vae_synthetic.png")
--------------------------------------------------------------------------------
/examples/drug_recommendation_mimic3_gamenet.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import GAMENet
4 | from pyhealth.tasks import drug_recommendation_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | # STEP 1: load data
8 | base_dataset = MIMIC3Dataset(
9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
11 | code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
12 | dev=True,
13 | refresh_cache=False,
14 | )
15 | base_dataset.stat()
16 |
17 | # STEP 2: set task
18 | sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn)
19 | sample_dataset.stat()
20 |
21 | train_dataset, val_dataset, test_dataset = split_by_patient(
22 | sample_dataset, [0.8, 0.1, 0.1]
23 | )
24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
27 |
28 | # STEP 3: define model
29 | model = GAMENet(
30 | sample_dataset,
31 | )
32 |
33 | # STEP 4: define trainer
34 | trainer = Trainer(
35 | model=model,
36 | metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"],
37 | )
38 |
39 | trainer.train(
40 | train_dataloader=train_dataloader,
41 | val_dataloader=val_dataloader,
42 | epochs=20,
43 | monitor="pr_auc_samples",
44 | )
45 |
46 | # STEP 5: evaluate
47 | print (trainer.evaluate(test_dataloader))
48 |
--------------------------------------------------------------------------------
/examples/drug_recommendation_mimic3_micron.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import MICRON
4 | from pyhealth.tasks import drug_recommendation_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | # STEP 1: load data
8 | base_dataset = MIMIC3Dataset(
9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
11 | code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
12 | dev=True,
13 | refresh_cache=False,
14 | )
15 | base_dataset.stat()
16 |
17 | # STEP 2: set task
18 | sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn)
19 | sample_dataset.stat()
20 |
21 | train_dataset, val_dataset, test_dataset = split_by_patient(
22 | sample_dataset, [0.8, 0.1, 0.1]
23 | )
24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
27 |
28 | # STEP 3: define model
29 | model = MICRON(
30 | sample_dataset,
31 | )
32 |
33 | # STEP 4: define trainer
34 | trainer = Trainer(model=model, metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"])
35 | trainer.train(
36 | train_dataloader=train_dataloader,
37 | val_dataloader=val_dataloader,
38 | epochs=5,
39 | monitor="pr_auc_samples",
40 | )
41 |
42 | # STEP 5: evaluate
43 | print (trainer.evaluate(test_dataloader))
44 |
--------------------------------------------------------------------------------
/examples/drug_recommendation_mimic3_molerec.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import MoleRec
4 | from pyhealth.tasks import drug_recommendation_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | # STEP 1: load data
8 | base_dataset = MIMIC3Dataset(
9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
11 | code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
12 | dev=True,
13 | refresh_cache=False,
14 | )
15 | base_dataset.stat()
16 |
17 | # STEP 2: set task
18 | sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn)
19 | sample_dataset.stat()
20 |
21 | train_dataset, val_dataset, test_dataset = split_by_patient(
22 | sample_dataset, [0.8, 0.1, 0.1]
23 | )
24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
27 |
28 | # STEP 3: define model
29 | model = MoleRec(
30 | sample_dataset,
31 | feature_keys=["conditions", "procedures"],
32 | label_key="drugs",
33 | mode="multilabel",
34 | )
35 |
36 | # STEP 4: define trainer
37 | trainer = Trainer(
38 | model=model,
39 | metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"],
40 | )
41 |
42 | trainer.train(
43 | train_dataloader=train_dataloader,
44 | val_dataloader=val_dataloader,
45 | epochs=3,
46 | monitor="pr_auc_samples",
47 | )
48 |
49 | # STEP 5: evaluate
50 | print (trainer.evaluate(test_dataloader))
51 |
--------------------------------------------------------------------------------
/examples/drug_recommendation_mimic3_safedrug.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import SafeDrug
4 | from pyhealth.tasks import drug_recommendation_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | # STEP 1: load data
8 | base_dataset = MIMIC3Dataset(
9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
11 | code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
12 | dev=True,
13 | refresh_cache=False,
14 | )
15 | base_dataset.stat()
16 |
17 | # STEP 2: set task
18 | sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn)
19 | sample_dataset.stat()
20 |
21 | train_dataset, val_dataset, test_dataset = split_by_patient(
22 | sample_dataset, [0.8, 0.1, 0.1]
23 | )
24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
27 |
28 | # STEP 3: define model
29 | model = SafeDrug(
30 | sample_dataset,
31 | )
32 |
33 | # STEP 4: define trainer
34 | trainer = Trainer(
35 | model=model,
36 | metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"],
37 | )
38 |
39 | trainer.train(
40 | train_dataloader=train_dataloader,
41 | val_dataloader=val_dataloader,
42 | epochs=25,
43 | monitor="pr_auc_samples",
44 | )
45 |
46 | # STEP 5: evaluate
47 | print (trainer.evaluate(test_dataloader))
48 |
--------------------------------------------------------------------------------
/examples/drug_recommendation_mimic3_transformer.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import Transformer
4 | from pyhealth.tasks import drug_recommendation_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | # STEP 1: load data
8 | base_dataset = MIMIC3Dataset(
9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
11 | code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})},
12 | dev=True,
13 | refresh_cache=True,
14 | )
15 |
16 | base_dataset.stat()
17 |
18 | # STEP 2: set task
19 | sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn)
20 | sample_dataset.stat()
21 |
22 | train_dataset, val_dataset, test_dataset = split_by_patient(
23 | sample_dataset, [0.8, 0.1, 0.1]
24 | )
25 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
26 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
27 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
28 |
29 | # STEP 3: define model
30 | model = Transformer(
31 | dataset=sample_dataset,
32 | feature_keys=["conditions", "procedures"],
33 | label_key="drugs",
34 | mode="multilabel",
35 | )
36 |
37 | # STEP 4: define trainer
38 | trainer = Trainer(
39 | model=model,
40 | metrics=["jaccard_samples", "f1_samples", "pr_auc_samples"],
41 | )
42 |
43 | trainer.train(
44 | train_dataloader=train_dataloader,
45 | val_dataloader=val_dataloader,
46 | epochs=20,
47 | monitor="pr_auc_samples",
48 | )
49 |
50 | # STEP 5: evaluate
51 | print (trainer.evaluate(test_dataloader))
52 |
--------------------------------------------------------------------------------
/examples/length_of_stay_mimic3_rnn.py:
--------------------------------------------------------------------------------
1 | from pyhealth.calib import calibration, predictionset
2 | from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient
3 | from pyhealth.models import Transformer
4 | from pyhealth.tasks import length_of_stay_prediction_mimic3_fn
5 | from pyhealth.trainer import Trainer, get_metrics_fn
6 |
7 | # STEP 1: load data
8 | base_dataset = MIMIC3Dataset(
9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
11 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"},
12 | dev=False,
13 | refresh_cache=True,
14 | )
15 | base_dataset.stat()
16 |
17 | # STEP 2: set task
18 | sample_dataset = base_dataset.set_task(length_of_stay_prediction_mimic3_fn)
19 | sample_dataset.stat()
20 |
21 | train_dataset, val_dataset, test_dataset = split_by_patient(
22 | sample_dataset, [0.8, 0.1, 0.1]
23 | )
24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
27 |
28 | # STEP 3: define modedl
29 | model = Transformer(
30 | dataset=sample_dataset,
31 | feature_keys=["conditions", "procedures", "drugs"],
32 | label_key="label",
33 | mode="multiclass",
34 | )
35 |
36 | # STEP 4: define trainer
37 | trainer = Trainer(model=model)
38 | trainer.train(
39 | train_dataloader=train_dataloader,
40 | val_dataloader=val_dataloader,
41 | epochs=50,
42 | monitor="accuracy",
43 | )
44 |
45 | # STEP 5: evaluate
46 | metrics = ['accuracy', 'f1_macro', 'f1_micro'] + ['ECE_adapt', 'cwECEt_adapt']
47 | y_true_all, y_prob_all = trainer.inference(test_dataloader)[:2]
48 | print(get_metrics_fn(model.mode)(y_true_all, y_prob_all, metrics=metrics))
49 |
50 | # STEP 6: calibrate the model
51 | cal_model = calibration.HistogramBinning(model, debug=True)
52 | cal_model.calibrate(cal_dataset=val_dataset)
53 | y_true_all, y_prob_all = Trainer(model=cal_model).inference(test_dataloader)[:2]
54 | print(get_metrics_fn(cal_model.mode)(y_true_all, y_prob_all, metrics=metrics))
55 |
56 |
57 | # STEP 7: Construct prediction set, controlling overall miscoverage rate (<0.1)
58 | # Note that if you use calibrated model the coverate rate cannot be controlled, because
59 | # with repect to the calibrated model (which was trained on the calibration set), the
60 | # test set and calibration set is not i.i.d
61 | ps_model = predictionset.LABEL(model, 0.1, debug=True)
62 | ps_model.calibrate(cal_dataset=val_dataset)
63 | y_true_all, y_prob_all, _, extra_output = Trainer(model=ps_model).inference(test_dataloader, additional_outputs=['y_predset'])
64 | print(get_metrics_fn(ps_model.mode)(y_true_all, y_prob_all,
65 | metrics=metrics + ['miscoverage_overall_ps', 'rejection_rate'],
66 | y_predset=extra_output['y_predset']))
67 |
--------------------------------------------------------------------------------
/examples/medcode.py:
--------------------------------------------------------------------------------
1 | from pyhealth.medcode import CrossMap, InnerMap
2 |
3 | ndc = InnerMap.load("NDC")
4 | print("Looking up for NDC code 00597005801")
5 | print(ndc.lookup("00597005801"))
6 |
7 | codemap = CrossMap.load("NDC", "ATC")
8 | print("Mapping NDC code 00597005801 to ATC")
9 | print(codemap.map("00597005801"))
10 |
11 | atc = InnerMap.load("ATC")
12 | print("Looking up for ATC code G04CA02")
13 | print(atc.lookup("G04CA02"))
14 |
--------------------------------------------------------------------------------
/examples/mortality_mimic3_adacare.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import AdaCare
4 | from pyhealth.tasks import mortality_prediction_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | if __name__ == "__main__":
8 | # STEP 1: load data
9 | base_dataset = MIMIC3Dataset(
10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"},
13 | dev=False,
14 | refresh_cache=False,
15 | )
16 | base_dataset.stat()
17 |
18 | # STEP 2: set task
19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn)
20 | sample_dataset.stat()
21 |
22 | train_dataset, val_dataset, test_dataset = split_by_patient(
23 | sample_dataset, [0.8, 0.1, 0.1]
24 | )
25 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
26 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
27 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
28 |
29 | # STEP 3: define model
30 | model = AdaCare(
31 | dataset=sample_dataset,
32 | feature_keys=["conditions", "procedures"],
33 | label_key="label",
34 | mode="binary",
35 | use_embedding=[True, True, True],
36 | embedding_dim=32,
37 | hidden_dim=32,
38 | dropout=0.5,
39 | )
40 |
41 | # STEP 4: define trainer
42 | trainer = Trainer(model=model)
43 | trainer.train(
44 | train_dataloader=train_dataloader,
45 | val_dataloader=val_dataloader,
46 | epochs=50,
47 | monitor="roc_auc",
48 | )
49 |
50 | # STEP 5: evaluate
51 | print(trainer.evaluate(test_dataloader))
52 |
--------------------------------------------------------------------------------
/examples/mortality_mimic3_agent.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import Agent
4 | from pyhealth.tasks import mortality_prediction_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | if __name__ == "__main__":
8 | # STEP 1: load data
9 | base_dataset = MIMIC3Dataset(
10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"},
13 | dev=False,
14 | refresh_cache=False,
15 | )
16 | base_dataset.stat()
17 |
18 | # STEP 2: set task
19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn)
20 | sample_dataset.stat()
21 |
22 | train_dataset, val_dataset, test_dataset = split_by_patient(
23 | sample_dataset, [0.8, 0.1, 0.1]
24 | )
25 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
26 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
27 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)
28 |
29 | # STEP 3: define model
30 | model = Agent(
31 | dataset=sample_dataset,
32 | feature_keys=["conditions", "procedures"],
33 | label_key="label",
34 | mode="binary",
35 | embedding_dim=32,
36 | hidden_dim=32,
37 | )
38 |
39 | # STEP 4: define trainer
40 | trainer = Trainer(model=model)
41 | trainer.train(
42 | train_dataloader=train_dataloader,
43 | val_dataloader=val_dataloader,
44 | epochs=50,
45 | monitor="roc_auc",
46 | )
47 |
48 | # STEP 5: evaluate
49 | print(trainer.evaluate(test_dataloader))
50 |
--------------------------------------------------------------------------------
/examples/mortality_mimic3_concare.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import ConCare
4 | from pyhealth.tasks import mortality_prediction_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | if __name__ == "__main__":
8 | # STEP 1: load data
9 | base_dataset = MIMIC3Dataset(
10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"},
13 | dev=False,
14 | refresh_cache=False,
15 | )
16 | base_dataset.stat()
17 |
18 | # STEP 2: set task
19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn)
20 | sample_dataset.stat()
21 |
22 | train_dataset, val_dataset, test_dataset = split_by_patient(
23 | sample_dataset, [0.8, 0.1, 0.1]
24 | )
25 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
26 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
27 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)
28 |
29 | # STEP 3: define model
30 | model = ConCare(
31 | dataset=sample_dataset,
32 | feature_keys=["conditions", "procedures"],
33 | label_key="label",
34 | mode="binary",
35 | use_embedding=[True, True, True],
36 | hidden_dim=32,
37 | )
38 |
39 | # STEP 4: define trainer
40 | trainer = Trainer(model=model)
41 | trainer.train(
42 | train_dataloader=train_dataloader,
43 | val_dataloader=val_dataloader,
44 | epochs=3,
45 | monitor="roc_auc",
46 | )
47 |
48 | # STEP 5: evaluate
49 | print(trainer.evaluate(test_dataloader))
50 |
--------------------------------------------------------------------------------
/examples/mortality_mimic3_grasp.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import GRASP
4 | from pyhealth.tasks import mortality_prediction_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | if __name__ == "__main__":
8 | # STEP 1: load data
9 | base_dataset = MIMIC3Dataset(
10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"},
13 | dev=False,
14 | refresh_cache=False,
15 | )
16 | base_dataset.stat()
17 |
18 | # STEP 2: set task
19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn)
20 | sample_dataset.stat()
21 |
22 | train_dataset, val_dataset, test_dataset = split_by_patient(
23 | sample_dataset, [0.8, 0.1, 0.1]
24 | )
25 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
26 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
27 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)
28 |
29 | # STEP 3: define model
30 | model = GRASP(
31 | dataset=sample_dataset,
32 | feature_keys=["conditions", "procedures"],
33 | label_key="label",
34 | mode="binary",
35 | use_embedding=[True, True, True],
36 | embedding_dim=32,
37 | hidden_dim=32,
38 | )
39 |
40 | # STEP 4: define trainer
41 | trainer = Trainer(model=model)
42 | trainer.train(
43 | train_dataloader=train_dataloader,
44 | val_dataloader=val_dataloader,
45 | epochs=5,
46 | monitor="roc_auc",
47 | )
48 |
49 | # STEP 5: evaluate
50 | print(trainer.evaluate(test_dataloader))
51 |
--------------------------------------------------------------------------------
/examples/mortality_mimic3_rnn.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import RNN
4 | from pyhealth.tasks import mortality_prediction_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | # STEP 1: load data
8 | base_dataset = MIMIC3Dataset(
9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
11 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"},
12 | dev=False,
13 | refresh_cache=False,
14 | )
15 | base_dataset.stat()
16 |
17 | # STEP 2: set task
18 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn)
19 | sample_dataset.stat()
20 |
21 | train_dataset, val_dataset, test_dataset = split_by_patient(
22 | sample_dataset, [0.8, 0.1, 0.1]
23 | )
24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
27 |
28 | # STEP 3: define model
29 | model = RNN(
30 | dataset=sample_dataset,
31 | feature_keys=["conditions", "procedures", "drugs"],
32 | label_key="label",
33 | mode="binary",
34 | )
35 |
36 | # STEP 4: define trainer
37 | trainer = Trainer(model=model)
38 | trainer.train(
39 | train_dataloader=train_dataloader,
40 | val_dataloader=val_dataloader,
41 | epochs=50,
42 | monitor="roc_auc",
43 | )
44 |
45 | # STEP 5: evaluate
46 | print(trainer.evaluate(test_dataloader))
47 |
--------------------------------------------------------------------------------
/examples/mortality_mimic3_stagenet.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import StageNet
4 | from pyhealth.tasks import mortality_prediction_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | if __name__ == "__main__":
8 | # STEP 1: load data
9 | base_dataset = MIMIC3Dataset(
10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"},
13 | dev=False,
14 | refresh_cache=False,
15 | )
16 | base_dataset.stat()
17 |
18 | # STEP 2: set task
19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn)
20 | sample_dataset.stat()
21 |
22 | train_dataset, val_dataset, test_dataset = split_by_patient(
23 | sample_dataset, [0.8, 0.1, 0.1]
24 | )
25 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
26 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
27 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)
28 |
29 | # STEP 3: define model
30 | model = StageNet(
31 | dataset=sample_dataset,
32 | feature_keys=["conditions", "procedures"],
33 | label_key="label",
34 | mode="binary",
35 | embedding_dim=32,
36 | )
37 |
38 | # STEP 4: define trainer
39 | trainer = Trainer(model=model)
40 | trainer.train(
41 | train_dataloader=train_dataloader,
42 | val_dataloader=val_dataloader,
43 | epochs=50,
44 | monitor="roc_auc",
45 | )
46 |
47 | # STEP 5: evaluate
48 | print(trainer.evaluate(test_dataloader))
49 |
--------------------------------------------------------------------------------
/examples/mortality_mimic3_tcn.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import TCN
4 | from pyhealth.tasks import mortality_prediction_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | if __name__ == "__main__":
8 | # STEP 1: load data
9 | base_dataset = MIMIC3Dataset(
10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"},
13 | dev=False,
14 | refresh_cache=False,
15 | )
16 | base_dataset.stat()
17 |
18 | # STEP 2: set task
19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn)
20 | sample_dataset.stat()
21 |
22 | train_dataset, val_dataset, test_dataset = split_by_patient(
23 | sample_dataset, [0.8, 0.1, 0.1]
24 | )
25 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True)
26 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False)
27 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False)
28 |
29 | # STEP 3: define model
30 | model = TCN(
31 | dataset=sample_dataset,
32 | feature_keys=["conditions", "procedures"],
33 | label_key="label",
34 | mode="binary",
35 | embedding_dim=32,
36 | )
37 |
38 | # STEP 4: define trainer
39 | trainer = Trainer(model=model)
40 | trainer.train(
41 | train_dataloader=train_dataloader,
42 | val_dataloader=val_dataloader,
43 | epochs=50,
44 | monitor="roc_auc",
45 | )
46 |
47 | # STEP 5: evaluate
48 | print(trainer.evaluate(test_dataloader))
49 |
--------------------------------------------------------------------------------
/examples/readmission_mimic3_fairness.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.tasks import readmission_prediction_mimic3_fn
3 | from pyhealth.datasets import split_by_patient, get_dataloader
4 | from pyhealth.metrics import fairness_metrics_fn
5 | from pyhealth.models import Transformer
6 | from pyhealth.trainer import Trainer
7 | from pyhealth.metrics.fairness_utils.utils import sensitive_attributes_from_patient_ids
8 |
9 | # STEP 1: load data
10 | base_dataset = MIMIC3Dataset(
11 | root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/",
12 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
13 | )
14 | base_dataset.stat()
15 |
16 | # STEP 2: set task
17 | sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn)
18 | sample_dataset.stat()
19 |
20 | train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])
21 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
22 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
23 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
24 |
25 | # STEP 3: define model
26 | model = Transformer(
27 | dataset=sample_dataset,
28 | # look up what are available for "feature_keys" and "label_keys" in dataset.samples[0]
29 | feature_keys=["conditions", "procedures"],
30 | label_key="label",
31 | mode="binary",
32 | )
33 |
34 | # STEP 4: define trainer
35 | trainer = Trainer(model=model)
36 | trainer.train(
37 | train_dataloader=train_dataloader,
38 | val_dataloader=val_dataloader,
39 | epochs=3,
40 | monitor="pr_auc",
41 | )
42 |
43 | # STEP 5: inference, return patient_ids
44 | y_true, y_prob, loss, patient_ids = trainer.inference(test_dataloader, return_patient_ids=True)
45 |
46 | # STEP 6: get sensitive attribute array from patient_ids
47 | sensitive_attribute_array = sensitive_attributes_from_patient_ids(base_dataset, patient_ids,
48 | 'gender', 'F')
49 |
50 | # STEP 7: use pyhealth.metrics to evaluate fairness
51 | fairness_metrics = fairness_metrics_fn(y_true, y_prob, sensitive_attribute_array,
52 | favorable_outcome=0)
53 | print(fairness_metrics)
--------------------------------------------------------------------------------
/examples/readmission_mimic3_rnn.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import MIMIC3Dataset
2 | from pyhealth.datasets import split_by_patient, get_dataloader
3 | from pyhealth.models import RNN
4 | from pyhealth.tasks import readmission_prediction_mimic3_fn
5 | from pyhealth.trainer import Trainer
6 |
7 | # STEP 1: load data
8 | base_dataset = MIMIC3Dataset(
9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4",
10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"],
11 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"},
12 | dev=False,
13 | refresh_cache=True,
14 | )
15 | base_dataset.stat()
16 |
17 | # STEP 2: set task
18 | sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn)
19 | sample_dataset.stat()
20 |
21 | train_dataset, val_dataset, test_dataset = split_by_patient(
22 | sample_dataset, [0.8, 0.1, 0.1]
23 | )
24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
27 |
28 | # STEP 3: define model
29 | model = RNN(
30 | dataset=sample_dataset,
31 | feature_keys=["conditions", "procedures", "drugs"],
32 | label_key="label",
33 | mode="binary",
34 | )
35 |
36 | # STEP 4: define trainer
37 | trainer = Trainer(model=model)
38 | trainer.train(
39 | train_dataloader=train_dataloader,
40 | val_dataloader=val_dataloader,
41 | epochs=50,
42 | monitor="roc_auc",
43 | )
44 |
45 | # STEP 5: evaluate
46 | trainer.evaluate(test_dataloader)
47 |
--------------------------------------------------------------------------------
/examples/sleep_staging_ISRUC_SparcNet.py:
--------------------------------------------------------------------------------
1 | from pyhealth.calib import calibration
2 | from pyhealth.datasets import ISRUCDataset, get_dataloader, split_by_patient
3 | from pyhealth.models import ContraWR, SparcNet
4 | from pyhealth.tasks import sleep_staging_isruc_fn
5 | from pyhealth.trainer import Trainer, get_metrics_fn
6 |
7 | # step 1: load signal data
8 | dataset = ISRUCDataset(
9 | root="/srv/local/data/trash/",
10 | dev=True,
11 | refresh_cache=False,
12 | # download=True,
13 | )
14 |
15 | print(dataset.stat())
16 |
17 | # step 2: set task
18 | sleep_staging_ds = dataset.set_task(sleep_staging_isruc_fn)
19 | sleep_staging_ds.stat()
20 | print(sleep_staging_ds.samples[0])
21 |
22 | # split dataset
23 | train_dataset, val_dataset, test_dataset = split_by_patient(
24 | sleep_staging_ds, [0.34, 0.33, 0.33]
25 | )
26 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
27 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
28 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
29 | print(
30 | "loader size: train/val/test",
31 | len(train_dataset),
32 | len(val_dataset),
33 | len(test_dataset),
34 | )
35 |
36 | # STEP 3: define model
37 | model = SparcNet(
38 | dataset=sleep_staging_ds,
39 | feature_keys=["signal"],
40 | label_key="label",
41 | mode="multiclass",
42 | )
43 |
44 | # STEP 4: define trainer
45 | trainer = Trainer(model=model)
46 | trainer.train(
47 | train_dataloader=train_dataloader,
48 | val_dataloader=val_dataloader,
49 | epochs=5,
50 | monitor="accuracy",
51 | )
52 |
53 | # STEP 5: evaluate
54 | metrics = ['accuracy', 'f1_macro', 'f1_micro'] + ['cwECEt_adapt']
55 | y_true_all, y_prob_all = trainer.inference(test_dataloader)[:2]
56 | print(get_metrics_fn(model.mode)(y_true_all, y_prob_all, metrics=metrics))
57 |
58 | # STEP 6: calibrate the model
59 | cal_model = calibration.KCal(model, debug=True, dim=32)
60 | cal_model.calibrate(
61 | cal_dataset=val_dataset,
62 | # Uncomment the following line if you want to re-train the embeddings
63 | # train_dataset=train_dataset,
64 | )
65 | y_true_all, y_prob_all = Trainer(model=cal_model).inference(test_dataloader)[:2]
66 | print(get_metrics_fn(cal_model.mode)(y_true_all, y_prob_all, metrics=metrics))
--------------------------------------------------------------------------------
/examples/sleep_staging_shhs_contrawr.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import split_by_patient, get_dataloader
2 | from pyhealth.trainer import Trainer
3 | from pyhealth.datasets import SHHSDataset
4 | from pyhealth.tasks import sleep_staging_shhs_fn
5 | from pyhealth.models import ContraWR, SparcNet
6 |
7 | # step 1: load signal data
8 | dataset = SHHSDataset(
9 | root="/srv/local/data/SHHS/polysomnography",
10 | dev=True,
11 | refresh_cache=False,
12 | )
13 |
14 | # step 2: set task
15 | sleep_staging_ds = dataset.set_task(sleep_staging_shhs_fn)
16 | sleep_staging_ds.stat()
17 |
18 | # split dataset
19 | train_dataset, val_dataset, test_dataset = split_by_patient(
20 | sleep_staging_ds, [0.6, 0.2, 0.2]
21 | )
22 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
23 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
24 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
25 | print(
26 | "loader size: train/val/test",
27 | len(train_dataset),
28 | len(val_dataset),
29 | len(test_dataset),
30 | )
31 |
32 | # STEP 3: define model
33 | model = SparcNet(
34 | dataset=sleep_staging_ds,
35 | feature_keys=["signal"],
36 | label_key="label",
37 | mode="multiclass",
38 | )
39 |
40 | # STEP 4: define trainer
41 | trainer = Trainer(model=model)
42 | trainer.train(
43 | train_dataloader=train_dataloader,
44 | val_dataloader=val_dataloader,
45 | epochs=5,
46 | monitor="accuracy",
47 | )
48 |
49 | # STEP 5: evaluate
50 | print(trainer.evaluate(test_dataloader))
51 |
--------------------------------------------------------------------------------
/examples/sleep_staging_sleepEDF_contrawr.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import split_by_patient, get_dataloader
2 | from pyhealth.trainer import Trainer
3 | from pyhealth.datasets import SleepEDFDataset
4 | from pyhealth.tasks import sleep_staging_sleepedf_fn
5 | from pyhealth.models import ContraWR, SparcNet
6 |
7 | # step 1: load signal data
8 | dataset = SleepEDFDataset(
9 | root="/srv/local/data/SLEEPEDF/sleep-edf-database-expanded-1.0.0/sleep-cassette",
10 | dev=True,
11 | refresh_cache=False,
12 | )
13 |
14 | # step 2: set task
15 | sleep_staging_ds = dataset.set_task(sleep_staging_sleepedf_fn)
16 | sleep_staging_ds.stat()
17 |
18 | # split dataset
19 | train_dataset, val_dataset, test_dataset = split_by_patient(
20 | sleep_staging_ds, [0.6, 0.2, 0.2]
21 | )
22 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
23 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
24 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
25 | print(
26 | "loader size: train/val/test",
27 | len(train_dataset),
28 | len(val_dataset),
29 | len(test_dataset),
30 | )
31 |
32 | # STEP 3: define model
33 | model = SparcNet(
34 | dataset=sleep_staging_ds,
35 | feature_keys=["signal"],
36 | label_key="label",
37 | mode="multiclass",
38 | )
39 |
40 | # STEP 4: define trainer
41 | trainer = Trainer(model=model)
42 | trainer.train(
43 | train_dataloader=train_dataloader,
44 | val_dataloader=val_dataloader,
45 | epochs=5,
46 | monitor="accuracy",
47 | )
48 |
49 | # STEP 5: evaluate
50 | print(trainer.evaluate(test_dataloader))
51 |
--------------------------------------------------------------------------------
/examples/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | for i in range(100):
4 | a = torch.randn((5, 5, 5)) > 0
5 | print(((torch.sum(a, dim=2) != 0) != (torch.any(a != 0, dim=2))).any())
6 |
--------------------------------------------------------------------------------
/examples/test_dataloader.sh:
--------------------------------------------------------------------------------
1 | python pyhealth/datasets/mimic3.py
2 | python pyhealth/datasets/eicu.py
3 | python pyhealth/datasets/omop.py
4 | python pyhealth/datasets/mimic4.py
--------------------------------------------------------------------------------
/examples/test_eICU_addition.py:
--------------------------------------------------------------------------------
1 | from pyhealth.datasets import eICUDataset
2 | from pyhealth.tasks import mortality_prediction_eicu_fn, mortality_prediction_eicu_fn2
3 |
4 | base_dataset = eICUDataset(
5 | root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
6 | tables=["diagnosis", "admissionDx", "treatment"],
7 | dev=False,
8 | refresh_cache=False,
9 | )
10 | sample_dataset = base_dataset.set_task(task_fn=mortality_prediction_eicu_fn2)
11 | sample_dataset.stat()
12 | print(sample_dataset.available_keys)
13 |
14 | # base_dataset = eICUDataset(
15 | # root="/srv/local/data/physionet.org/files/eicu-crd/2.0",
16 | # tables=["diagnosis", "admissionDx", "treatment"],
17 | # dev=True,
18 | # refresh_cache=False,
19 | # )
20 | # sample_dataset = base_dataset.set_task(task_fn=mortality_prediction_eicu_fn2)
21 | # sample_dataset.stat()
22 | # print(sample_dataset.available_keys)
23 |
--------------------------------------------------------------------------------
/figure/five-stage-pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/figure/five-stage-pipeline.png
--------------------------------------------------------------------------------
/figure/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/figure/overview.png
--------------------------------------------------------------------------------
/figure/poster.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/figure/poster.png
--------------------------------------------------------------------------------
/figure/structured-dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/figure/structured-dataset.png
--------------------------------------------------------------------------------
/leaderboard/__init__.py:
--------------------------------------------------------------------------------
1 | from .leaderboard_gen import leaderboard_generation, plots_generation, construct_args
2 | from .utils import *
3 |
--------------------------------------------------------------------------------
/leaderboard/ext_plot.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | import argparse
3 | import os
4 | os.environ["KMP_DUPLICATE_LIB_OK"]=True
5 |
6 | credential_file = './credentials.json'
7 |
8 | def plot_html(args):
9 | dfs = read_dataframes_by_time_from_gcp(args.credentials)
10 | output_file(filename='leaderboard_plotting.html', title='PyHealth Leaderboard')
11 |
12 | bokeh_figures = []
13 |
14 | for task in args.tasks:
15 | df = get_typed_df_with_time(dfs, task)
16 | bokeh_figure = generate_bokeh_figure(df)
17 | bokeh_figures.append(bokeh_figure)
18 |
19 | save(column(bokeh_figures))
20 |
21 |
22 | def construct_args():
23 | parser.add_argument("--credentials", type=str, default=credential_file)
24 | parser.add_argument("--tasks", type=list,
25 | default=[
26 | "drugrec",
27 | "lenOfStay",
28 | "mortality",
29 | "readmission"
30 | ])
31 |
32 |
33 | def main():
34 | args = construct_args()
35 | plot_html(args)
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/leaderboard/rtd_build_trigger.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | def build(args):
4 | version = args.version
5 | version = 'leaderboard-and-unitest-update'
6 | URL = f'https://readthedocs.org/api/v3/projects/pyhealth/versions/{version}/builds/'
7 | with open(args.rtd_token) as f:
8 | TOKEN = f.readlines()[0]
9 | HEADERS = {'Authorization': f'token {TOKEN}'}
10 | response = requests.post(URL, headers=HEADERS)
11 | print(response.json())
12 |
13 |
14 | def construct_args():
15 | parser.add_argument("--rtd_token", type=str, default='./rtd_token.txt')
16 | parser.add_argument("--version", type=str, default='lastest')
17 |
18 |
19 | def main():
20 | args = construct_args()
21 | build(args)
22 |
23 |
24 | if __name__ == '__main__':
25 | main()
26 |
27 |
--------------------------------------------------------------------------------
/makefile:
--------------------------------------------------------------------------------
1 | #@meta {author: "Paul Landes"}
2 | #@meta {desc: "PyHealth build automation", date: "2025-05-22"}
3 |
4 |
5 | ## Build
6 | #
7 | # directory with the unit tests
8 | PY_TEST_DIR ?= tests
9 | # test file glob pattern
10 | PY_TEST_GLOB ?= test_metrics.py
11 |
12 |
13 | ## Targets
14 | #
15 | # install dependencies
16 | .PHONY: deps
17 | deps:
18 | pip install -r requirements-nlp.txt
19 |
20 | # run the unit test cases
21 | .PHONY: test
22 | test:
23 | @echo "Running tests in $(PY_TEST_DIR)/$(PY_TEST_GLOB)"
24 | python -m unittest discover \
25 | -s $(PY_TEST_DIR) -p '$(PY_TEST_GLOB)' -v
26 |
27 | # clean derived objects
28 | .PHONY: clean
29 | clean:
30 | @echo "removing __pycache__"
31 | @find . -type d -name __pycache__ -prune -exec rm -r {} \;
32 |
--------------------------------------------------------------------------------
/pyhealth/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 | import sys
5 |
6 | __version__ = "1.1.4"
7 |
8 | # package-level cache path
9 | BASE_CACHE_PATH = os.path.join(str(Path.home()), ".cache/pyhealth/")
10 | # BASE_CACHE_PATH = "/srv/local/data/pyhealth-cache"
11 | if not os.path.exists(BASE_CACHE_PATH):
12 | os.makedirs(BASE_CACHE_PATH)
13 |
14 | # logging
15 | logger = logging.getLogger(__name__)
16 | logger.setLevel(logging.INFO)
17 | handler = logging.StreamHandler(sys.stdout)
18 | formatter = logging.Formatter("%(message)s")
19 | handler.setFormatter(formatter)
20 | logger.addHandler(handler)
21 |
--------------------------------------------------------------------------------
/pyhealth/calib/README.md:
--------------------------------------------------------------------------------
1 | This module contains several post-hoc uncertainty quantification methods.
2 |
3 | # Model Calibration:
4 |
5 | ### KCal
6 |
7 | Used in: `multiclass`.
8 |
9 | The model needs to be able to take `embed=True` in `forward` and output the penultimate embedding in the output.
10 |
11 | Lin, Zhen, Shubhendu Trivedi, and Jimeng Sun.
12 | "Taking a Step Back with KCal: Multi-Class Kernel-Based Calibration for Deep Neural Networks."
13 | ICLR 2023.
14 |
15 | ### Temperature Scaling
16 |
17 | Used in: `multiclass`, `multilabel` and `binary`.
18 |
19 | Guo, Chuan, Geoff Pleiss, Yu Sun, and Kilian Q. Weinberger.
20 | "On calibration of modern neural networks."
21 | ICML 2017.
22 |
23 | ### Histogram Binning
24 |
25 | Used in: `multiclass`, `multilabel` and `binary`.
26 |
27 |
28 | Zadrozny, Bianca, and Charles Elkan.
29 | "Learning and making decisions when costs and probabilities are both unknown."
30 | In Proceedings of the seventh ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 204-213. 2001.
31 |
32 | Gupta, Chirag, and Aaditya Ramdas.
33 | "Top-label calibration and multiclass-to-binary reductions."
34 | ICLR 2022.
35 |
36 | # Prediction Set:
37 |
38 | ### SCRIB
39 |
40 | Used in: `multiclass`.
41 |
42 | Lin, Zhen, Lucas Glass, M. Brandon Westover, Cao Xiao, and Jimeng Sun.
43 | "SCRIB: Set-classifier with Class-specific Risk Bounds for Blackbox Models."
44 | AAAI 2022.
45 |
46 | ### LABEL
47 |
48 | Used in: `multiclass`.
49 |
50 | Sadinle, Mauricio, Jing Lei, and Larry Wasserman.
51 | "Least ambiguous set-valued classifiers with bounded error levels."
52 | Journal of the American Statistical Association 114, no. 525 (2019): 223-234.
--------------------------------------------------------------------------------
/pyhealth/calib/__init__.py:
--------------------------------------------------------------------------------
1 | from pyhealth.calib import calibration, predictionset
2 |
--------------------------------------------------------------------------------
/pyhealth/calib/base_classes.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Dict
3 |
4 | import torch
5 |
6 |
7 | class PostHocCalibrator(ABC, torch.nn.Module):
8 | def __init__(self, model, **kwargs) -> None:
9 | super().__init__()
10 | self.model = model
11 |
12 | def calibrate(self, cal_dataset):
13 | ...
14 |
15 | def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
16 | ...
17 |
18 |
19 | def to(self, device):
20 | super().to(device)
21 | self.device = device
22 | return self
23 |
24 |
25 | class SetPredictor(ABC, torch.nn.Module):
26 | def __init__(self, model, **kwargs) -> None:
27 | super().__init__()
28 | self.model = model
29 |
30 | def calibrate(self, cal_dataset):
31 | ...
32 |
33 | def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
34 | ...
35 |
36 | def to(self, device):
37 | super().to(device)
38 | self.device = device
39 | return self
--------------------------------------------------------------------------------
/pyhealth/calib/calibration/__init__.py:
--------------------------------------------------------------------------------
1 | """Model calibration methods"""
2 | from pyhealth.calib.calibration.dircal import DirichletCalibration
3 | from pyhealth.calib.calibration.hb import HistogramBinning
4 | from pyhealth.calib.calibration.kcal import KCal
5 | from pyhealth.calib.calibration.temperature_scale import TemperatureScaling
6 |
7 | __all__ = ['DirichletCalibration', 'HistogramBinning', 'KCal', 'TemperatureScaling']
8 |
--------------------------------------------------------------------------------
/pyhealth/calib/predictionset/__init__.py:
--------------------------------------------------------------------------------
1 | """Prediction set construction methods"""
2 | from pyhealth.calib.predictionset.favmac import FavMac
3 | from pyhealth.calib.predictionset.label import LABEL
4 | from pyhealth.calib.predictionset.scrib import SCRIB
5 |
6 | __all__ = ['LABEL', 'SCRIB', 'FavMac']
7 |
--------------------------------------------------------------------------------
/pyhealth/calib/utils.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Optional
3 |
4 | import numpy as np
5 | import torch
6 | import tqdm
7 | from torch import Tensor
8 |
9 | from pyhealth.datasets import utils as datautils
10 |
11 |
12 | def agg_loss(loss:torch.Tensor, reduction: str):
13 | if reduction == 'mean':
14 | return loss.mean()
15 | if reduction == 'sum':
16 | return loss.sum()
17 | return loss
18 |
19 |
20 | def one_hot_np(labels, K):
21 | new_labels = np.zeros((len(labels), K))
22 | new_labels[np.arange(len(labels)), labels] = 1
23 | return new_labels
24 |
25 | class LogLoss(torch.nn.Module):
26 | """Cross entropy, but takes in the probability instead of the logits"""
27 | reduction: str
28 | def __init__(self, weight: Optional[Tensor] = None, ignore_index: int = -100, reduction: str = 'mean', clip=1e-10) -> None:
29 | super(LogLoss, self).__init__()
30 | self.register_buffer('weight', weight)
31 | self.ignore_index = ignore_index
32 | self.reduction = reduction
33 | self.clip = clip
34 |
35 | def forward(self, input: Tensor, target: Tensor) -> Tensor:
36 | assert self.weight is None or isinstance(self.weight, Tensor)
37 | dim = input.dim()
38 | assert dim == 2, f"Expected 2 dimensions (got {dim})"
39 | input = input.clip(self.clip)#this weight should be trivial, so I won't normalize
40 | input = -torch.log(input)
41 | if self.weight is not None:
42 | input = input * self.weight.unsqueeze(0)
43 | loss = torch.gather(input, -1, target.unsqueeze(-1)).squeeze(-1)
44 | return agg_loss(loss, self.reduction)
45 |
46 |
47 | def prepare_numpy_dataset(model, dataset, keys, forward_kwargs=None,
48 | incl_data_keys=None, debug=False, batch_size=32):
49 | if forward_kwargs is None:
50 | forward_kwargs = {}
51 | if incl_data_keys is None:
52 | incl_data_keys = []
53 | loader = datautils.get_dataloader(dataset, batch_size, shuffle=False)
54 |
55 | ret = defaultdict(list)
56 | with torch.no_grad():
57 | for _i, data in tqdm.tqdm(enumerate(loader), desc=f"retrieving {keys}", total=len(loader)):
58 | if debug and _i % 10 != 0:
59 | continue
60 | data.update(forward_kwargs)
61 | res = model(**data)
62 | for key in keys:
63 | ret[key].append(res[key].detach().cpu().numpy())
64 | for key in incl_data_keys:
65 | ret[key].extend(data[key])
66 | for key in incl_data_keys:
67 | ret[key] = np.asarray(ret[key])
68 | for key in keys:
69 | ret[key] = np.concatenate(ret[key])
70 | return ret
71 |
--------------------------------------------------------------------------------
/pyhealth/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import Event, Patient
2 |
3 |
4 | class Visit:
5 | """This class is deprecated and should not be used."""
6 | def __init__(self, *args, **kwargs):
7 | import warnings
8 | warnings.warn("The Visit class is deprecated and will be removed in a future version.", DeprecationWarning)
--------------------------------------------------------------------------------
/pyhealth/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | class BaseEHRDataset:
2 | """This class is deprecated and should not be used."""
3 | def __init__(self, *args, **kwargs):
4 | import warnings
5 | warnings.warn("The BaseEHRDataset class is deprecated and will be removed in a future version.", DeprecationWarning)
6 |
7 | class BaseSignalDataset:
8 | """This class is deprecated and should not be used."""
9 | def __init__(self, *args, **kwargs):
10 | import warnings
11 | warnings.warn("The BaseSignalDataset class is deprecated and will be removed in a future version.", DeprecationWarning)
12 |
13 |
14 | class SampleEHRDataset:
15 | """This class is deprecated and should not be used."""
16 | def __init__(self, *args, **kwargs):
17 | import warnings
18 | warnings.warn("The SampleEHRDataset class is deprecated and will be removed in a future version.", DeprecationWarning)
19 |
20 |
21 | class SampleSignalDataset:
22 | """This class is deprecated and should not be used."""
23 | def __init__(self, *args, **kwargs):
24 | import warnings
25 | warnings.warn("The SampleSignalDataset class is deprecated and will be removed in a future version.", DeprecationWarning)
26 |
27 |
28 | from .base_dataset import BaseDataset
29 | from .cardiology import CardiologyDataset
30 | from .covid19_cxr import COVID19CXRDataset
31 | from .ehrshot import EHRShotDataset
32 | from .eicu import eICUDataset
33 | from .isruc import ISRUCDataset
34 | from .medical_transcriptions import MedicalTranscriptionsDataset
35 | from .mimic3 import MIMIC3Dataset
36 | from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset
37 | from .mimicextract import MIMICExtractDataset
38 | from .omop import OMOPDataset
39 | from .sample_dataset import SampleDataset
40 | from .shhs import SHHSDataset
41 | from .sleepedf import SleepEDFDataset
42 | from .splitter import split_by_patient, split_by_sample, split_by_visit
43 | from .tuab import TUABDataset
44 | from .tuev import TUEVDataset
45 | from .utils import collate_fn_dict, collate_fn_dict_with_padding, get_dataloader
46 |
--------------------------------------------------------------------------------
/pyhealth/datasets/configs/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import DatasetConfig, load_yaml_config
2 |
--------------------------------------------------------------------------------
/pyhealth/datasets/configs/covid19_cxr.yaml:
--------------------------------------------------------------------------------
1 | version: "5.0"
2 | tables:
3 | covid19_cxr:
4 | file_path: "covid19_cxr-metadata-pyhealth.csv"
5 | patient_id: null
6 | timestamp: null
7 | attributes:
8 | - "path"
9 | - "url"
10 | - "label"
11 |
--------------------------------------------------------------------------------
/pyhealth/datasets/configs/medical_transcriptions.yaml:
--------------------------------------------------------------------------------
1 | version: "1.0"
2 | tables:
3 | mtsamples:
4 | file_path: "mtsamples.csv"
5 | patient_id: null
6 | timestamp: null
7 | attributes:
8 | - "description"
9 | - "medical_specialty"
10 | - "sample_name"
11 | - "transcription"
12 | - "keywords"
--------------------------------------------------------------------------------
/pyhealth/datasets/configs/mimic4_ehr.yaml:
--------------------------------------------------------------------------------
1 | version: "2.2"
2 | tables:
3 | patients:
4 | file_path: "hosp/patients.csv.gz"
5 | patient_id: "subject_id"
6 | timestamp: null
7 | attributes:
8 | - "gender"
9 | - "anchor_age"
10 | - "anchor_year"
11 | - "anchor_year_group"
12 | - "dod"
13 |
14 | admissions:
15 | file_path: "hosp/admissions.csv.gz"
16 | patient_id: "subject_id"
17 | timestamp: "admittime"
18 | attributes:
19 | - "hadm_id"
20 | - "admission_type"
21 | - "admission_location"
22 | - "insurance"
23 | - "language"
24 | - "marital_status"
25 | - "race"
26 | - "discharge_location"
27 | - "dischtime"
28 | - "hospital_expire_flag"
29 |
30 | icustays:
31 | file_path: "icu/icustays.csv.gz"
32 | patient_id: "subject_id"
33 | timestamp: "intime"
34 | attributes:
35 | - "stay_id"
36 | - "first_careunit"
37 | - "last_careunit"
38 | - "outtime"
39 |
40 | diagnoses_icd:
41 | file_path: "hosp/diagnoses_icd.csv.gz"
42 | patient_id: "subject_id"
43 | join:
44 | - file_path: "hosp/admissions.csv.gz"
45 | "on": "hadm_id"
46 | how: "inner"
47 | columns:
48 | - "dischtime"
49 | timestamp: "dischtime"
50 | attributes:
51 | - "icd_code"
52 | - "icd_version"
53 | - "seq_num"
54 |
55 | procedures_icd:
56 | file_path: "hosp/procedures_icd.csv.gz"
57 | patient_id: "subject_id"
58 | join:
59 | - file_path: "hosp/admissions.csv.gz"
60 | "on": "hadm_id"
61 | how: "inner"
62 | columns:
63 | - "dischtime"
64 | timestamp: "dischtime"
65 | attributes:
66 | - "icd_code"
67 | - "icd_version"
68 | - "seq_num"
69 |
70 | prescriptions:
71 | file_path: "hosp/prescriptions.csv.gz"
72 | patient_id: "subject_id"
73 | timestamp: "starttime"
74 | attributes:
75 | - "drug"
76 | - "ndc"
77 | - "prod_strength"
78 | - "dose_val_rx"
79 | - "dose_unit_rx"
80 | - "route"
81 | - "stoptime"
82 |
83 | labevents:
84 | file_path: "hosp/labevents.csv.gz"
85 | patient_id: "subject_id"
86 | join:
87 | - file_path: "hosp/d_labitems.csv.gz"
88 | "on": "itemid"
89 | how: "inner"
90 | columns:
91 | - "label"
92 | - "fluid"
93 | - "category"
94 | timestamp: "charttime"
95 | attributes:
96 | - "itemid"
97 | - "label"
98 | - "fluid"
99 | - "category"
100 | - "value"
101 | - "valuenum"
102 | - "valueuom"
103 | - "flag"
104 | - "storetime"
105 |
106 | hcpcsevents:
107 | file_path: "hosp/hcpcsevents.csv.gz"
108 | patient_id: "subject_id"
109 | timestamp: "chartdate"
110 | attributes:
111 | - "hcpcs_cd"
112 | - "seq_num"
113 | - "short_description"
114 |
--------------------------------------------------------------------------------
/pyhealth/datasets/configs/mimic4_note.yaml:
--------------------------------------------------------------------------------
1 | version: "2.2"
2 | tables:
3 | discharge:
4 | file_path: "note/discharge.csv.gz"
5 | patient_id: "subject_id"
6 | timestamp: "charttime"
7 | attributes:
8 | - "note_id"
9 | - "hadm_id"
10 | - "note_type"
11 | - "note_seq"
12 | - "storetime"
13 | - "text"
14 |
15 | discharge_detail:
16 | file_path: "note/discharge_detail.csv.gz"
17 | patient_id: "subject_id"
18 | join:
19 | - file_path: "note/discharge.csv.gz"
20 | "on": "note_id"
21 | how: "inner"
22 | columns:
23 | - "charttime"
24 | timestamp: "charttime"
25 | attributes:
26 | - "note_id"
27 | - "field_name"
28 | - "field_value"
29 | - "field_ordinal"
30 |
31 | radiology:
32 | file_path: "note/radiology.csv.gz"
33 | patient_id: "subject_id"
34 | timestamp: "charttime"
35 | attributes:
36 | - "note_id"
37 | - "hadm_id"
38 | - "note_type"
39 | - "note_seq"
40 | - "storetime"
41 | - "text"
42 |
43 | radiology_detail:
44 | file_path: "note/radiology_detail.csv.gz"
45 | patient_id: "subject_id"
46 | join:
47 | - file_path: "note/radiology.csv.gz"
48 | "on": "note_id"
49 | how: "inner"
50 | columns:
51 | - "charttime"
52 | timestamp: "charttime"
53 | attributes:
54 | - "note_id"
55 | - "field_name"
56 | - "field_value"
57 | - "field_ordinal"
58 |
--------------------------------------------------------------------------------
/pyhealth/datasets/ehrshot.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 | from typing import List, Optional
4 |
5 | from .base_dataset import BaseDataset
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 |
10 | class EHRShotDataset(BaseDataset):
11 | """
12 | A dataset class for handling EHRShot data.
13 |
14 | This class is responsible for loading and managing the EHRShot dataset.
15 |
16 | Website: https://som-shahlab.github.io/ehrshot-website/
17 |
18 | Attributes:
19 | root (str): The root directory where the dataset is stored.
20 | tables (List[str]): A list of tables to be included in the dataset.
21 | dataset_name (Optional[str]): The name of the dataset.
22 | config_path (Optional[str]): The path to the configuration file.
23 | """
24 |
25 | def __init__(
26 | self,
27 | root: str,
28 | tables: List[str],
29 | dataset_name: Optional[str] = None,
30 | config_path: Optional[str] = None,
31 | **kwargs
32 | ) -> None:
33 | if config_path is None:
34 | logger.info("No config path provided, using default config")
35 | config_path = Path(__file__).parent / "configs" / "ehrshot.yaml"
36 | super().__init__(
37 | root=root,
38 | tables=tables,
39 | dataset_name=dataset_name or "ehrshot",
40 | config_path=config_path,
41 | **kwargs
42 | )
43 | return
44 |
--------------------------------------------------------------------------------
/pyhealth/datasets/featurizers/signal.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.signal import resample, butter, filtfilt
3 |
4 | class BiosignalFeaturizer:
5 |
6 | def __init__(self, target_length=256, lowcut=0.5, highcut=40.0, sampling_rate=100):
7 | self.target_length = target_length
8 | self.lowcut = lowcut
9 | self.highcut = highcut
10 | self.sampling_rate = sampling_rate
11 |
12 | def _bandpass_filter(self, signal):
13 | """Apply a bandpass filter to the signal."""
14 | nyquist = 0.5 * self.sampling_rate
15 | low = self.lowcut / nyquist
16 | high = self.highcut / nyquist
17 | b, a = butter(1, [low, high], btype="band")
18 | return filtfilt(b, a, signal)
19 |
20 | def encode(self, signal):
21 | """
22 | Encodes the biosignal into a fixed-length feature representation.
23 |
24 | Args:
25 | signal (np.array): 1D array of raw biosignal data.
26 |
27 | Returns:
28 | np.array: Processed and fixed-length representation of the biosignal.
29 | """
30 | # Apply bandpass filtering
31 | filtered_signal = self._bandpass_filter(signal)
32 |
33 | # Resample to target length for uniformity
34 | resampled_signal = resample(filtered_signal, self.target_length)
35 |
36 | # Normalize the signal
37 | normalized_signal = (resampled_signal - np.mean(resampled_signal)) / np.std(resampled_signal)
38 |
39 | return normalized_signal
40 |
41 | if __name__ == "__main__":
42 | # Example biosignal (ECG or other) with 1000 sample points
43 | sample_signal = np.sin(np.linspace(0, 10, 1000)) + 0.1 * np.random.randn(1000)
44 | featurizer = BiosignalFeaturizer()
45 | print(featurizer)
46 | print(type(featurizer))
47 | print(featurizer.encode(sample_signal))
48 |
--------------------------------------------------------------------------------
/pyhealth/datasets/featurizers/text.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, AutoModel
2 | import torch
3 |
4 |
5 | class TextFeaturizer:
6 |
7 | def __init__(self, model_name="bert-base-uncased"):
8 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9 | self.model = AutoModel.from_pretrained(model_name)
10 |
11 | def encode(self, value):
12 | # Tokenize and process the input text
13 | inputs = self.tokenizer(value, return_tensors="pt", truncation=True,
14 | padding=True)
15 |
16 | # Generate embeddings
17 | with torch.no_grad():
18 | outputs = self.model(**inputs)
19 |
20 | # Use the CLS token embedding as the representation (can also use pooling for sentence embeddings)
21 | embedding = outputs.last_hidden_state[:, 0, :] # CLS token embedding
22 |
23 | return embedding
24 |
25 |
26 | if __name__ == "__main__":
27 | sample_text = "This is a sample text input for the TextFeaturizer."
28 | featurizer = TextFeaturizer()
29 | print(featurizer)
30 | print(type(featurizer))
31 | print(featurizer.encode(sample_text))
32 |
--------------------------------------------------------------------------------
/pyhealth/datasets/medical_transcriptions.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 | from typing import Optional
4 |
5 | from ..tasks import MedicalTranscriptionsClassification
6 | from .base_dataset import BaseDataset
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class MedicalTranscriptionsDataset(BaseDataset):
12 | """Medical transcription data scraped from mtsamples.com.
13 |
14 | Dataset is available at:
15 | https://www.kaggle.com/datasets/tboyle10/medicaltranscriptions
16 |
17 | Args:
18 | root: Root directory of the raw data.
19 | dataset_name: Name of the dataset. Defaults to "medical_transcriptions".
20 | config_path: Path to the configuration file. If None, uses default config.
21 |
22 | Attributes:
23 | root: Root directory of the raw data (should contain many csv files).
24 | dataset_name: Name of the dataset.
25 | config_path: Path to the configuration file.
26 |
27 | Examples:
28 | >>> from pyhealth.datasets import MedicalTranscriptionsDataset
29 | >>> dataset = MedicalTranscriptionsDataset(
30 | ... root="path/to/medical_transcriptions",
31 | ... )
32 | >>> dataset.stats()
33 | >>> samples = dataset.set_task()
34 | >>> print(samples[0])
35 | """
36 |
37 | def __init__(
38 | self,
39 | root: str,
40 | dataset_name: Optional[str] = None,
41 | config_path: Optional[str] = None,
42 | ) -> None:
43 | if config_path is None:
44 | logger.info("No config path provided, using default config")
45 | config_path = (
46 | Path(__file__).parent / "configs" / "medical_transcriptions.yaml"
47 | )
48 | default_tables = ["mtsamples"]
49 | super().__init__(
50 | root=root,
51 | tables=default_tables,
52 | dataset_name=dataset_name or "medical_transcriptions",
53 | config_path=config_path,
54 | )
55 | return
56 |
57 | @property
58 | def default_task(self) -> MedicalTranscriptionsClassification:
59 | """Returns the default task for this dataset."""
60 | return MedicalTranscriptionsClassification()
61 |
--------------------------------------------------------------------------------
/pyhealth/interpret/__init__.py:
--------------------------------------------------------------------------------
1 | from pyhealth.interpret import methods
--------------------------------------------------------------------------------
/pyhealth/interpret/methods/__init__.py:
--------------------------------------------------------------------------------
1 | from pyhealth.interpret.methods.chefer import CheferRelevance
--------------------------------------------------------------------------------
/pyhealth/medcode/__init__.py:
--------------------------------------------------------------------------------
1 | from .codes.atc import ATC
2 | from .codes.ccscm import CCSCM
3 | from .codes.ccsproc import CCSPROC
4 | from .codes.icd10cm import ICD10CM
5 | from .codes.icd10proc import ICD10PROC
6 | from .codes.icd9cm import ICD9CM
7 | from .codes.icd9proc import ICD9PROC
8 | from .codes.ndc import NDC
9 | from .codes.rxnorm import RxNorm
10 | from .codes.umls import UMLS
11 | from .cross_map import CrossMap
12 | from .inner_map import InnerMap
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/pyhealth/medcode/codes/__init__.py
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/atc.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from pyhealth.medcode.inner_map import InnerMap
4 | from pyhealth.medcode.utils import download_and_read_csv
5 |
6 |
7 | class ATC(InnerMap):
8 | """Anatomical Therapeutic Chemical."""
9 |
10 | def __init__(self, **kwargs):
11 | super(ATC, self).__init__(vocabulary="ATC", **kwargs)
12 | self.ddi = dict()
13 |
14 | @staticmethod
15 | def convert(code: str, level=5):
16 | """Convert ATC code to a specific level."""
17 | if type(level) is str:
18 | level = int(level)
19 | assert level in [1, 2, 3, 4, 5]
20 | if level == 1:
21 | return code[:1]
22 | elif level == 2:
23 | return code[:3]
24 | elif level == 3:
25 | return code[:4]
26 | elif level == 4:
27 | return code[:5]
28 | else:
29 | return code
30 |
31 | def get_ddi(
32 | self, gamenet_ddi: bool = False, refresh_cache: bool = False
33 | ) -> List[str]:
34 | """Gets the drug-drug interactions (DDI).
35 |
36 | Args:
37 | gamenet_ddi: Whether to use the DDI from the GAMENet paper,
38 | which is a subset of the DDI from the ATC.
39 | refresh_cache: Whether to refresh the cache. Default is False.
40 | """
41 | filename = "DDI_GAMENet.csv" if gamenet_ddi else "DDI.csv"
42 | if filename not in self.ddi or refresh_cache:
43 | df = download_and_read_csv(filename, refresh_cache)
44 | ddi = []
45 | for idx, row in df.iterrows():
46 | ddi.append([row["ATC i"], row["ATC j"]])
47 | self.ddi[filename] = ddi
48 | return self.ddi[filename]
49 |
50 |
51 | if __name__ == "__main__":
52 | code_sys = ATC(refresh_cache=True)
53 | code_sys.stat()
54 | print(code_sys.lookup("N01AB07"))
55 | print(code_sys.lookup("N01AB07", attribute="level"))
56 | print(code_sys.lookup("N01AB07", attribute="description"))
57 | print(code_sys.lookup("N01AB07", attribute="indication"))
58 | print(code_sys.lookup("N01AB07", attribute="smiles"))
59 | print(code_sys.lookup("N01AB07", attribute="drugbank_id"))
60 | print(code_sys.get_ancestors("N01AB07"))
61 | print(code_sys.get_descendants("N01AB"))
62 | print(len(code_sys.get_ddi(gamenet_ddi=True)))
63 | print(len(code_sys.get_ddi(gamenet_ddi=False)))
64 |
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/ccscm.py:
--------------------------------------------------------------------------------
1 | from pyhealth.medcode.inner_map import InnerMap
2 |
3 |
4 | class CCSCM(InnerMap):
5 | """Classification of Diseases, Clinical Modification."""
6 |
7 | def __init__(self, **kwargs):
8 | super(CCSCM, self).__init__(vocabulary="CCSCM", **kwargs)
9 |
10 |
11 | if __name__ == "__main__":
12 | code_sys = CCSCM(refresh_cache=True)
13 | code_sys.stat()
14 | print("20" in code_sys)
15 | print(code_sys.lookup("10"))
16 | print(code_sys.get_ancestors("10"))
17 | print(code_sys.get_descendants("10"))
18 |
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/ccsproc.py:
--------------------------------------------------------------------------------
1 | from pyhealth.medcode.inner_map import InnerMap
2 |
3 |
4 | class CCSPROC(InnerMap):
5 | """Classification of Diseases, Procedure."""
6 |
7 | def __init__(self, **kwargs):
8 | super(CCSPROC, self).__init__(vocabulary="CCSPROC", **kwargs)
9 |
10 |
11 | if __name__ == "__main__":
12 | code_sys = CCSPROC(refresh_cache=True)
13 | code_sys.stat()
14 | print("1" in code_sys)
15 | print(code_sys.lookup("20"))
16 | print(code_sys.get_ancestors("20"))
17 | print(code_sys.get_descendants("20"))
18 |
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/icd10cm.py:
--------------------------------------------------------------------------------
1 | from pyhealth.medcode.inner_map import InnerMap
2 |
3 |
4 | # TODO: add convert
5 |
6 |
7 | class ICD10CM(InnerMap):
8 | """10-th International Classification of Diseases, Clinical Modification."""
9 |
10 | def __init__(self, **kwargs):
11 | super(ICD10CM, self).__init__(vocabulary="ICD10CM", **kwargs)
12 |
13 | @staticmethod
14 | def standardize(code: str):
15 | """Standardizes ICD10CM code."""
16 | if "." in code:
17 | return code
18 | if len(code) <= 3:
19 | return code
20 | return code[:3] + "." + code[3:]
21 |
22 |
23 | if __name__ == "__main__":
24 | code_sys = ICD10CM(refresh_cache=True)
25 | code_sys.stat()
26 | print("A00.0" in code_sys)
27 | print(code_sys.lookup("D50.0"))
28 | print(code_sys.get_ancestors("D50.0"))
29 | print(code_sys.get_descendants("D50"))
30 |
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/icd10proc.py:
--------------------------------------------------------------------------------
1 | from pyhealth.medcode.inner_map import InnerMap
2 |
3 |
4 | # TODO: add convert
5 |
6 |
7 | class ICD10PROC(InnerMap):
8 | """10-th International Classification of Diseases, Procedure."""
9 |
10 | def __init__(self, **kwargs):
11 | super(ICD10PROC, self).__init__(vocabulary="ICD10PROC", **kwargs)
12 |
13 |
14 | if __name__ == "__main__":
15 | code_sys = ICD10PROC(refresh_cache=True)
16 | code_sys.stat()
17 | print("0LBG0ZZ" in code_sys)
18 | print(code_sys.lookup("0LBG0ZZ"))
19 | print(code_sys.get_ancestors("0LBG0ZZ"))
20 | print(code_sys.get_descendants("0LBG0"))
21 |
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/icd9cm.py:
--------------------------------------------------------------------------------
1 | from pyhealth.medcode.inner_map import InnerMap
2 |
3 |
4 | # TODO: add convert
5 |
6 |
7 | class ICD9CM(InnerMap):
8 | """9-th International Classification of Diseases, Clinical Modification."""
9 |
10 | def __init__(self, **kwargs):
11 | super(ICD9CM, self).__init__(vocabulary="ICD9CM", **kwargs)
12 |
13 | @staticmethod
14 | def standardize(code: str):
15 | """Standardizes ICD9CM code."""
16 | if "." in code:
17 | return code
18 | if code.startswith("E"):
19 | if len(code) <= 4:
20 | return code
21 | return code[:4] + "." + code[4:]
22 | else:
23 | if len(code) <= 3:
24 | return code
25 | return code[:3] + "." + code[3:]
26 |
27 |
28 | if __name__ == "__main__":
29 | code_sys = ICD9CM(refresh_cache=True)
30 | code_sys.stat()
31 | print("821.01" in code_sys)
32 | print(code_sys.lookup("82101"))
33 | print(code_sys.get_ancestors("821.01"))
34 | print(code_sys.get_descendants("821"))
35 |
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/icd9proc.py:
--------------------------------------------------------------------------------
1 | from pyhealth.medcode.inner_map import InnerMap
2 |
3 |
4 | # TODO: add convert
5 |
6 |
7 | class ICD9PROC(InnerMap):
8 | """9-th International Classification of Diseases, Procedure."""
9 |
10 | def __init__(self, **kwargs):
11 | super(ICD9PROC, self).__init__(vocabulary="ICD9PROC", **kwargs)
12 |
13 | @staticmethod
14 | def standardize(code: str):
15 | """Standardizes ICD9PROC code."""
16 | if "." in code:
17 | return code
18 | if len(code) <= 2:
19 | return code
20 | return code[:2] + "." + code[2:]
21 |
22 |
23 | if __name__ == "__main__":
24 | code_sys = ICD9PROC(refresh_cache=True)
25 | code_sys.stat()
26 | print("81.01" in code_sys)
27 | print(code_sys.lookup("01.31"))
28 | print(code_sys.get_ancestors("01.31"))
29 | print(code_sys.get_descendants("01"))
30 |
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/ndc.py:
--------------------------------------------------------------------------------
1 | from pyhealth.medcode.inner_map import InnerMap
2 |
3 |
4 | # TODO: add standardize for different formats of NDC codes
5 |
6 |
7 | class NDC(InnerMap):
8 | """National Drug Code."""
9 |
10 | def __init__(self, **kwargs):
11 | super(NDC, self).__init__(vocabulary="NDC", **kwargs)
12 |
13 |
14 | if __name__ == "__main__":
15 | code_sys = NDC(refresh_cache=True)
16 | code_sys.stat()
17 | print("00527051210" in code_sys)
18 | print(code_sys.lookup("00527051210"))
19 | print(code_sys.get_ancestors("00527051210"))
20 | print(code_sys.get_descendants("00527051210"))
21 |
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/rxnorm.py:
--------------------------------------------------------------------------------
1 | from pyhealth.medcode.inner_map import InnerMap
2 |
3 |
4 | class RxNorm(InnerMap):
5 | """RxNorm."""
6 |
7 | def __init__(self, **kwargs):
8 | super(RxNorm, self).__init__(vocabulary="RxNorm", **kwargs)
9 |
10 |
11 | if __name__ == "__main__":
12 | code_sys = RxNorm(refresh_cache=True)
13 | code_sys.stat()
14 | print("21914" in code_sys)
15 | print(code_sys.graph.nodes["21914"])
16 | print(code_sys.get_ancestors("21914"))
17 | print(code_sys.get_descendants("21914"))
18 |
--------------------------------------------------------------------------------
/pyhealth/medcode/codes/umls.py:
--------------------------------------------------------------------------------
1 | from pyhealth.medcode.inner_map import InnerMap
2 |
3 |
4 | class UMLS(InnerMap):
5 | """UMLS."""
6 |
7 | def __init__(self, **kwargs):
8 | super(UMLS, self).__init__(vocabulary="UMLS", **kwargs)
9 |
10 |
11 | if __name__ == "__main__":
12 | code_sys = UMLS(refresh_cache=True)
13 | code_sys.stat()
14 | print(code_sys.lookup("C0000768"))
15 |
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | from .kg_emb import *
2 | from .lm_emb import *
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/kg_emb/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import TransE
2 | from .models import ComplEx
3 | from .models import RotatE
4 | from .models import DistMult
5 | from .models import kg_base
6 | from .datasets import *
7 | from .tasks import *
8 |
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .sample_kg_dataset import SampleKGDataset
2 | from .base_kg_dataset import BaseKGDataset
3 | from .umls import UMLSDataset
4 | from .splitter import split
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/splitter.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | from typing import Optional, Tuple, Union, List
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from pyhealth.datasets import SampleBaseDataset
8 |
9 |
10 | def split(
11 | dataset: SampleBaseDataset,
12 | ratios: Union[Tuple[float, float, float], List[float]],
13 | seed: Optional[int] = None,
14 | ):
15 | """Splits the dataset by its outermost indexed items
16 |
17 | Args:
18 | dataset: a `SampleBaseDataset` object
19 | ratios: a list/tuple of ratios for train / val / test
20 | seed: random seed for shuffling the dataset
21 |
22 | Returns:
23 | train_dataset, val_dataset, test_dataset: three subsets of the dataset of
24 | type `torch.utils.data.Subset`.
25 |
26 | Note:
27 | The original dataset can be accessed by `train_dataset.dataset`,
28 | `val_dataset.dataset`, and `test_dataset.dataset`.
29 | """
30 |
31 | if seed is not None:
32 | np.random.seed(seed)
33 | assert sum(ratios) == 1.0, "ratios must sum to 1.0"
34 | index = np.arange(len(dataset))
35 | np.random.shuffle(index)
36 | train_index = index[: int(len(dataset) * ratios[0])]
37 | val_index = index[
38 | int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1]))
39 | ]
40 | test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :]
41 | train_dataset = torch.utils.data.Subset(dataset, train_index)
42 | train_dataset = [{**train_dataset[i], **{'train': True, 'hyperparameters': dataset.task_spec_param}} for i in range(len(train_dataset))]
43 |
44 | val_dataset = torch.utils.data.Subset(dataset, val_index)
45 | val_dataset = [{**val_dataset[i], 'train': False} for i in range(len(val_dataset))]
46 | test_dataset = torch.utils.data.Subset(dataset, test_index)
47 | test_dataset = [{**test_dataset[i], 'train': False} for i in range(len(test_dataset))]
48 | return train_dataset, val_dataset, test_dataset
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/umls.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from tqdm import tqdm
4 | import numpy as np
5 | import pandas as pd
6 | from pandarallel import pandarallel
7 | from pyhealth.medcode.pretrained_embeddings.kg_emb.datasets import BaseKGDataset
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 | class UMLSDataset(BaseKGDataset):
12 | """Base UMLS knowleddge graph dataset
13 |
14 | Dataset is available at https://www.nlm.nih.gov/research/umls/index.html
15 |
16 | Args:
17 | dataset_name: name of the dataset.
18 | root: root directory of the raw data (should contain many csv files).
19 | dev: whether to enable dev mode (only use a small subset of the data).
20 | Default is False.
21 | refresh_cache: whether to refresh the cache; if true, the dataset will
22 | be processed from scratch and the cache will be updated. Default is False.
23 |
24 | """
25 |
26 | def raw_graph_process(self):
27 | pandarallel.initialize(progress_bar=False)
28 | if self.dev == False:
29 | self.graph_path = os.path.join(self.root, "graph.txt")
30 | else:
31 | self.graph_path = os.path.join(self.root, "graph_filtered.txt")
32 |
33 | if os.path.exists(self.graph_path):
34 | logger.debug("umls knowledge graph exists and load umls")
35 | else:
36 | logger.debug("umls does not exist")
37 |
38 | print("Loading UMLS knowledge graph...")
39 | graph_df = pd.read_csv(
40 | self.graph_path,
41 | sep='\t',
42 | names=['e1', 'r', 'e2']
43 | )
44 |
45 | print("Processing UMLS knowledge graph...")
46 | entity_list = pd.unique(graph_df[['e1', 'e2']].values.ravel('K'))
47 | relation_list = pd.unique(graph_df['r'].values)
48 |
49 | self.entity2id = {val: i for i, val in enumerate(entity_list)}
50 | self.relation2id = {val: i for i, val in enumerate(relation_list)}
51 | self.entity_num = len(self.entity2id)
52 | self.relation_num = len(self.relation2id)
53 |
54 | print("Building UMLS knowledge graph...")
55 | self.triples = [(self.entity2id[e1], self.relation2id[r], self.entity2id[e2])
56 | for e1, r, e2 in tqdm(zip(graph_df['e1'], graph_df['r'], graph_df['e2']), total=graph_df.shape[0])]
57 |
58 |
59 | return
60 |
61 |
62 | if __name__ == "__main__":
63 | dataset = UMLSDataset(
64 | root="https://storage.googleapis.com/pyhealth/umls/",
65 | dev=True,
66 | refresh_cache=True,
67 | )
68 | dataset.stat()
69 | dataset.info()
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/kg_emb/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .kg_base import KGEBaseModel
2 | from .transe import TransE
3 | from .rotate import RotatE
4 | from .distmult import DistMult
5 | from .complex import ComplEx
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/kg_emb/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .link_prediction import link_prediction_fn
2 |
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/kg_emb/tasks/link_prediction.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from tqdm import tqdm
4 | from typing import Tuple, List
5 | from collections import defaultdict
6 |
7 | def link_prediction_fn(
8 | triples: List[Tuple]
9 | ):
10 |
11 | """Process a triple list for the link prediction task
12 |
13 | Link prediction is a task to either
14 | Tail Prediction: predict tail entity t given a triple query (h, r, ?), or
15 | Head Prediction: predict head entity h given a triple query (?, r, t)
16 |
17 | Args:
18 | triples: a list of triples (indexed) from the knowledge graph
19 |
20 | Returns:
21 | samples: a list of samples
22 | """
23 | samples = []
24 | count = count_frequency(triples)
25 | triple_set = set(triples)
26 | gt_head, gt_tail = ground_truth_for_query(triple_set)
27 |
28 | for positive_sample in tqdm(triples):
29 | head, relation, tail = positive_sample
30 | subsampling_weight = torch.sqrt(1 / torch.Tensor([count[(head, relation)] + count[(tail, -relation-1)]]))
31 | gt_h = gt_head[(relation, tail)]
32 | gt_t = gt_tail[(head, relation)]
33 |
34 | tmp = {
35 | "triple": positive_sample,
36 | "ground_truth_head": gt_h,
37 | "ground_truth_tail": gt_t,
38 | "subsampling_weight": subsampling_weight
39 | }
40 |
41 | samples.append(tmp)
42 |
43 | return samples
44 |
45 |
46 | def ground_truth_for_query(triple_set):
47 | """
48 | Search ground truth of either query (h, r, ?) or (?, r, t) in the dataset
49 | """
50 | gt_head = defaultdict(list)
51 | gt_tail = defaultdict(list)
52 |
53 | for triple in triple_set:
54 | head, relation, tail = triple
55 | gt_head[(relation, tail)].append(head)
56 | gt_tail[(head, relation)].append(tail)
57 |
58 | return gt_head, gt_tail
59 |
60 |
61 | def count_frequency(triples, start=4):
62 | '''
63 | Get frequency of a partial triple like (head, relation) or (relation, tail)
64 | The frequency will be used for subsampling like word2vec
65 | '''
66 | count = {}
67 | for head, relation, tail in triples:
68 | if (head, relation) not in count:
69 | count[(head, relation)] = start
70 | else:
71 | count[(head, relation)] += 1
72 |
73 | if (tail, -relation-1) not in count:
74 | count[(tail, -relation-1)] = start
75 | else:
76 | count[(tail, -relation-1)] += 1
77 | return count
78 |
79 |
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/lm_emb/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/pyhealth/medcode/pretrained_embeddings/lm_emb/__init__.py
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/lm_emb/huggingface_retriever.py:
--------------------------------------------------------------------------------
1 | def embedding_retrieve(model, tokenizer, phrase):
2 | # Encode the sentence
3 | inputs = tokenizer(phrase, return_tensors='pt')
4 |
5 | # Get the model's output
6 | outputs = model(**inputs)
7 |
8 | # Extract the embeddings
9 | embedding = outputs.last_hidden_state.mean(dim=1)
10 |
11 | # Now, `embedding` is a tensor that contains the embedding for your sentence.
12 | # You can convert it to a numpy array if needed:
13 | embedding = embedding.detach().numpy().tolist()[0]
14 |
15 | return embedding
16 |
--------------------------------------------------------------------------------
/pyhealth/medcode/pretrained_embeddings/lm_emb/openai_retriever.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import json
3 | import os
4 |
5 | key = os.environ['OPENAI_API_KEY']
6 |
7 | def embedding_retrieve(term):
8 | # Set up the API endpoint URL and request headers
9 | url = "https://api.openai.com/v1/embeddings"
10 | headers = {
11 | "Content-Type": "application/json",
12 | "Authorization": f"Bearer {key}"
13 | }
14 |
15 | # Set up the request payload with the text string to embed and the model to use
16 | payload = {
17 | "input": term,
18 | "model": "text-embedding-ada-002"
19 | }
20 |
21 | # Send the request and retrieve the response
22 | response = requests.post(url, headers=headers, data=json.dumps(payload))
23 |
24 | # Extract the text embeddings from the response JSON
25 | embedding = response.json()["data"][0]['embedding']
26 |
27 | return embedding
--------------------------------------------------------------------------------
/pyhealth/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .binary import binary_metrics_fn
2 | from .drug_recommendation import ddi_rate_score
3 | from .multiclass import multiclass_metrics_fn
4 | from .multilabel import multilabel_metrics_fn
5 |
6 | # from .fairness import fairness_metrics_fn
7 | from .ranking import ranking_metrics_fn
8 | from .regression import regression_metrics_fn
9 |
--------------------------------------------------------------------------------
/pyhealth/metrics/drug_recommendation.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 |
5 |
6 | # TODO: this metric is very ad-hoc, need to be improved
7 |
8 |
9 | def ddi_rate_score(medications: List[np.ndarray], ddi_matrix: np.ndarray) -> float:
10 | """DDI rate score.
11 |
12 | Args:
13 | medications: list of medications for each patient, where each medication
14 | is represented by the corresponding index in the ddi matrix.
15 | ddi_matrix: array-like of shape (n_classes, n_classes).
16 |
17 | Returns:
18 | result: DDI rate score.
19 | """
20 | all_cnt = 0
21 | ddi_cnt = 0
22 | for sample in medications:
23 | for i, med_i in enumerate(sample):
24 | for j, med_j in enumerate(sample):
25 | if j <= i: continue
26 | all_cnt += 1
27 | if ddi_matrix[med_i, med_j] == 1 or ddi_matrix[med_j, med_i] == 1:
28 | ddi_cnt += 1
29 | if all_cnt == 0:
30 | return 0
31 | return ddi_cnt / all_cnt
32 |
--------------------------------------------------------------------------------
/pyhealth/metrics/fairness.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 |
3 | import numpy as np
4 |
5 | from pyhealth.metrics.fairness_utils import disparate_impact, statistical_parity_difference
6 |
7 | def fairness_metrics_fn(
8 | y_true: np.ndarray,
9 | y_prob: np.ndarray,
10 | sensitive_attributes: np.ndarray,
11 | favorable_outcome: int = 1,
12 | metrics: Optional[List[str]] = None,
13 | threshold: float = 0.5,
14 | ) -> Dict[str, float]:
15 | """Computes metrics for binary classification.
16 |
17 | User can specify which metrics to compute by passing a list of metric names.
18 | The accepted metric names are:
19 | - disparate_impact:
20 | - statistical_parity_difference:
21 |
22 | If no metrics are disparate_impact, and statistical_parity_difference are computed by default.
23 |
24 | Args:
25 | y_true: True target values of shape (n_samples,).
26 | y_prob: Predicted probabilities of shape (n_samples,).
27 | sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group.
28 | favorable_outcome: Label value which is considered favorable (i.e. "positive").
29 | metrics: List of metrics to compute. Default is ["disparate_impact", "statistical_parity_difference"].
30 | threshold: Threshold for binary classification. Default is 0.5.
31 |
32 | Returns:
33 | Dictionary of metrics whose keys are the metric names and values are
34 | the metric values.
35 | """
36 | if metrics is None:
37 | metrics = ["disparate_impact", "statistical_parity_difference"]
38 |
39 | y_pred = y_prob.copy()
40 | y_pred[y_pred >= threshold] = 1
41 | y_pred[y_pred < threshold] = 0
42 |
43 | output = {}
44 | for metric in metrics:
45 | if metric == "disparate_impact":
46 | output[metric] = disparate_impact(sensitive_attributes, y_pred, favorable_outcome)
47 | elif metric == "statistical_parity_difference":
48 | output[metric] = statistical_parity_difference(sensitive_attributes, y_pred, favorable_outcome)
49 | else:
50 | raise ValueError(f"Unknown metric for fairness: {metric}")
51 | return output
52 |
53 |
--------------------------------------------------------------------------------
/pyhealth/metrics/fairness_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .group import disparate_impact, statistical_parity_difference
2 | from .utils import sensitive_attributes_from_patient_ids
--------------------------------------------------------------------------------
/pyhealth/metrics/fairness_utils/group.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | """
4 | Notation:
5 | - Protected group: P
6 | - Unprotected group: U
7 | """
8 |
9 | def disparate_impact(sensitive_attributes: np.ndarray, y_pred: np.ndarray, favorable_outcome: int = 1, allow_zero_division = False, epsilon: float = 1e-8) -> float:
10 | """
11 | Computes the disparate impact between the the protected and unprotected group.
12 |
13 | disparate_impact = P(y_pred = favorable_outcome | P) / P(y_pred = favorable_outcome | U)
14 |
15 | Args:
16 | sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group.
17 | y_pred: Predicted target values of shape (n_samples,).
18 | favorable_outcome: Label value which is considered favorable (i.e. "positive").
19 | allow_zero_division: If True, use epsilon instead of 0 in the denominator if the denominator is 0. Otherwise, raise a ValueError.
20 |
21 | Returns:
22 | The disparate impact between the protected and unprotected group.
23 | """
24 |
25 | p_fav_unpr = np.sum(y_pred[sensitive_attributes == 0] == favorable_outcome) / len(y_pred[sensitive_attributes == 0])
26 | p_fav_prot = np.sum(y_pred[sensitive_attributes == 1] == favorable_outcome) / len(y_pred[sensitive_attributes == 1])
27 |
28 | if p_fav_unpr == 0:
29 | if allow_zero_division:
30 | p_fav_unpr = epsilon
31 | else:
32 | raise ValueError("Unprotected group has no instances with a favorable outcome. Disparate impact is undefined.")
33 |
34 | disparate_impact_value = p_fav_prot / p_fav_unpr
35 |
36 | return disparate_impact_value
37 |
38 | def statistical_parity_difference(sensitive_attributes: np.ndarray, y_pred: np.ndarray, favorable_outcome: int = 1) -> float:
39 | """
40 | Computes the statistical parity difference between the the protected and unprotected group.
41 |
42 | statistical_parity_difference = P(y_pred = favorable_outcome | P) - P(y_pred = favorable_outcome | U)
43 | Args:
44 | sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group.
45 | y_pred: Predicted target values of shape (n_samples,).
46 | favorable_outcome: Label value which is considered favorable (i.e. "positive").
47 | Returns:
48 | The statistical parity difference between the protected and unprotected group.
49 | """
50 |
51 | p_fav_unpr = np.sum(y_pred[sensitive_attributes == 0] == favorable_outcome) / len(y_pred[sensitive_attributes == 0])
52 | p_fav_prot = np.sum(y_pred[sensitive_attributes == 1] == favorable_outcome) / len(y_pred[sensitive_attributes == 1])
53 |
54 | statistical_parity_difference_value = p_fav_prot - p_fav_unpr
55 |
56 | return statistical_parity_difference_value
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------
/pyhealth/metrics/fairness_utils/utils.py:
--------------------------------------------------------------------------------
1 |
2 | from typing import List
3 | import numpy as np
4 |
5 | from pyhealth.datasets import BaseEHRDataset
6 |
7 | def sensitive_attributes_from_patient_ids(dataset: BaseEHRDataset,
8 | patient_ids: List[str],
9 | sensitive_attribute: str,
10 | protected_group: str) -> np.ndarray:
11 | """
12 | Returns the desired sensitive attribute array from patient_ids.
13 |
14 | Args:
15 | dataset: Dataset object.
16 | patient_ids: List of patient IDs.
17 | sensitive_attribute: Sensitive attribute to extract.
18 | protected_group: Value of the protected group.
19 |
20 | Returns:
21 | Sensitive attribute array of shape (n_samples,).
22 | """
23 |
24 | sensitive_attribute_array = np.zeros(len(patient_ids))
25 | for idx, patient_id in enumerate(patient_ids):
26 | sensitive_attribute_value = getattr(dataset.patients[patient_id], sensitive_attribute)
27 | if sensitive_attribute_value == protected_group:
28 | sensitive_attribute_array[idx] = 1
29 | return sensitive_attribute_array
30 |
31 |
--------------------------------------------------------------------------------
/pyhealth/metrics/regression.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 |
3 | import numpy as np
4 | import sklearn.metrics as sklearn_metrics
5 |
6 |
7 | def regression_metrics_fn(
8 | x: np.ndarray,
9 | x_rec: np.ndarray,
10 | metrics: Optional[List[str]] = None,
11 | ) -> Dict[str, float]:
12 | """Computes metrics for regression.
13 |
14 | User can specify which metrics to compute by passing a list of metric names.
15 | The accepted metric names are:
16 | - kl_divergence: KL divergence
17 | - mse: mean squared error
18 | - mae: mean absolute error
19 | If no metrics are specified, kd_div, mse, mae are computed by default.
20 |
21 | This function calls sklearn.metrics functions to compute the metrics. For
22 | more information on the metrics, please refer to the documentation of the
23 | corresponding sklearn.metrics functions.
24 |
25 | Args:
26 | x: True target data sample.
27 | x_rec: reconstructed data sample.
28 | metrics: List of metrics to compute. Default is ["kl_divergence", "mse", "mae"].
29 |
30 | Returns:
31 | Dictionary of metrics whose keys are the metric names and values are
32 | the metric values.
33 |
34 | Examples:
35 | >>> from pyhealth.metrics import binary_metrics_fn
36 | >>> y_true = np.array([0, 0, 1, 1])
37 | >>> y_prob = np.array([0.1, 0.4, 0.35, 0.8])
38 | >>> binary_metrics_fn(y_true, y_prob, metrics=["accuracy"])
39 | {'accuracy': 0.75}
40 | """
41 | if metrics is None:
42 | metrics = ["kl_divergence", "mse", "mae"]
43 |
44 | x = x.flatten()
45 | x_rec = x_rec.flatten()
46 |
47 | if x.shape != x_rec.shape:
48 | raise ValueError("x and x_rec must have the same shape.")
49 |
50 | output = {}
51 | for metric in metrics:
52 | if metric == "kl_divergence":
53 | x[x < 1e-6] = 1e-6
54 | x_rec[x_rec < 1e-6] = 1e-6
55 | x = x / np.sum(x)
56 | x_rec = x_rec / np.sum(x_rec)
57 | kl_divergence = np.sum(x_rec * np.log(x_rec / x))
58 | output["kl_divergence"] = kl_divergence
59 | elif metric == "mse":
60 | mse = sklearn_metrics.mean_squared_error(x, x_rec)
61 | output["mse"] = mse
62 | elif metric == "mae":
63 | mae = sklearn_metrics.mean_absolute_error(x, x_rec)
64 | output["mae"] = mae
65 | else:
66 | raise ValueError(f"Unknown metric for regression task: {metric}")
67 | return output
68 |
69 |
70 | if __name__ == "__main__":
71 | x = np.random.random(size=10000)
72 | x_rec = np.random.random(size=10000)
73 | print(regression_metrics_fn(x, x_rec))
74 |
--------------------------------------------------------------------------------
/pyhealth/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .adacare import AdaCare, AdaCareLayer
2 | from .agent import Agent, AgentLayer
3 | from .base_model import BaseModel
4 | from .cnn import CNN, CNNLayer
5 | from .concare import ConCare, ConCareLayer
6 | from .contrawr import ContraWR, ResBlock2D
7 | from .deepr import Deepr, DeeprLayer
8 | from .embedding import EmbeddingModel
9 | from .gamenet import GAMENet, GAMENetLayer
10 | from .gan import GAN
11 | from .gnn import GAT, GCN
12 | from .graph_torchvision_model import Graph_TorchvisionModel
13 | from .grasp import GRASP, GRASPLayer
14 | from .medlink import MedLink
15 | from .micron import MICRON, MICRONLayer
16 | from .mlp import MLP
17 | from .molerec import MoleRec, MoleRecLayer
18 | from .retain import RETAIN, RETAINLayer
19 | from .rnn import RNN, RNNLayer
20 | from .safedrug import SafeDrug, SafeDrugLayer
21 | from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer
22 | from .stagenet import StageNet, StageNetLayer
23 | from .tcn import TCN, TCNLayer
24 | from .torchvision_model import TorchvisionModel
25 | from .transformer import Transformer, TransformerLayer
26 | from .transformers_model import TransformersModel
27 | from .vae import VAE
28 |
--------------------------------------------------------------------------------
/pyhealth/models/embedding.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ..datasets import SampleDataset
7 | from ..processors import SequenceProcessor, TimeseriesProcessor
8 | from .base_model import BaseModel
9 |
10 |
11 | class EmbeddingModel(BaseModel):
12 | """
13 | EmbeddingModel is responsible for creating embedding layers for different types of input data.
14 |
15 | Attributes:
16 | dataset (SampleDataset): The dataset containing input processors.
17 | embedding_layers (nn.ModuleDict): A dictionary of embedding layers for each input field.
18 | """
19 |
20 | def __init__(self, dataset: SampleDataset, embedding_dim: int = 128):
21 | """
22 | Initializes the EmbeddingModel with the given dataset and embedding dimension.
23 |
24 | Args:
25 | dataset (SampleDataset): The dataset containing input processors.
26 | embedding_dim (int): The dimension of the embedding space. Default is 128.
27 | """
28 | super().__init__(dataset)
29 | self.embedding_layers = nn.ModuleDict()
30 | for field_name, processor in self.dataset.input_processors.items():
31 | if isinstance(processor, SequenceProcessor):
32 | vocab_size = len(processor.code_vocab)
33 | self.embedding_layers[field_name] = nn.Embedding(
34 | num_embeddings=vocab_size,
35 | embedding_dim=embedding_dim,
36 | padding_idx=0
37 | )
38 | elif isinstance(processor, TimeseriesProcessor):
39 | self.embedding_layers[field_name] = nn.Linear(
40 | in_features=processor.size,
41 | out_features=embedding_dim
42 | )
43 |
44 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
45 | """
46 | Forward pass to compute embeddings for the input data.
47 |
48 | Args:
49 | inputs (Dict[str, torch.Tensor]): A dictionary of input tensors.
50 |
51 | Returns:
52 | Dict[str, torch.Tensor]: A dictionary of embedded tensors.
53 | """
54 | embedded = {}
55 | for field_name, tensor in inputs.items():
56 | if field_name in self.embedding_layers:
57 | tensor = tensor.to(self.device)
58 | embedded[field_name] = self.embedding_layers[field_name](tensor)
59 | else:
60 | embedded[field_name] = tensor # passthrough for continuous features
61 | return embedded
62 |
63 | def __repr__(self) -> str:
64 | """
65 | Returns a string representation of the EmbeddingModel.
66 |
67 | Returns:
68 | str: A string representation of the model.
69 | """
70 | return f"EmbeddingModel(embedding_layers={self.embedding_layers})"
71 |
--------------------------------------------------------------------------------
/pyhealth/models/medlink/__init__.py:
--------------------------------------------------------------------------------
1 | from .bm25 import BM25Okapi
2 | from .model import MedLink
3 | from .utils import convert_to_ir_format
4 | from .utils import filter_by_candidates
5 | from .utils import generate_candidates
6 | from .utils import get_bm25_hard_negatives
7 | from .utils import get_eval_dataloader
8 | from .utils import get_train_dataloader
9 | from .utils import tvt_split
10 |
--------------------------------------------------------------------------------
/pyhealth/models/transformers_model.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | import torch
4 | import torch.nn as nn
5 | from transformers import AutoModel, AutoTokenizer
6 |
7 | from ..datasets import SampleDataset
8 | from .base_model import BaseModel
9 |
10 |
11 | class TransformersModel(BaseModel):
12 | """Transformers class for Huggingface models.
13 | """
14 |
15 | def __init__(
16 | self,
17 | dataset: SampleDataset,
18 | model_name: str,
19 | ):
20 | super(TransformersModel, self).__init__(
21 | dataset=dataset,
22 | )
23 | self.model_name = model_name
24 | self.model = AutoModel.from_pretrained(model_name)
25 | assert len(self.feature_keys) == 1, "Only one feature key is supported if Transformers is initialized"
26 | self.feature_key = self.feature_keys[0]
27 | assert len(self.label_keys) == 1, "Only one label key is supported if RNN is initialized"
28 | self.label_key = self.label_keys[0]
29 | self.mode = self.dataset.output_schema[self.label_key]
30 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
31 | output_size = self.get_output_size()
32 | hidden_dim = self.model.config.hidden_size
33 | self.fc = nn.Linear(hidden_dim, output_size)
34 |
35 | def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
36 | """Forward propagation."""
37 | # concat the info within one batch (batch, channel, length)
38 | x = kwargs[self.feature_key]
39 | # TODO: max_length should be a parameter
40 | x = self.tokenizer(
41 | x, return_tensors="pt", padding=True, truncation=True, max_length=256
42 | )
43 | x = x.to(self.device)
44 | # TODO: should not use pooler_output, but use the last hidden state
45 | embeddings = self.model(**x).pooler_output
46 | logits = self.fc(embeddings)
47 | y_true = kwargs[self.label_key].to(self.device)
48 | loss = self.get_loss_function()(logits, y_true)
49 | y_prob = self.prepare_y_prob(logits)
50 | return {
51 | "loss": loss,
52 | "y_prob": y_prob,
53 | "y_true": y_true,
54 | }
55 |
56 |
57 | if __name__ == "__main__":
58 | from pyhealth.datasets import MedicalTranscriptionsDataset, get_dataloader
59 |
60 | base_dataset = MedicalTranscriptionsDataset(
61 | root="/srv/local/data/zw12/raw_data/MedicalTranscriptions"
62 | )
63 |
64 | sample_dataset = base_dataset.set_task()
65 |
66 | train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True)
67 |
68 | model = TransformersModel(
69 | dataset=sample_dataset,
70 | feature_keys=["transcription"],
71 | label_key="label",
72 | mode="multiclass",
73 | model_name="emilyalsentzer/Bio_ClinicalBERT",
74 | )
75 |
76 | # data batch
77 | data_batch = next(iter(train_loader))
78 |
79 | # try the model
80 | ret = model(**data_batch)
81 | print(ret)
82 |
83 | # try loss backward
84 | ret["loss"].backward()
--------------------------------------------------------------------------------
/pyhealth/models/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 |
5 |
6 | def batch_to_multihot(label: List[List[int]], num_labels: int) -> torch.tensor:
7 | """Converts label to multihot format.
8 |
9 | Args:
10 | label: [batch size, *]
11 | num_labels: total number of labels
12 |
13 | Returns:
14 | multihot: [batch size, num_labels]
15 | """
16 | multihot = torch.zeros((len(label), num_labels))
17 | for i, l in enumerate(label):
18 | multihot[i, l] = 1
19 | return multihot
20 |
21 |
22 | def get_last_visit(hidden_states, mask):
23 | """Gets the last visit from the sequence model.
24 |
25 | Args:
26 | hidden_states: [batch size, seq len, hidden_size]
27 | mask: [batch size, seq len]
28 |
29 | Returns:
30 | last_visit: [batch size, hidden_size]
31 | """
32 | if mask is None:
33 | return hidden_states[:, -1, :]
34 | else:
35 | mask = mask.long()
36 | last_visit = torch.sum(mask, 1) - 1
37 | last_visit = last_visit.unsqueeze(-1)
38 | last_visit = last_visit.expand(-1, hidden_states.shape[1] * hidden_states.shape[2])
39 | last_visit = torch.reshape(last_visit, hidden_states.shape)
40 | last_hidden_states = torch.gather(hidden_states, 1, last_visit)
41 | last_hidden_state = last_hidden_states[:, 0, :]
42 | return last_hidden_state
43 |
--------------------------------------------------------------------------------
/pyhealth/processors/__init__.py:
--------------------------------------------------------------------------------
1 | PROCESSOR_REGISTRY = {}
2 |
3 |
4 | def register_processor(name: str):
5 | def decorator(cls):
6 | if name in PROCESSOR_REGISTRY:
7 | raise ValueError(f"Processor '{name}' already registered.")
8 | PROCESSOR_REGISTRY[name] = cls
9 | return cls
10 | return decorator
11 |
12 |
13 | def get_processor(name: str):
14 | if name not in PROCESSOR_REGISTRY:
15 | raise ValueError(f"Unknown processor: {name}")
16 | return PROCESSOR_REGISTRY[name]
17 |
18 |
19 | # Import all processors so they register themselves
20 | from .image_processor import ImageProcessor
21 | from .label_processor import (
22 | BinaryLabelProcessor,
23 | MultiClassLabelProcessor,
24 | MultiLabelProcessor,
25 | RegressionLabelProcessor,
26 | )
27 | from .raw_processor import RawProcessor
28 | from .sequence_processor import SequenceProcessor
29 | from .signal_processor import SignalProcessor
30 | from .text_processor import TextProcessor
31 | from .timeseries_processor import TimeseriesProcessor
32 |
33 | # Expose public API
34 | __all__ = [
35 | "get_processor",
36 | "ImageProcessor",
37 | "SequenceProcessor",
38 | "TimeseriesProcessor",
39 | "SignalProcessor",
40 | "BinaryLabelProcessor",
41 | "MultiClassLabelProcessor",
42 | "MultiLabelProcessor",
43 | "RegressionLabelProcessor",
44 | "RawProcessor",
45 | "TextProcessor",
46 | ]
47 |
--------------------------------------------------------------------------------
/pyhealth/processors/base_processor.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, Dict, List, Optional
3 |
4 |
5 | class Processor(ABC):
6 | """
7 | Abstract base processor class.
8 |
9 | Defines optional hooks for saving/loading state to/from disk.
10 | """
11 |
12 | def save(self, path: str) -> None:
13 | """Optional: Save processor state to disk.
14 |
15 | Args:
16 | path: File path to save processor state.
17 | """
18 | pass
19 |
20 | def load(self, path: str) -> None:
21 | """Optional: Load processor state from disk.
22 |
23 | Args:
24 | path: File path to load processor state from.
25 | """
26 | pass
27 |
28 |
29 | class FeatureProcessor(Processor):
30 | """
31 | Processor for individual fields (features).
32 |
33 | Example: Tokenization, image loading, normalization.
34 | """
35 |
36 | def fit(self, samples: List[Dict[str, Any]], field: str) -> None:
37 | """Fit the processor to the samples.
38 |
39 | Args:
40 | samples: List of sample dictionaries.
41 | """
42 | pass
43 |
44 | @abstractmethod
45 | def process(self, value: Any) -> Any:
46 | """Process an individual field value.
47 |
48 | Args:
49 | value: Raw field value.
50 |
51 | Returns:
52 | Processed value.
53 | """
54 | pass
55 |
56 |
57 | class SampleProcessor(Processor):
58 | """
59 | Processor for individual samples (dict of fields).
60 |
61 | Example: Imputation, sample-level augmentation, label smoothing.
62 | """
63 |
64 | @abstractmethod
65 | def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
66 | """Process a single sample dictionary.
67 |
68 | Args:
69 | sample: Sample dictionary.
70 |
71 | Returns:
72 | Processed sample dictionary.
73 | """
74 | pass
75 |
76 |
77 | class DatasetProcessor(Processor):
78 | """
79 | Processor for the full dataset.
80 |
81 | Example: Global normalization, train/val/test splitting, dataset caching.
82 | """
83 |
84 | @abstractmethod
85 | def process(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
86 | """Process the entire dataset.
87 |
88 | Args:
89 | samples: List of sample dictionaries.
90 |
91 | Returns:
92 | List of processed sample dictionaries.
93 | """
94 | pass
95 |
--------------------------------------------------------------------------------
/pyhealth/processors/raw_processor.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from . import register_processor
4 | from .base_processor import FeatureProcessor
5 |
6 |
7 | @register_processor("raw")
8 | class RawProcessor(FeatureProcessor):
9 | """
10 | Processor that returns the raw value.
11 | """
12 |
13 | def process(self, value: Any) -> str:
14 | return value
15 |
16 | def size(self):
17 | return None
18 |
19 | def __repr__(self):
20 | return "RawProcessor()"
21 |
--------------------------------------------------------------------------------
/pyhealth/processors/sequence_processor.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 | import torch
4 |
5 | from . import register_processor
6 | from .base_processor import FeatureProcessor
7 |
8 |
9 | @register_processor("sequence")
10 | class SequenceProcessor(FeatureProcessor):
11 | """
12 | Feature processor for encoding categorical sequences (e.g., medical codes) into numerical indices.
13 |
14 | Supports single or multiple tokens (e.g., single diagnosis or list of procedures).
15 | Can build vocabulary on the fly if not provided.
16 | """
17 |
18 | def __init__(self):
19 | # -1 for for ease of boolean arithmetic > 0, > -1, etc.
20 | # TODO: this can be a problem if we pass -1 into nn.Embedding
21 | self.code_vocab: Dict[Any, int] = {"": -1, "": 0}
22 | self._next_index = 1
23 |
24 | def process(self, value: Any) -> torch.Tensor:
25 | """Process token value(s) into tensor of indices.
26 |
27 | Args:
28 | value: Raw token string or list of token strings.
29 |
30 | Returns:
31 | Tensor of indices.
32 | """
33 | indices = []
34 | for token in value:
35 | if token is None: # missing values
36 | indices.append(self.code_vocab[""])
37 | else:
38 | if token not in self.code_vocab:
39 | self.code_vocab[token] = self._next_index
40 | self._next_index += 1
41 | indices.append(self.code_vocab[token])
42 | return torch.tensor(indices, dtype=torch.long)
43 |
44 | def size(self):
45 | return len(self.code_vocab)
46 |
47 | def __repr__(self):
48 | return (
49 | f"SequenceProcessor(code_vocab_size={len(self.code_vocab)})"
50 | )
51 |
--------------------------------------------------------------------------------
/pyhealth/processors/signal_processor.py:
--------------------------------------------------------------------------------
1 | from .base_processor import FeatureProcessor
2 |
3 |
4 | class SignalProcessor(FeatureProcessor):
5 | pass
--------------------------------------------------------------------------------
/pyhealth/processors/text_processor.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from . import register_processor
4 | from .base_processor import FeatureProcessor
5 |
6 |
7 | @register_processor("text")
8 | class TextProcessor(FeatureProcessor):
9 | """
10 | Processor for text data.
11 | """
12 |
13 | def process(self, value: Any) -> str:
14 | return value
15 |
16 | def size(self):
17 | return None
18 |
19 | def __repr__(self):
20 | return "TextProcessor()"
21 |
--------------------------------------------------------------------------------
/pyhealth/sampler/__init__.py:
--------------------------------------------------------------------------------
1 | from .sage_sampler import NeighborSampler
--------------------------------------------------------------------------------
/pyhealth/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_task import BaseTask
2 | from .benchmark_ehrshot import BenchmarkEHRShot
3 | from .cardiology_detect import (
4 | cardiology_isAD_fn,
5 | cardiology_isAR_fn,
6 | cardiology_isBBBFB_fn,
7 | cardiology_isCD_fn,
8 | cardiology_isWA_fn,
9 | )
10 | from .covid19_cxr_classification import COVID19CXRClassification
11 | from .drug_recommendation import (
12 | drug_recommendation_eicu_fn,
13 | drug_recommendation_mimic3_fn,
14 | drug_recommendation_mimic4_fn,
15 | drug_recommendation_omop_fn,
16 | )
17 | from .EEG_abnormal import EEG_isAbnormal_fn
18 | from .EEG_events import EEG_events_fn
19 | from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4
20 | from .length_of_stay_prediction import (
21 | length_of_stay_prediction_eicu_fn,
22 | length_of_stay_prediction_mimic3_fn,
23 | length_of_stay_prediction_mimic4_fn,
24 | length_of_stay_prediction_omop_fn,
25 | )
26 | from .medical_coding import MIMIC3ICD9Coding
27 | from .medical_transcriptions_classification import MedicalTranscriptionsClassification
28 | from .mortality_prediction import (
29 | MortalityPredictionEICU,
30 | MortalityPredictionEICU2,
31 | MortalityPredictionMIMIC3,
32 | MortalityPredictionMIMIC4,
33 | MortalityPredictionOMOP,
34 | )
35 | from .patient_linkage import patient_linkage_mimic3_fn
36 | from .readmission_30days_mimic4 import Readmission30DaysMIMIC4
37 | from .readmission_prediction import (
38 | readmission_prediction_eicu_fn,
39 | readmission_prediction_eicu_fn2,
40 | readmission_prediction_mimic3_fn,
41 | readmission_prediction_mimic4_fn,
42 | readmission_prediction_omop_fn,
43 | )
44 | from .sleep_staging import (
45 | sleep_staging_isruc_fn,
46 | sleep_staging_shhs_fn,
47 | sleep_staging_sleepedf_fn,
48 | )
49 | from .sleep_staging_v2 import SleepStagingSleepEDF
50 | from .temple_university_EEG_tasks import EEG_events_fn, EEG_isAbnormal_fn
51 |
--------------------------------------------------------------------------------
/pyhealth/tasks/base_task.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, List
3 |
4 | import polars as pl
5 |
6 |
7 | class BaseTask(ABC):
8 | task_name: str
9 | input_schema: Dict[str, str]
10 | output_schema: Dict[str, str]
11 |
12 | def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
13 | return df
14 |
15 | @abstractmethod
16 | def __call__(self, patient) -> List[Dict]:
17 | raise NotImplementedError
--------------------------------------------------------------------------------
/pyhealth/tasks/covid19_cxr_classification.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 | from .base_task import BaseTask
4 |
5 |
6 | class COVID19CXRClassification(BaseTask):
7 | """A task for classifying chest disease from chest X-ray images.
8 |
9 | This task classifies chest X-ray images into different disease categories.
10 | It expects a single chest X-ray image per patient and returns the
11 | corresponding disease label.
12 |
13 | Attributes:
14 | task_name (str): The name of the task, set to
15 | "COVID19CXRClassification".
16 | input_schema (Dict[str, str]): The input schema specifying the required
17 | input format. Contains a single key "image" with value "image".
18 | output_schema (Dict[str, str]): The output schema specifying the output
19 | format. Contains a single key "disease" with value "multiclass".
20 | """
21 |
22 | task_name: str = "COVID19CXRClassification"
23 | input_schema: Dict[str, str] = {"image": "image"}
24 | output_schema: Dict[str, str] = {"disease": "multiclass"}
25 |
26 | def __call__(self, patient: Any) -> List[Dict[str, Any]]:
27 | """Process a patient's chest X-ray data to classify COVID-19 status.
28 |
29 | Args:
30 | patient: A patient object containing chest X-ray data.
31 |
32 | Returns:
33 | List[Dict[str, Any]]: A list containing a single dictionary with:
34 | - "image": Path to the chest X-ray image
35 | - "disease": The disease classification label
36 |
37 | Raises:
38 | AssertionError: If the patient has more than one chest X-ray event.
39 | """
40 | event = patient.get_events(event_type="covid19_cxr")
41 | # There should be only one event
42 | assert len(event) == 1
43 | event = event[0]
44 | image = event.path
45 | disease = event.label
46 | return [{"image": image, "disease": disease}]
47 |
--------------------------------------------------------------------------------
/pyhealth/tasks/medical_transcriptions_classification.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 | from ..data import Patient
4 | from .base_task import BaseTask
5 |
6 |
7 | class MedicalTranscriptionsClassification(BaseTask):
8 | """Task for classifying medical transcriptions into medical specialties.
9 |
10 | This task takes medical transcription text as input and predicts the
11 | corresponding medical specialty. It processes patient records containing
12 | mtsamples events and extracts transcription and medical specialty
13 | information.
14 |
15 | Attributes:
16 | task_name (str): Name of the task
17 | input_schema (Dict[str, str]): Schema defining input features
18 | output_schema (Dict[str, str]): Schema defining output features
19 | """
20 | task_name: str = "MedicalTranscriptionsClassification"
21 | input_schema: Dict[str, str] = {"transcription": "text"}
22 | output_schema: Dict[str, str] = {"medical_specialty": "multiclass"}
23 |
24 | def __call__(self, patient: Patient) -> List[Dict[str, Any]]:
25 | """Process a patient record to extract medical transcription samples.
26 |
27 | Args:
28 | patient (Patient): Patient record containing medical
29 | transcription events
30 |
31 | Returns:
32 | List[Dict[str, Any]]: List of samples containing transcription
33 | and medical specialty
34 | """
35 | event = patient.get_events(event_type="mtsamples")
36 | # There should be only one event
37 | assert len(event) == 1
38 | event = event[0]
39 |
40 | transcription_valid = isinstance(event.transcription, str)
41 | specialty_valid = isinstance(event.medical_specialty, str)
42 |
43 | if transcription_valid and specialty_valid:
44 | sample = {
45 | "id": patient.patient_id,
46 | "transcription": event.transcription,
47 | "medical_specialty": event.medical_specialty,
48 | }
49 | return [sample]
50 | else:
51 | return []
52 |
--------------------------------------------------------------------------------
/pyhealth/tasks/patient_linkage.py:
--------------------------------------------------------------------------------
1 | from pyhealth.data import Patient
2 |
3 |
4 | def patient_linkage_mimic3_fn(patient: Patient):
5 | """ Patient linkage task for the mimic3 dataset. """
6 |
7 | # exclude patients with less than two visits
8 | if len(patient) < 2:
9 | return []
10 |
11 | samples = []
12 | q_visit = patient.get_visit_by_index(len(patient) - 1)
13 | d_visit = patient.get_visit_by_index(len(patient) - 2)
14 |
15 | q_age = (q_visit.encounter_time - patient.birth_datetime).days // 365.25
16 | d_age = (d_visit.encounter_time - patient.birth_datetime).days // 365.25
17 |
18 | # exclude patients under 18
19 | if (q_age < 18) or (d_age < 18):
20 | return []
21 |
22 | q_conditions = q_visit.get_code_list(table="DIAGNOSES_ICD")
23 | d_conditions = d_visit.get_code_list(table="DIAGNOSES_ICD")
24 |
25 | # exclude patients without conditions
26 | if len(q_conditions) * len(d_conditions) == 0:
27 | return []
28 |
29 | # identifiers
30 | gender = patient.gender
31 | insurance = q_visit.attr_dict["insurance"]
32 | language = q_visit.attr_dict["language"]
33 | religion = q_visit.attr_dict["religion"]
34 | marital_status = q_visit.attr_dict["marital_status"]
35 | ethnicity = q_visit.attr_dict["ethnicity"]
36 | insurance = "" if insurance != insurance else insurance
37 | language = "" if language != language else language
38 | religion = "" if religion != religion else religion
39 | marital_status = "" if marital_status != marital_status else marital_status
40 | ethnicity = "" if ethnicity != ethnicity else ethnicity
41 | q_identifiers = "+".join(
42 | [gender, insurance, language, religion, marital_status, ethnicity]
43 | )
44 |
45 | insurance = d_visit.attr_dict["insurance"]
46 | language = d_visit.attr_dict["language"]
47 | religion = d_visit.attr_dict["religion"]
48 | marital_status = d_visit.attr_dict["marital_status"]
49 | ethnicity = d_visit.attr_dict["ethnicity"]
50 | insurance = "" if insurance != insurance else insurance
51 | language = "" if language != language else language
52 | religion = "" if religion != religion else religion
53 | marital_status = "" if marital_status != marital_status else marital_status
54 | ethnicity = "" if ethnicity != ethnicity else ethnicity
55 | d_identifiers = "+".join(
56 | [gender, insurance, language, religion, marital_status, ethnicity]
57 | )
58 |
59 | samples.append({
60 | "patient_id": patient.patient_id,
61 | "visit_id": q_visit.visit_id,
62 | "conditions": [""] + q_conditions,
63 | "age": q_age,
64 | "identifiers": q_identifiers,
65 | "d_visit_id": d_visit.visit_id,
66 | "d_conditions": [""] + d_conditions,
67 | "d_age": d_age,
68 | "d_identifiers": d_identifiers,
69 | })
70 |
71 | return samples
72 |
--------------------------------------------------------------------------------
/pyhealth/unittests/test_data/test_data.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import unittest
3 |
4 | from pyhealth.data import Event, Patient
5 |
6 |
7 | class TestEvent(unittest.TestCase):
8 | def setUp(self):
9 | self.event = Event(
10 | type="DIAGNOSES_ICD",
11 | timestamp=datetime(2012, 1, 1, 0, 0),
12 | attr_dict={"code": "428.0", "vocabulary": "ICD9CM"}
13 | )
14 |
15 | def test_type(self):
16 | self.assertIsInstance(self.event.type, str)
17 | self.assertIsInstance(self.event.timestamp, datetime)
18 |
19 | def test_attr(self):
20 | self.assertEqual(self.event.type, "DIAGNOSES_ICD")
21 | self.assertEqual(self.event.timestamp, datetime(2012, 1, 1, 0, 0))
22 | attr_dict = self.event.attr_dict
23 | self.assertEqual(attr_dict["code"], "428.0")
24 | self.assertEqual(attr_dict["vocabulary"], "ICD9CM")
25 |
26 | def test_repr_and_str(self):
27 | print(repr(self.event))
28 | print(str(self.event))
29 |
30 |
31 | class TestPatient(unittest.TestCase):
32 | def setUp(self):
33 | self.event1 = Event(type="diagnosis", timestamp=datetime(2023, 5, 17))
34 | self.event2 = Event(type="prescription", timestamp=datetime(2023, 5, 18))
35 | self.event3 = Event(type="lab_test", timestamp=None)
36 | self.patient = Patient(patient_id="12345",
37 | attr_dict={"name": "John Doe", "age": 45})
38 |
39 | def test_attr(self):
40 | self.assertEqual(self.patient.patient_id, "12345")
41 | self.assertEqual(self.patient.attr_dict["name"], "John Doe")
42 | self.assertEqual(self.patient.attr_dict["age"], 45)
43 |
44 | def test_methods(self):
45 | self.patient.add_event(self.event1)
46 | self.patient.add_event(self.event2)
47 | self.patient.add_event(self.event3)
48 | self.assertEqual(len(self.patient.events), 3)
49 | self.assertEqual(self.patient.events[0], self.event1)
50 | self.assertEqual(self.patient.events[1], self.event2)
51 | self.assertEqual(self.patient.events[2], self.event3)
52 |
53 | def test_repr_and_str(self):
54 | self.patient.add_event(self.event1)
55 | self.patient.add_event(self.event2)
56 | self.patient.add_event(self.event3)
57 | print(repr(self.patient))
58 | print(str(self.patient))
59 |
60 |
61 | if __name__ == "__main__":
62 | unittest.main()
63 |
--------------------------------------------------------------------------------
/pyhealth/unittests/test_datasets/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from pyhealth.datasets import BaseEHRDataset
3 |
4 | class EHRDatasetStatAssertion:
5 |
6 | def __init__(self, dataset: BaseEHRDataset, eps: float):
7 | self.dataset = dataset
8 | self.eps = eps
9 | # return self # builder
10 |
11 | def assertEHRStats(
12 | self,
13 | expected_num_patients: int,
14 | expected_num_visits: int,
15 | expected_num_visits_per_patient: float,
16 | expected_events_per_visit_per_table: List[float],
17 | ):
18 | self.assertNumPatients(expected_num_patients)
19 | self.assertNumVisits(expected_num_visits)
20 | self.assertMeanVisitsPerPatient(expected_num_visits_per_patient)
21 | self.assertTableMeans(expected_events_per_visit_per_table)
22 |
23 | def assertNumPatients(self, expected: int):
24 | actual = len(self.dataset.patients)
25 | if expected != actual:
26 | raise AssertionError(f"Expected {expected} patients got {actual}")
27 |
28 | def assertNumVisits(self, expected: int):
29 | actual = sum([len(patient) for patient in self.dataset.patients.values()])
30 | if expected != actual:
31 | raise AssertionError(f"Expected {expected} num visits got {actual}")
32 |
33 | def assertMeanVisitsPerPatient(self, expected: int):
34 | actual_visits = [len(patient) for patient in self.dataset.patients.values()]
35 | actual = sum(actual_visits) / len(actual_visits)
36 | if abs(expected - actual) > self.eps:
37 | raise AssertionError(f"Expected {expected} mean num visits got {actual}")
38 |
39 |
40 | # expected list must be ordered by tables
41 | def assertTableMeans(self, expected_per_table: List[float]):
42 | for expected_value, table in zip(expected_per_table, self.dataset.tables):
43 | actual_num_events = [
44 | len(v.get_event_list(table))
45 | for p in self.dataset.patients.values()
46 | for v in p
47 | ]
48 |
49 | actual_value = sum(actual_num_events) / len(actual_num_events)
50 |
51 | if abs(expected_value - actual_value) > self.eps:
52 | raise AssertionError(f"Expected {expected_value} mean for events in {table} got {actual_value}")
--------------------------------------------------------------------------------
/pyhealth/unittests/test_medcode.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import sys
3 | import os
4 | current = os.path.dirname(os.path.realpath(__file__))
5 | sys.path.append(os.path.dirname(os.path.dirname(current)))
6 |
7 | from pyhealth.medcode import InnerMap, CrossMap
8 |
9 | class TestInnerMap(unittest.TestCase):
10 |
11 | def setUp(self):
12 | map_name = "ICD9CM"
13 | self.inner_map = InnerMap.load(map_name)
14 |
15 | def test_contain(self):
16 | self.assertTrue(
17 | "428.0" in self.inner_map,
18 | msg="contain function of InnerMap failed"
19 | )
20 |
21 | def test_lookup(self):
22 | self.assertEqual(
23 | self.inner_map.lookup("428.0"),
24 | 'Congestive heart failure, unspecified',
25 | msg="lookup function of InnerMap failed"
26 | )
27 |
28 | def test_get_ancestors(self):
29 | self.assertEqual(
30 | self.inner_map.get_ancestors("428.0"),
31 | ['428', '420-429.99', '390-459.99', '001-999.99'],
32 | msg="get_ancestors function of InnerMap failed"
33 | )
34 |
35 | def test_get_descendants(self):
36 | self.assertEqual(
37 | self.inner_map.get_descendants("428"),
38 | ['428.0', '428.1', '428.2', '428.3', '428.4', '428.9', '428.20', '428.21', '428.22', '428.23', '428.30', '428.31', '428.32', '428.33', '428.40', '428.41', '428.42', '428.43'],
39 | msg="get_descendants function of InnerMap failed"
40 | )
41 |
42 |
43 | class TestInnerMapATC(unittest.TestCase):
44 | def setUp(self):
45 | self.inner_map = InnerMap.load("ATC")
46 |
47 | def test_lookup(self):
48 | self.assertEqual(
49 | self.inner_map.lookup("M01AE51"),
50 | 'ibuprofen, combinations',
51 | msg="lookup function of InnerMap (ATC) failed"
52 | )
53 | self.assertEqual(
54 | self.inner_map.lookup("M01AE51", "drugbank_id"),
55 | 'DB01050',
56 | msg="lookup function of InnerMap (ATC) failed"
57 | )
58 | self.assertEqual(
59 | self.inner_map.lookup("M01AE51", "smiles"),
60 | 'CC(C)CC1=CC=C(C=C1)C(C)C(O)=O',
61 | msg="lookup function of InnerMap (ATC) failed"
62 | )
63 |
64 |
65 | def test_convert(self):
66 | self.assertEqual(
67 | self.inner_map.convert("A12CE02", level=3),
68 | "A12C",
69 | msg="convert function of InnerMap (ATC) failed"
70 | )
71 |
72 |
73 | class TestCrossMap(unittest.TestCase):
74 | def setUp(self):
75 | self.cross_map = CrossMap.load(source_vocabulary="ICD9CM", target_vocabulary="CCSCM")
76 |
77 | def test_map(self):
78 | self.assertEqual(
79 | self.cross_map.map("428.0"),
80 | ["108"],
81 | msg="map function of CrossMap failed"
82 | )
83 |
84 |
85 | if __name__ == "__main__":
86 | unittest.main()
87 |
--------------------------------------------------------------------------------
/pyhealth/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pickle
4 | import random
5 |
6 | import numpy as np
7 | import torch
8 |
9 |
10 | def set_seed(seed):
11 | random.seed(seed)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | if torch.cuda.is_available():
15 | torch.cuda.manual_seed(seed)
16 | torch.cuda.manual_seed_all(seed)
17 | torch.backends.cudnn.deterministic = True
18 | torch.backends.cudnn.benchmark = False
19 | os.environ["PYTHONHASHSEED"] = str(seed)
20 |
21 |
22 | def create_directory(directory):
23 | if not os.path.exists(directory):
24 | os.makedirs(directory)
25 |
26 |
27 | def load_pickle(filename):
28 | with open(filename, "rb") as f:
29 | return pickle.load(f)
30 |
31 |
32 | def save_pickle(data, filename):
33 | with open(filename, "wb") as f:
34 | pickle.dump(data, f)
35 |
36 |
37 | def load_json(filename):
38 | with open(filename, "r") as f:
39 | return json.load(f)
40 |
41 |
42 | def save_json(data, filename):
43 | with open(filename, "w") as f:
44 | json.dump(data, f)
45 |
46 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-22.04
5 | tools:
6 | python: "3.8"
7 |
8 | sphinx:
9 | configuration: docs/conf.py
10 |
11 | formats:
12 | - epub
13 | - htmlzip
14 |
15 | python:
16 | install:
17 | - method: pip
18 | path: .
19 | - requirements: docs/requirements.txt
20 |
21 | submodules:
22 | include: all
--------------------------------------------------------------------------------
/requirements-nlp.txt:
--------------------------------------------------------------------------------
1 | numpy~=1.26.4
2 | pandas~=2.1.4
3 | scikit-learn~=1.3.2
4 | torch~=2.2.2
5 | transformers~=4.48.3
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.8.0
2 | torchvision>=0.9.0
3 | rdkit>=2022.03.4
4 | scikit-learn>=0.24.2
5 | networkx>=2.6.3
6 | pandas>=1.3.2,<2
7 | pandarallel>=1.5.3
8 | mne>=1.0.3
9 | urllib3<=1.26.15
10 | numpy
11 | tqdm
12 | polars
13 | transformers
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.rst
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup, Extension
2 |
3 | # read the contents of README file
4 | from os import path
5 | from io import open # for Python 2 and 3 compatibility
6 |
7 | # # get __version__ from _version.py
8 | # ver_file = path.join('pyhealth', 'version.py')
9 | # with open(ver_file) as f:
10 | # exec(f.read())
11 |
12 | this_directory = path.abspath(path.dirname(__file__))
13 |
14 |
15 | # read the contents of README.rst
16 | def readme():
17 | with open(path.join(this_directory, "README.rst"), encoding="utf-8") as f:
18 | return f.read()
19 |
20 |
21 | # read the contents of requirements.txt
22 | with open(path.join(this_directory, "requirements.txt"), encoding="utf-8") as f:
23 | requirements = f.read().splitlines()
24 |
25 | VERSION = "1.1.4"
26 |
27 | setup(
28 | name="pyhealth",
29 | version=VERSION,
30 | description="A Python library for healthcare AI",
31 | long_description=readme(),
32 | long_description_content_type="text/x-rst",
33 | author="Chaoqi Yang, Zhenbang Wu, Patrick Jiang, Zhen Lin, Benjamin Danek, Junyi Gao, Jimeng Sun",
34 | author_email="chaoqiy2@illinois.edu",
35 | url="https://github.com/sunlabuiuc/pyhealth",
36 | keywords=[
37 | "heathcare AI",
38 | "healthcare",
39 | "electronic health records",
40 | "EHRs",
41 | "machine learning",
42 | "data mining",
43 | "neural networks",
44 | "deep learning",
45 | ],
46 | packages=find_packages(exclude=["test"]),
47 | include_package_data=True,
48 | install_requires=requirements,
49 | setup_requires=["setuptools>=38.6.0"],
50 | classifiers=[
51 | "Development Status :: 3 - Alpha",
52 | "Intended Audience :: Education",
53 | "Intended Audience :: Financial and Insurance Industry",
54 | "Intended Audience :: Science/Research",
55 | "Intended Audience :: Developers",
56 | "Intended Audience :: Information Technology",
57 | "License :: OSI Approved :: BSD License",
58 | "Programming Language :: Python :: 3.8",
59 | ],
60 | )
61 |
--------------------------------------------------------------------------------
/test-resources/nlp/metrics.csv:
--------------------------------------------------------------------------------
1 | bleu,levenshtein,rouge1_f_score,rouge1_precision,rouge1_recall,rouge2_f_score,rouge2_precision,rouge2_recall,rouge3_f_score,rouge3_precision,rouge3_recall,rouge4_f_score,rouge4_precision,rouge4_recall,rouge5_f_score,rouge5_precision,rouge5_recall,rouge6_f_score,rouge6_precision,rouge6_recall,rouge7_f_score,rouge7_precision,rouge7_recall,rouge8_f_score,rouge8_precision,rouge8_recall,rouge9_f_score,rouge9_precision,rouge9_recall,rougeL_f_score,rougeL_precision,rougeL_recall
2 | 1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
3 | 0.658,0.9592,0.9474,0.9,1.0,0.8235,0.7778,0.875,0.6667,0.625,0.7143,0.4615,0.4286,0.5,0.3636,0.3333,0.4,0.2222,0.2,0.25,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.9474,0.9,1.0
4 |
--------------------------------------------------------------------------------
/tests/base.py:
--------------------------------------------------------------------------------
1 | from typing import Type
2 | import unittest
3 | import logging
4 |
5 |
6 | class BaseTestCase(unittest.TestCase):
7 | @staticmethod
8 | def _setup_logging(level: int = logging.INFO):
9 | logging.basicConfig(level=level)
10 |
11 | @classmethod
12 | def _set_debug(cls: Type):
13 | cls._setup_logging(logging.DEBUG)
14 | print()
15 | print('_' * 80)
16 |
--------------------------------------------------------------------------------
/tests/nlp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/tests/nlp/__init__.py
--------------------------------------------------------------------------------
/tests/nlp/test_metrics.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import logging
3 | from base import BaseTestCase
4 | from pathlib import Path
5 | import pandas as pd
6 | from pyhealth.nlp.metrics import (
7 | LevenshteinDistanceScoreMethod,
8 | ScoreContext, ScoreSet, ScoreResult, Scorer
9 | )
10 |
11 |
12 | class TestMetrics(BaseTestCase):
13 | def __init__(self, *args, **kwargs):
14 | super().__init__(*args, **kwargs)
15 | # silence info messages from rouge package
16 | self._setup_logging(logging.WARNING)
17 |
18 | def setUp(self):
19 | s1: str = 'The boy threw the ball. He practiced every day.'
20 | s2: str = 'The boy threw X the ball. He practiced every day.'
21 | self.pairs: List[List[str]] = [[s1, s1], [s1, s2]]
22 |
23 | def test_object_graph(self):
24 | # configure only the edit distance method
25 | scorer = Scorer(
26 | methods={'editdistance': LevenshteinDistanceScoreMethod()},
27 | )
28 | ss: ScoreSet = scorer.score(ScoreContext(self.pairs))
29 | self.assertEqual(ScoreSet, type(ss))
30 | self.assertEqual(2, len(ss.results))
31 | res1: ScoreResult = ss.results[0]
32 | self.assertEqual(ScoreResult, type(res1))
33 | self.assertEqual(1, len(res1.scores))
34 | self.assertTrue('editdistance' in res1.scores)
35 | self.assertEqual(1., res1.scores['editdistance'].value)
36 |
37 | def test_pandas(self):
38 | WRITE: bool = 0
39 | should_file: Path = Path('test-resources/nlp/metrics.csv')
40 | self._set_debug()
41 | scorer = Scorer()
42 | ss: ScoreSet = scorer.score(ScoreContext(self.pairs))
43 | df: pd.DataFrame = ss.as_dataframe()
44 | # give tolarance for arch high sig digits that might be off by epsilon
45 | df = df.round(4)
46 | if WRITE:
47 | should_file.parent.mkdir(parents=True, exist_ok=True)
48 | df.to_csv(should_file, index=False)
49 | should: pd.DataFrame = pd.read_csv(should_file)
50 | self.assertEqual(should.to_string(), df.to_string())
51 |
--------------------------------------------------------------------------------