├── .github └── workflows │ └── test.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.rst ├── chat-assistant ├── Dockerfile ├── OPENAI_API_KEY.txt ├── README.md ├── code-corpus-generation.py ├── corpus │ ├── pyhealth-code.pkl │ ├── pyhealth-code.txt │ ├── pyhealth-text.pkl │ ├── pyhealth-text.txt │ └── pyhealth.txt ├── ingest.py ├── main.py ├── prompts │ ├── __init__.py │ ├── introduction_prompt.py │ ├── qa_prompt.py │ └── summary_prompt.py ├── qa_chain.py ├── query_data.py ├── requirements.txt └── text-corpus-generation.py ├── docs ├── Makefile ├── _static │ ├── css │ │ ├── overwrite.css │ │ └── sphinx_gallery.css │ └── pyhealth_logos │ │ └── pyhealth-logo.png ├── about.rst ├── advance_tutorials.rst ├── api │ ├── calib.rst │ ├── calib │ │ ├── pyhealth.calib.calibration.rst │ │ └── pyhealth.calib.predictionset.rst │ ├── data.rst │ ├── data │ │ ├── pyhealth.data.Event.rst │ │ └── pyhealth.data.Patient.rst │ ├── datasets.rst │ ├── datasets │ │ ├── pyhealth.datasets.BaseDataset.rst │ │ ├── pyhealth.datasets.CardiologyDataset.rst │ │ ├── pyhealth.datasets.ISRUCDataset.rst │ │ ├── pyhealth.datasets.MIMIC3Dataset.rst │ │ ├── pyhealth.datasets.MIMIC4Dataset.rst │ │ ├── pyhealth.datasets.MIMICExtractDataset.rst │ │ ├── pyhealth.datasets.MedicalTranscriptionsDataset.rst │ │ ├── pyhealth.datasets.OMOPDataset.rst │ │ ├── pyhealth.datasets.SHHSDataset.rst │ │ ├── pyhealth.datasets.SampleDataset.rst │ │ ├── pyhealth.datasets.SleepEDFDataset.rst │ │ ├── pyhealth.datasets.TUABDataset.rst │ │ ├── pyhealth.datasets.TUEVDataset.rst │ │ ├── pyhealth.datasets.eICUDataset.rst │ │ ├── pyhealth.datasets.splitter.rst │ │ └── pyhealth.datasets.utils.rst │ ├── medcode.rst │ ├── metrics.rst │ ├── metrics │ │ ├── pyhealth.metrics.binary.rst │ │ ├── pyhealth.metrics.calibration.rst │ │ ├── pyhealth.metrics.fairness.rst │ │ ├── pyhealth.metrics.multiclass.rst │ │ ├── pyhealth.metrics.multilabel.rst │ │ └── pyhealth.metrics.prediction_set.rst │ ├── models.rst │ ├── models │ │ ├── pyhealth.models.AdaCare.rst │ │ ├── pyhealth.models.Agent.rst │ │ ├── pyhealth.models.BaseModel.rst │ │ ├── pyhealth.models.CNN.rst │ │ ├── pyhealth.models.ConCare.rst │ │ ├── pyhealth.models.ContraWR.rst │ │ ├── pyhealth.models.Deepr.rst │ │ ├── pyhealth.models.GAMENet.rst │ │ ├── pyhealth.models.GAN.rst │ │ ├── pyhealth.models.GNN.rst │ │ ├── pyhealth.models.GRASP.rst │ │ ├── pyhealth.models.MICRON.rst │ │ ├── pyhealth.models.MLP.rst │ │ ├── pyhealth.models.MedLink.rst │ │ ├── pyhealth.models.MoleRec.rst │ │ ├── pyhealth.models.RETAIN.rst │ │ ├── pyhealth.models.RNN.rst │ │ ├── pyhealth.models.SafeDrug.rst │ │ ├── pyhealth.models.SparcNet.rst │ │ ├── pyhealth.models.StageNet.rst │ │ ├── pyhealth.models.TCN.rst │ │ ├── pyhealth.models.Transformer.rst │ │ ├── pyhealth.models.TransformersModel.rst │ │ └── pyhealth.models.VAE.rst │ ├── tasks.rst │ ├── tasks │ │ ├── pyhealth.tasks.BaseTask.rst │ │ ├── pyhealth.tasks.COVID19CXRClassification.rst │ │ ├── pyhealth.tasks.EEG_abnormal.rst │ │ ├── pyhealth.tasks.EEG_events.rst │ │ ├── pyhealth.tasks.InHospitalMortalityMIMIC4.rst │ │ ├── pyhealth.tasks.MIMIC3ICD9Coding.rst │ │ ├── pyhealth.tasks.MedicalTranscriptionsClassification.rst │ │ ├── pyhealth.tasks.Readmission30DaysMIMIC4.rst │ │ ├── pyhealth.tasks.SleepStagingSleepEDF.rst │ │ ├── pyhealth.tasks.cardiology_detect.rst │ │ ├── pyhealth.tasks.drug_recommendation.rst │ │ ├── pyhealth.tasks.length_of_stay_prediction.rst │ │ ├── pyhealth.tasks.mortality_prediction.rst │ │ ├── pyhealth.tasks.patient_linkage_mimic3_fn.rst │ │ ├── pyhealth.tasks.readmission_prediction.rst │ │ ├── pyhealth.tasks.sleep_staging.rst │ │ └── pyhealth.tasks.temple_university_EEG_tasks.rst │ ├── tokenizer.rst │ └── trainer.rst ├── conf.py ├── developer.rst ├── examples.rst ├── examples_bak.rst ├── faq.rst ├── how_to_get_started.rst ├── index.rst ├── install.rst ├── live.rst ├── log.rst ├── make.bat ├── requirements.txt ├── tutorials.rst └── use_cases.rst ├── examples ├── ChestXray-image-generation-GAN.ipynb ├── EEG_events_SparcNet.py ├── EEG_isAbnormal_SparcNet.py ├── benchmark_ehrshot_xgboost.ipynb ├── cardiology_detection_isAR_SparcNet.py ├── chextXray_image_generation_VAE.py ├── drug_recommendation_eICU_transformer.py ├── drug_recommendation_mimic3_gamenet.py ├── drug_recommendation_mimic3_micron.py ├── drug_recommendation_mimic3_molerec.py ├── drug_recommendation_mimic3_safedrug.py ├── drug_recommendation_mimic3_transformer.py ├── drug_recommendation_mimic4_gamenet.py ├── graph_torchvision_model.ipynb ├── interpret_demo.ipynb ├── kg_embedding.ipynb ├── length_of_stay_mimic3_rnn.py ├── lm_embedding_huggingface.ipynb ├── lm_embedding_openai.ipynb ├── medcode.py ├── mortality_mimic3_adacare.py ├── mortality_mimic3_agent.py ├── mortality_mimic3_concare.py ├── mortality_mimic3_grasp.py ├── mortality_mimic3_rnn.py ├── mortality_mimic3_stagenet.py ├── mortality_mimic3_tcn.py ├── patient_linkage_mimic3_medlink.py ├── prepare_mapping.ipynb ├── readmission_mimic3_fairness.py ├── readmission_mimic3_rnn.py ├── sleep_staging_ISRUC_SparcNet.py ├── sleep_staging_shhs_contrawr.py ├── sleep_staging_sleepEDF_contrawr.py ├── test.py ├── test_dataloader.sh └── test_eICU_addition.py ├── figure ├── five-stage-pipeline.png ├── overview.png ├── poster.png └── structured-dataset.png ├── hackthon └── 202410-sunlab-hackthon.md ├── leaderboard ├── README.rst ├── __init__.py ├── ext_plot.py ├── leaderboard_gen.py ├── rtd_build_trigger.py └── utils.py ├── makefile ├── pyhealth ├── __init__.py ├── calib │ ├── README.md │ ├── __init__.py │ ├── base_classes.py │ ├── calibration │ │ ├── __init__.py │ │ ├── dircal.py │ │ ├── hb.py │ │ ├── kcal │ │ │ ├── __init__.py │ │ │ ├── bw.py │ │ │ ├── embed_data.py │ │ │ └── kde.py │ │ └── temperature_scale.py │ ├── predictionset │ │ ├── __init__.py │ │ ├── favmac │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ └── quantiletree.py │ │ ├── label.py │ │ └── scrib │ │ │ ├── __init__.py │ │ │ ├── quicksearch.py │ │ │ └── quicksearch_cython.pyx │ └── utils.py ├── data │ ├── __init__.py │ └── data.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── cardiology.py │ ├── configs │ │ ├── __init__.py │ │ ├── config.py │ │ ├── covid19_cxr.yaml │ │ ├── ehrshot.yaml │ │ ├── medical_transcriptions.yaml │ │ ├── mimic3.yaml │ │ ├── mimic4_cxr.yaml │ │ ├── mimic4_ehr.yaml │ │ └── mimic4_note.yaml │ ├── covid19_cxr.py │ ├── ehrshot.py │ ├── eicu.py │ ├── featurizers │ │ ├── signal.py │ │ └── text.py │ ├── isruc.py │ ├── medical_transcriptions.py │ ├── mimic3.py │ ├── mimic4.py │ ├── mimicextract.py │ ├── omop.py │ ├── sample_dataset.py │ ├── shhs.py │ ├── sleepedf.py │ ├── splitter.py │ ├── tuab.py │ ├── tuev.py │ └── utils.py ├── interpret │ ├── __init__.py │ └── methods │ │ ├── __init__.py │ │ └── chefer.py ├── medcode │ ├── __init__.py │ ├── codes │ │ ├── __init__.py │ │ ├── atc.py │ │ ├── ccscm.py │ │ ├── ccsproc.py │ │ ├── icd10cm.py │ │ ├── icd10proc.py │ │ ├── icd9cm.py │ │ ├── icd9proc.py │ │ ├── ndc.py │ │ ├── rxnorm.py │ │ └── umls.py │ ├── cross_map.py │ ├── inner_map.py │ ├── pretrained_embeddings │ │ ├── __init__.py │ │ ├── kg_emb │ │ │ ├── __init__.py │ │ │ ├── datasets │ │ │ │ ├── __init__.py │ │ │ │ ├── base_kg_dataset.py │ │ │ │ ├── sample_kg_dataset.py │ │ │ │ ├── splitter.py │ │ │ │ └── umls.py │ │ │ ├── examples │ │ │ │ └── train_kge_model.py │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── complex.py │ │ │ │ ├── distmult.py │ │ │ │ ├── kg_base.py │ │ │ │ ├── rotate.py │ │ │ │ └── transe.py │ │ │ └── tasks │ │ │ │ ├── __init__.py │ │ │ │ └── link_prediction.py │ │ └── lm_emb │ │ │ ├── __init__.py │ │ │ ├── huggingface_retriever.py │ │ │ └── openai_retriever.py │ └── utils.py ├── metrics │ ├── __init__.py │ ├── binary.py │ ├── calibration.py │ ├── drug_recommendation.py │ ├── fairness.py │ ├── fairness_utils │ │ ├── __init__.py │ │ ├── group.py │ │ └── utils.py │ ├── multiclass.py │ ├── multilabel.py │ ├── prediction_set.py │ ├── ranking.py │ └── regression.py ├── models │ ├── __init__.py │ ├── adacare.py │ ├── agent.py │ ├── base_model.py │ ├── cnn.py │ ├── concare.py │ ├── contrawr.py │ ├── deepr.py │ ├── embedding.py │ ├── gamenet.py │ ├── gan.py │ ├── gnn.py │ ├── graph_torchvision_model.py │ ├── grasp.py │ ├── medlink │ │ ├── __init__.py │ │ ├── bm25.py │ │ ├── model.py │ │ └── utils.py │ ├── micron.py │ ├── mlp.py │ ├── molerec.py │ ├── retain.py │ ├── rnn.py │ ├── safedrug.py │ ├── sparcnet.py │ ├── stagenet.py │ ├── tcn.py │ ├── torchvision_model.py │ ├── transformer.py │ ├── transformers_model.py │ ├── utils.py │ └── vae.py ├── nlp │ └── metrics.py ├── processors │ ├── __init__.py │ ├── base_processor.py │ ├── image_processor.py │ ├── label_processor.py │ ├── raw_processor.py │ ├── sequence_processor.py │ ├── signal_processor.py │ ├── text_processor.py │ └── timeseries_processor.py ├── sampler │ ├── __init__.py │ └── sage_sampler.py ├── tasks │ ├── EEG_abnormal.py │ ├── EEG_events.py │ ├── __init__.py │ ├── base_task.py │ ├── benchmark_ehrshot.py │ ├── cardiology_detect.py │ ├── covid19_cxr_classification.py │ ├── drug_recommendation.py │ ├── in_hospital_mortality_mimic4.py │ ├── length_of_stay_prediction.py │ ├── medical_coding.py │ ├── medical_transcriptions_classification.py │ ├── mortality_prediction.py │ ├── patient_linkage.py │ ├── readmission_30days_mimic4.py │ ├── readmission_prediction.py │ ├── sleep_staging.py │ ├── sleep_staging_v2.py │ └── temple_university_EEG_tasks.py ├── tokenizer.py ├── trainer.py ├── unittests │ ├── test.ipynb │ ├── test.py │ ├── test_data │ │ └── test_data.py │ ├── test_datasets │ │ ├── test_eicu.py │ │ ├── test_mimic3.py │ │ ├── test_mimic4.py │ │ ├── test_omop.py │ │ └── utils.py │ ├── test_medcode.py │ ├── test_mortality_prediction.py │ └── test_tokenizer.py └── utils.py ├── readthedocs.yml ├── requirements-nlp.txt ├── requirements.txt ├── setup.cfg ├── setup.py ├── test-resources └── nlp │ └── metrics.csv └── tests ├── base.py └── nlp ├── __init__.py └── test_metrics.py /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | push: 6 | paths-ignore: 7 | - '**.md' 8 | - 'doc/*' 9 | branches: 10 | - master 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ['3.11'] 18 | steps: 19 | - name: Checkout reposistory 20 | uses: actions/checkout@v3 21 | with: 22 | submodules: recursive 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v3 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | 29 | - name: Display Python version 30 | run: python -c "import sys; print(sys.version)" 31 | 32 | - name: Install dependent packages 33 | run: 'make deps' 34 | 35 | - name: Run tests 36 | run: 'make test' 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *DS_Store* 3 | *output* 4 | resource* 5 | 6 | debug_entry* 7 | playground.ipynb 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | .pytest_cache 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | 115 | # Sphinx documentation 116 | docs/build/ 117 | # /docs/api/data/ 118 | # /docs/api/datasets/ 119 | # /docs/api/evaluator/ 120 | # /docs/api/models/ 121 | # /docs/api/tasks/ 122 | 123 | #docs/examples.rst 124 | #docs/examples_bak.rst 125 | #docs/faq.rst 126 | #docs/usecase.rst 127 | 128 | # leaderboard credentials 129 | leaderboard/credentials.json 130 | leaderboard/rtd_token.txt 131 | 132 | # locally pre-trained models 133 | pyhealth/medcode/pretrained_embeddings/kg_emb/examples/pretrained_model -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 SunLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /chat-assistant/Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | FROM python:3.10 3 | 4 | # Set the working directory in the container to /app 5 | WORKDIR /app 6 | 7 | # Copy the current directory contents into the container at /app 8 | COPY . /app 9 | 10 | # Install any needed packages specified in requirements.txt 11 | RUN pip install --upgrade pip 12 | RUN pip install --no-cache-dir -r requirements.txt 13 | 14 | # Make port 80 available to the world outside this container 15 | EXPOSE 7861 16 | 17 | 18 | # Define environment variable 19 | ENV LOG_PATH=/app/logs/ 20 | ENV CORPUS_PATH=/app/corpus/ 21 | ENV GRADIO_SERVER_NAME 0.0.0.0 22 | ENV GRADIO_SERVER_PORT 7861 23 | 24 | # put openai api key here 25 | ENV OPENAI_API_KEY sk-xxxx 26 | 27 | # Run your_script.py when the container launches 28 | CMD ["python", "main.py"] 29 | -------------------------------------------------------------------------------- /chat-assistant/OPENAI_API_KEY.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/chat-assistant/OPENAI_API_KEY.txt -------------------------------------------------------------------------------- /chat-assistant/README.md: -------------------------------------------------------------------------------- 1 | ## LLM for PyHealth 2 | The current LLM for pyhealth interface is deployed here: http://35.208.88.194:7861/. 3 | 4 | ### Step 1: 5 | Merge all pyhealth related information (code and doc txt) into `pyhealth.txt` 6 | - currently, we use https://github.com/mpoon/gpt-repository-loader 7 | 8 | ### Step 2: 9 | Run the `ingest.py` to transform the `pyhealth.txt` into the FAISS vector database 10 | ```python 11 | python ingest.py 12 | ``` 13 | 14 | ### Step 3: 15 | Run the retrieval augmented generation (RAG) app for Q & A based on the `pyhealth.txt` document. 16 | ```python 17 | python app_rag.py 18 | ``` 19 | 20 | 21 | ### Launch in Docker 22 | 23 | 1. Modfiy environment variables (OPENAI_API_KEY, server address...)in `Dockerfile`. 24 | 2. Build image by `docker build -t chat-pyhealth .`. 25 | 3. Debug a container by `docker run -p 0.0.0.0:7861:7861 --name chat-pyhealth-c -v ./logs/:/app/logs/ chat-pyhealth`. 26 | 4. Run a container by `docker run -d -p 0.0.0.0:7861:7861 --name chat-pyhealth-c -v ./logs/:/app/logs/ chat-pyhealth`. 27 | 28 | ```shell 29 | ## build container 30 | docker run -d -p [host address and port]:[container port] --name [name] -v [host path]:[container path] [image] 31 | 32 | # -d: detached 33 | # -p: port 34 | # --name: container name 35 | # -v: mount directory of container to local host path 36 | 37 | ## check 38 | docker ps 39 | docker images 40 | 41 | ## remove 42 | docker stop / restart chat-pyhealth-c # container 43 | docker rm chat-pyhealth-c # container 44 | docker rmi chat-pyhealth # image 45 | 46 | ## modify directly 47 | docker cp [local file in host] chat-pyhealth-c:[container path] 48 | ``` 49 | 50 | 51 | **Let me know if you want to join and help us improve the current interface.** 52 | -------------------------------------------------------------------------------- /chat-assistant/code-corpus-generation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def merge_py_files(directory, output_file): 4 | with open(output_file, 'a') as outfile: 5 | for root, dirs, files in os.walk(directory): 6 | for filename in files: 7 | if filename.endswith(".py"): 8 | with open(os.path.join(root, filename), 'r') as infile: 9 | outfile.write("Here is the code content for " + filename + ":\n") 10 | outfile.write(infile.read()) 11 | outfile.write("\n\n") 12 | 13 | # --------- text ------------ 14 | with open('pyhealth-code.txt', 'w') as f: 15 | f.write("The code content for pyhealth") 16 | 17 | merge_py_files('../pyhealth', 'pyhealth-code.txt') 18 | merge_py_files('../examples', 'pyhealth-code.txt') 19 | merge_py_files('/home/chaoqiy2/pyhealth-tutorial', 'pyhealth-code.txt') -------------------------------------------------------------------------------- /chat-assistant/corpus/pyhealth-code.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/chat-assistant/corpus/pyhealth-code.pkl -------------------------------------------------------------------------------- /chat-assistant/corpus/pyhealth-text.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/chat-assistant/corpus/pyhealth-text.pkl -------------------------------------------------------------------------------- /chat-assistant/ingest.py: -------------------------------------------------------------------------------- 1 | from langchain.text_splitter import CharacterTextSplitter 2 | from langchain.document_loaders import UnstructuredFileLoader 3 | from langchain.vectorstores.faiss import FAISS 4 | from langchain.embeddings import OpenAIEmbeddings 5 | import pickle 6 | import os 7 | from env import OPENAI_API_KEY 8 | 9 | os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY 10 | 11 | 12 | def vectorize_corpus(corpus_path): 13 | print(f"Vectorizing {corpus_path}...") 14 | # print(f"Loading data {corpus_path}...") 15 | loader = UnstructuredFileLoader(corpus_path) 16 | raw_documents = loader.load() 17 | 18 | 19 | # print("Splitting text...") 20 | text_splitter = CharacterTextSplitter( 21 | separator="\n\n", 22 | chunk_size=600, 23 | chunk_overlap=100, 24 | length_function=len, 25 | ) 26 | documents = text_splitter.split_documents(raw_documents) 27 | 28 | # print("Creating vectorstore...") 29 | embeddings = OpenAIEmbeddings() 30 | vectorstore = FAISS.from_documents(documents, embeddings) 31 | 32 | base_name, ext = os.path.splitext(corpus_path) 33 | vectorstore_path = base_name + '.pkl' 34 | with open(vectorstore_path, "wb") as f: 35 | pickle.dump(vectorstore, f) 36 | 37 | 38 | vectorize_corpus('corpus/pyhealth-text.txt') 39 | vectorize_corpus('corpus/pyhealth-code.txt') -------------------------------------------------------------------------------- /chat-assistant/prompts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/chat-assistant/prompts/__init__.py -------------------------------------------------------------------------------- /chat-assistant/prompts/introduction_prompt.py: -------------------------------------------------------------------------------- 1 | USER_INTRO = '''I am here for some help of PyHealth.''' 2 | 3 | 4 | 5 | AI_INTRO = '''Hi, I am a chat assistant for PyHealth. I am here to solve all kinds of questions and problems when you are using PyHealth. My specific features include: 6 | 7 | 1. I can give you a brief introduction to PyHealth and its functions 8 | 2. I can provide tutorials, explain how to use different modules 9 | 4. I can even provide sample code for specific tasks using PyHealth 10 | 5. I can provide examples on mapping functions between different medical coding systems 11 | 6. I can tell you a little bit about our development team 12 | 13 | How can I assist you with PyHealth today?''' -------------------------------------------------------------------------------- /chat-assistant/prompts/qa_prompt.py: -------------------------------------------------------------------------------- 1 | QA_PROMPT_TEMPLATE = '''You are a very warm and helpful PyHealth usage assistant for users who want to use PyHealth or are using PyHeath. You need to answer questions related to PyHealth. 2 | 3 | PyHealth is a comprehensive deep learning toolkit for supporting clinical predictive modeling, which is designed for both ML researchers and medical practitioners. We can make your healthcare AI applications easier to deploy and more flexible and customizable. 4 | 5 | You need to refer to Reference Documents and Reference Code to get information about PyHealth to improve your answer quality. 6 | Reference Documents is PyHealth document, which contains information on how to use it, sample code, etc. 7 | Source Code is source code of PyHealth, which contains code implementation details. 8 | 9 | When users describe a use case for PyHealth or want to write a code using PyHealth, you should give a paragraph of sample code with some modifications and comments based on users need. 10 | 11 | 12 | ##### Reference Documents START ##### 13 | {ref_doc} 14 | ###### Reference Documents END ###### 15 | 16 | ##### Source Code START ##### (All the code you generated should base on this) 17 | {source_code} 18 | ###### Source Code END ###### 19 | 20 | ###### Chat History START ###### (This is the chat history between you and the user.) 21 | {chat_history} 22 | ###### Chat History END ###### 23 | 24 | 25 | Notice: You should not express any subjective opinion. 26 | 27 | When you generate PyHealth code for users, you must ensure all code are based on source code or sample code in reference documents. 28 | The code you generate must be executable and cannot contain errors, like calling undefined functions in PyHealth. 29 | 30 | You cannot answer any questions that are not related to PyHealth! 31 | You cannot answer any questions that are not related to PyHealth! 32 | You cannot answer any questions that are not related to PyHealth! 33 | 34 | 35 | ##### Current Conversation ##### 36 | User: {human_input} 37 | AI: (give your response in Markdown format.)''' -------------------------------------------------------------------------------- /chat-assistant/prompts/summary_prompt.py: -------------------------------------------------------------------------------- 1 | SUMMARY_PROMPT_TEMPLATE = '''You are an AI who is responsible for summarizing chat history between one another AI and the user. 2 | 3 | I will provide you with the summary of previous chat history and new message in the current round of conversation. Then you need to generate new summary based on these information. 4 | 5 | Summary of Previous Chat History: {previous_summary} 6 | New Message: {new_message} 7 | 8 | The length of your new summary is limited to {summary_token_limitation} tokens. 9 | 10 | Now, give the new summary of chat history:''' -------------------------------------------------------------------------------- /chat-assistant/requirements.txt: -------------------------------------------------------------------------------- 1 | langchain==0.1.11 2 | openai 3 | faiss-cpu 4 | unstructured 5 | tiktoken 6 | rich #for console formatting 7 | gradio==4.19.2 -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/css/overwrite.css: -------------------------------------------------------------------------------- 1 | /* 2 | Furo CSS variables 3 | https://github.com/pradyunsg/furo/blob/main/src/furo/assets/styles/variables/_index.scss 4 | https://github.com/pradyunsg/furo/blob/main/src/furo/theme/partials/_head_css_variables.html 5 | https://github.com/streamlink/streamlink/blob/17a4088c38709123c0bcab4a150549bd16d19e07/docs/_static/styles/custom.css 6 | */ 7 | 8 | /* for the sphinx design cards */ 9 | body { 10 | --sd-color-shadow: dimgrey; 11 | } 12 | 13 | dt:target, 14 | span.highlighted { 15 | background-color: #f0f0f0; 16 | } 17 | 18 | dl.citation > dt { 19 | float: left; 20 | margin-right: 15px; 21 | font-weight: bold; 22 | } 23 | 24 | /* Parameters normalize size and captialized, */ 25 | dl .field-list dt { 26 | font-size: var(--font-size--normal) !important; 27 | text-transform: none !important; 28 | } 29 | 30 | /* examples and headings in classes */ 31 | p.rubric { 32 | font-size: var(--font-size--normal); 33 | text-transform: none; 34 | font-weight: 500; 35 | } 36 | 37 | /* adapted from https://github.com/dask/dask-sphinx-theme/blob/main/dask_sphinx_theme/static/css/nbsphinx.css */ 38 | 39 | .nbinput .prompt, 40 | .nboutput .prompt { 41 | display: none; 42 | } 43 | .nboutput .stderr { 44 | display: none; 45 | } 46 | 47 | div.nblast.container { 48 | padding-bottom: 10px !important; 49 | padding-right: 0px; 50 | padding-left: 0px; 51 | } 52 | 53 | div.nbinput.container { 54 | padding-top: 10px !important; 55 | padding-right: 0px; 56 | padding-left: 0px; 57 | } 58 | 59 | div.nbinput.container div.input_area div[class*="highlight"] > pre { 60 | padding: 10px !important; 61 | margin: 0; 62 | } 63 | 64 | p.topic-title { 65 | margin-top: 0; 66 | } 67 | 68 | /* so that api methods are small in sidebar */ 69 | li.toctree-l3 { 70 | font-size: 81.25% !important; 71 | } 72 | li.toctree-l4 { 73 | font-size: 75% !important; 74 | } 75 | -------------------------------------------------------------------------------- /docs/_static/css/sphinx_gallery.css: -------------------------------------------------------------------------------- 1 | .sphx-glr-thumbcontainer { 2 | background: inherit !important; 3 | min-height: 250px !important; 4 | margin: 10px !important; 5 | } 6 | 7 | .sphx-glr-thumbcontainer .headerlink { 8 | display: none !important; 9 | } 10 | 11 | div.sphx-glr-thumbcontainer span { 12 | font-style: normal !important; 13 | } 14 | 15 | .sphx-glr-thumbcontainer a.internal { 16 | padding: 140px 10px 0 !important; 17 | } 18 | 19 | .sphx-glr-thumbcontainer .figure { 20 | width: 200px !important; 21 | } 22 | 23 | .sphx-glr-thumbcontainer .figure.align-center { 24 | text-align: center; 25 | margin-left: 0%; 26 | transform: translate(0%); 27 | } 28 | -------------------------------------------------------------------------------- /docs/_static/pyhealth_logos/pyhealth-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/docs/_static/pyhealth_logos/pyhealth-logo.png -------------------------------------------------------------------------------- /docs/about.rst: -------------------------------------------------------------------------------- 1 | About us 2 | ======== 3 | 4 | We are the `SunLab `_ healthcare research team at UIUC. 5 | 6 | Current Maintainers 7 | ------------------ 8 | 9 | `Zhenbang Wu `_ (Ph.D. Student @ University of Illinois Urbana-Champaign) 10 | 11 | `John Wu `_ (Ph.D. Student @ University of Illinois Urbana-Champaign) 12 | 13 | `Junyi Gao `_ (M.S. @ UIUC, Ph.D. Student @ University of Edinburgh) 14 | 15 | `Jimeng Sun `_ (Professor @ University of Illinois Urbana-Champaign) 16 | 17 | Alumni 18 | ------ 19 | 20 | `Chaoqi Yang `_ (Ph.D. Student @ University of Illinois Urbana-Champaign) 21 | 22 | `Zhen Lin `_ (Ph.D. Student @ University of Illinois Urbana-Champaign) 23 | 24 | `Benjamin Danek `_ (M.S. Student @ University of Illinois Urbana-Champaign) 25 | 26 | `Patrick Jiang `_ (M.S. Student @ University of Illinois Urbana-Champaign) -------------------------------------------------------------------------------- /docs/advance_tutorials.rst: -------------------------------------------------------------------------------- 1 | Advanced Tutorials 2 | ======================== 3 | 4 | We provided advanced tutorials for supporting various needs. 5 | 6 | `Advanced Tutorial 1: Fit your dataset into our pipeline `_ `[Video] `_ 7 | 8 | `Advanced Tutorial 2: Define your own healthcare task `_ 9 | 10 | `Advanced Tutorial 3: Adopt customized model into pyhealth `_ `[Video] `_ 11 | 12 | `Advanced Tutorial 4: Load your own processed data into pyhealth and try out our ML models `_ `[Video] `_ 13 | 14 | ---------- 15 | -------------------------------------------------------------------------------- /docs/api/calib.rst: -------------------------------------------------------------------------------- 1 | Calibration and Uncertainty Quantification 2 | =============== 3 | 4 | In this module, we implemented the following prediction 5 | set constructors or model calibration methods, which can be 6 | combined with any PyHealth models. 7 | 8 | .. toctree:: 9 | :maxdepth: 3 10 | 11 | calib/pyhealth.calib.calibration 12 | 13 | calib/pyhealth.calib.predictionset 14 | 15 | -------------------------------------------------------------------------------- /docs/api/calib/pyhealth.calib.calibration.rst: -------------------------------------------------------------------------------- 1 | pyhealth.calib.calibration 2 | =================================== 3 | 4 | 5 | .. automodule:: pyhealth.calib.calibration 6 | :members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/calib/pyhealth.calib.predictionset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.calib.predictionset 2 | =================================== 3 | 4 | 5 | .. automodule:: pyhealth.calib.predictionset 6 | :members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/api/data.rst: -------------------------------------------------------------------------------- 1 | Data 2 | =============== 3 | 4 | **pyhealth.data** defines the atomic data structures of this package. 5 | 6 | .. toctree:: 7 | :maxdepth: 3 8 | 9 | data/pyhealth.data.Event 10 | data/pyhealth.data.Patient 11 | 12 | -------------------------------------------------------------------------------- /docs/api/data/pyhealth.data.Event.rst: -------------------------------------------------------------------------------- 1 | pyhealth.data.Event 2 | ========================= 3 | 4 | One basic data structure in the package. It is a simple container for a single event. 5 | It contains all necessary attributes for supporting various healthcare tasks. 6 | 7 | .. autoclass:: pyhealth.data.Event 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/api/data/pyhealth.data.Patient.rst: -------------------------------------------------------------------------------- 1 | pyhealth.data.Patient 2 | ========================= 3 | 4 | Another basic data structure in the package. A Patient is a collection of Visit for 5 | the current patients. It contains all necessary attributes of a patient, such as 6 | ethnicity, mortality status, gender, etc. It can support various healthcare tasks. 7 | 8 | .. autoclass:: pyhealth.data.Patient 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | -------------------------------------------------------------------------------- /docs/api/datasets.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | =============== 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | 7 | datasets/pyhealth.datasets.BaseDataset 8 | datasets/pyhealth.datasets.SampleDataset 9 | datasets/pyhealth.datasets.MIMIC3Dataset 10 | datasets/pyhealth.datasets.MIMIC4Dataset 11 | datasets/pyhealth.datasets.MedicalTranscriptionsDataset 12 | datasets/pyhealth.datasets.CardiologyDataset 13 | datasets/pyhealth.datasets.eICUDataset 14 | datasets/pyhealth.datasets.ISRUCDataset 15 | datasets/pyhealth.datasets.MIMICExtractDataset 16 | datasets/pyhealth.datasets.OMOPDataset 17 | datasets/pyhealth.datasets.SHHSDataset 18 | datasets/pyhealth.datasets.SleepEDFDataset 19 | datasets/pyhealth.datasets.TUABDataset 20 | datasets/pyhealth.datasets.TUEVDataset 21 | datasets/pyhealth.datasets.splitter 22 | datasets/pyhealth.datasets.utils 23 | 24 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.BaseDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.BaseDataset 2 | =================================== 3 | 4 | This is the basic base dataset class. Any specific datasets will inherit from this class. 5 | 6 | .. autoclass:: pyhealth.datasets.BaseDataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.CardiologyDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.CardiologyDataset 2 | =================================== 3 | 4 | The Cardiology dataset includes six portions "cpsc_2018", "cpsc_2018_extra", "georgia", "ptb", "ptb-xl", "st_petersburg_incart", refer to `doc `_ for more information. 5 | 6 | .. autoclass:: pyhealth.datasets.CardiologyDataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.ISRUCDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.ISRUCDataset 2 | =================================== 3 | 4 | The open ISRUC EEF database, refer to `doc `_ for more information. 5 | 6 | .. autoclass:: pyhealth.datasets.ISRUCDataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.MIMIC3Dataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.MIMIC3Dataset 2 | =================================== 3 | 4 | The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. 5 | 6 | .. autoclass:: pyhealth.datasets.MIMIC3Dataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.MIMIC4Dataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.MIMIC4Dataset 2 | =================================== 3 | 4 | The open Medical Information Mart for Intensive Care (MIMIC-IV) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. 5 | 6 | .. autoclass:: pyhealth.datasets.MIMIC4Dataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.MIMICExtractDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.MIMICExtractDataset 2 | =================================== 3 | 4 | The open Medical Information Mart for Intensive Care (MIMIC-III) database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. 5 | 6 | .. autoclass:: pyhealth.datasets.MIMICExtractDataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.MedicalTranscriptionsDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.MedicalTranscriptionsDataset 2 | =================================== 3 | 4 | The Medical Transcriptions dataset, refer to `doc `_. 5 | 6 | .. autoclass:: pyhealth.datasets.MedicalTranscriptionsDataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.OMOPDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.OMOPDataset 2 | =================================== 3 | 4 | We can process any OMOP-CDM formatted database, refer to `doc `_ for more information. The raw data is processed into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. 5 | 6 | .. autoclass:: pyhealth.datasets.OMOPDataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.SHHSDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.SHHSDataset 2 | =================================== 3 | 4 | The open Sleep-EDF Database Expanded database, refer to `doc `_ for more information. 5 | 6 | .. autoclass:: pyhealth.datasets.SHHSDataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.SampleDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.SampleDataset 2 | =================================== 3 | 4 | This class is the basic sample dataset. All sample datasets are inherited from this class. 5 | 6 | .. autoclass:: pyhealth.datasets.SampleDataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.SleepEDFDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.SleepEDFDataset 2 | =================================== 3 | 4 | The open Sleep-EDF Database Expanded database, refer to `doc `_ for more information. 5 | 6 | .. autoclass:: pyhealth.datasets.SleepEDFDataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.TUABDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.TUABDataset 2 | =================================== 3 | 4 | Dataset is available at https://isip.piconepress.com/projects/nedc/html/tuh_eeg/#c_tuab 5 | 6 | The TUAB dataset (or Temple University Hospital EEG Abnormal Corpus) is a collection of EEG data acquired at the Temple University Hospital. 7 | 8 | The dataset contains both normal and abnormal EEG readings. 9 | 10 | .. autoclass:: pyhealth.datasets.TUABDataset 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.TUEVDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.TUEVDataset 2 | =================================== 3 | 4 | Dataset is available at https://isip.piconepress.com/projects/tuh_eeg/html/downloads.shtml 5 | 6 | This corpus is a subset of TUEG that contains annotations of EEG segments as one of six classes: (1) spike and sharp wave (SPSW), (2) generalized periodic epileptiform discharges (GPED), (3) periodic lateralized epileptiform discharges (PLED), (4) eye movement (EYEM), (5) artifact (ARTF) and (6) background (BCKG). 7 | 8 | .. autoclass:: pyhealth.datasets.TUEVDataset 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.eICUDataset.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.eICUDataset 2 | =================================== 3 | 4 | The open eICU Collaborative Research Database, refer to `doc `_ for more information. We process this database into well-structured dataset object and give user the **best flexibility and convenience** for supporting modeling and analysis. 5 | 6 | .. autoclass:: pyhealth.datasets.eICUDataset 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.splitter.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.splitter 2 | =================================== 3 | 4 | Several data splitting function for `pyhealth.datasets` module to obtain training / validation / test sets. 5 | 6 | .. automodule:: pyhealth.datasets.splitter 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /docs/api/datasets/pyhealth.datasets.utils.rst: -------------------------------------------------------------------------------- 1 | pyhealth.datasets.utils 2 | =================================== 3 | 4 | Several utility functions. 5 | 6 | .. automodule:: pyhealth.datasets.utils 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /docs/api/medcode.rst: -------------------------------------------------------------------------------- 1 | MedCode 2 | =================================== 3 | 4 | We provide medical code mapping tools for (i) ontology mapping within one coding system and 5 | (ii) mapping the same concept cross different coding systems. 6 | 7 | 8 | .. autoclass:: pyhealth.medcode.InnerMap 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | .. autoclass:: pyhealth.medcode.CrossMap 14 | :members: 15 | :undoc-members: 16 | :show-inheritance: 17 | 18 | Diagnosis codes: 19 | ---------------------- 20 | 21 | .. autoclass:: pyhealth.medcode.ICD9CM 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | 26 | .. autoclass:: pyhealth.medcode.ICD10CM 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | .. autoclass:: pyhealth.medcode.CCSCM 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Procedure codes: 37 | ---------------------- 38 | 39 | .. autoclass:: pyhealth.medcode.ICD9PROC 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | .. autoclass:: pyhealth.medcode.ICD10PROC 45 | :members: 46 | :undoc-members: 47 | :show-inheritance: 48 | 49 | .. autoclass:: pyhealth.medcode.CCSPROC 50 | :members: 51 | :undoc-members: 52 | :show-inheritance: 53 | 54 | Medication codes: 55 | ----------------------- 56 | 57 | .. autoclass:: pyhealth.medcode.NDC 58 | :members: 59 | :undoc-members: 60 | :show-inheritance: 61 | 62 | .. autoclass:: pyhealth.medcode.RxNorm 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | 67 | .. autoclass:: pyhealth.medcode.ATC 68 | :members: 69 | :undoc-members: 70 | :show-inheritance: 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /docs/api/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | =============== 3 | 4 | We provide easy to use metrics (the same style and args as `sklearn.metrics`) for binary classification, 5 | multiclass classification, multilabel classification. 6 | For applicable tasks, we provide the relevant metrics for model calibration, as well as those for prediction set evaluation. 7 | Among these we also provide metrics related to uncertainty quantification, for model calibration, as well as metrics that measure the quality of prediction sets 8 | We also provide other metrics specically for healthcare 9 | tasks, such as drug drug interaction (DDI) rate. 10 | 11 | 12 | .. toctree:: 13 | :maxdepth: 3 14 | 15 | metrics/pyhealth.metrics.multiclass 16 | metrics/pyhealth.metrics.multilabel 17 | metrics/pyhealth.metrics.binary 18 | metrics/pyhealth.metrics.calibration 19 | metrics/pyhealth.metrics.prediction_set 20 | metrics/pyhealth.metrics.fairness 21 | -------------------------------------------------------------------------------- /docs/api/metrics/pyhealth.metrics.binary.rst: -------------------------------------------------------------------------------- 1 | pyhealth.metrics.binary 2 | =================================== 3 | 4 | .. currentmodule:: pyhealth.metrics.binary 5 | 6 | .. autofunction:: binary_metrics_fn -------------------------------------------------------------------------------- /docs/api/metrics/pyhealth.metrics.calibration.rst: -------------------------------------------------------------------------------- 1 | [core] calibration 2 | =================================== 3 | 4 | .. automodule:: pyhealth.metrics.calibration 5 | :members: 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /docs/api/metrics/pyhealth.metrics.fairness.rst: -------------------------------------------------------------------------------- 1 | pyhealth.metrics.fairness 2 | =================================== 3 | 4 | .. currentmodule:: pyhealth.metrics.fairness 5 | 6 | .. autofunction:: fairness_metrics_fn 7 | 8 | .. currentmodule:: pyhealth.metrics.fairness_utils 9 | 10 | .. autofunction:: disparate_impact 11 | 12 | .. autofunction:: statistical_parity_difference 13 | 14 | .. autofunction:: sensitive_attributes_from_patient_ids -------------------------------------------------------------------------------- /docs/api/metrics/pyhealth.metrics.multiclass.rst: -------------------------------------------------------------------------------- 1 | pyhealth.metrics.multiclass 2 | =================================== 3 | 4 | .. currentmodule:: pyhealth.metrics.multiclass 5 | 6 | .. autofunction:: multiclass_metrics_fn 7 | 8 | -------------------------------------------------------------------------------- /docs/api/metrics/pyhealth.metrics.multilabel.rst: -------------------------------------------------------------------------------- 1 | pyhealth.metrics.multilabel 2 | =================================== 3 | 4 | .. currentmodule:: pyhealth.metrics.multilabel 5 | 6 | .. autofunction:: multilabel_metrics_fn -------------------------------------------------------------------------------- /docs/api/metrics/pyhealth.metrics.prediction_set.rst: -------------------------------------------------------------------------------- 1 | [core] prediction_set 2 | =================================== 3 | 4 | .. automodule:: pyhealth.metrics.prediction_set 5 | :members: 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /docs/api/models.rst: -------------------------------------------------------------------------------- 1 | Models 2 | =============== 3 | 4 | We implement the following models for supporting multiple healthcare predictive tasks. 5 | 6 | 7 | .. toctree:: 8 | :maxdepth: 3 9 | 10 | models/pyhealth.models.BaseModel 11 | models/pyhealth.models.MLP 12 | models/pyhealth.models.CNN 13 | models/pyhealth.models.RNN 14 | models/pyhealth.models.GNN 15 | models/pyhealth.models.Transformer 16 | models/pyhealth.models.TransformersModel 17 | models/pyhealth.models.RETAIN 18 | models/pyhealth.models.GAMENet 19 | models/pyhealth.models.MICRON 20 | models/pyhealth.models.SafeDrug 21 | models/pyhealth.models.MoleRec 22 | models/pyhealth.models.Deepr 23 | models/pyhealth.models.ContraWR 24 | models/pyhealth.models.SparcNet 25 | models/pyhealth.models.StageNet 26 | models/pyhealth.models.AdaCare 27 | models/pyhealth.models.ConCare 28 | models/pyhealth.models.Agent 29 | models/pyhealth.models.GRASP 30 | models/pyhealth.models.MedLink 31 | models/pyhealth.models.TCN 32 | models/pyhealth.models.GAN 33 | models/pyhealth.models.VAE 34 | -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.AdaCare.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.AdaCare 2 | =================================== 3 | 4 | The separate callable AdaCareLayer and the complete AdaCare model. 5 | 6 | .. autoclass:: pyhealth.models.AdaCareLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.AdaCare 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.Agent.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.Agent 2 | =================================== 3 | 4 | The separate callable AgentLayer and the complete Agent model. 5 | 6 | .. autoclass:: pyhealth.models.AgentLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.Agent 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.BaseModel.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.BaseModel 2 | =================================== 3 | 4 | The base class for all models. 5 | 6 | .. autoclass:: pyhealth.models.BaseModel 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.CNN.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.CNN 2 | =================================== 3 | 4 | The separate callable CNNLayer and the complete CNN model. 5 | 6 | .. autoclass:: pyhealth.models.CNNLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.CNN 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: 15 | -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.ConCare.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.ConCare 2 | =================================== 3 | 4 | The separate callable ConCareLayer and the complete ConCare model. 5 | 6 | .. autoclass:: pyhealth.models.ConCareLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.ConCare 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.ContraWR.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.ContraWR 2 | =================================== 3 | 4 | 5 | The separate callable ResBlock2D and the complete ContraWR model. 6 | 7 | .. autoclass:: pyhealth.models.ResBlock2D 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | .. autoclass:: pyhealth.models.ContraWR 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.Deepr.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.Deepr 2 | =================================== 3 | 4 | The separate callable DeeprLayer and the complete Deepr model. 5 | 6 | .. autoclass:: pyhealth.models.DeeprLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.Deepr 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: 15 | -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.GAMENet.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.GAMENet 2 | =================================== 3 | 4 | The separate callable GAMENetLayer and the complete GAMENet model. 5 | 6 | .. autoclass:: pyhealth.models.GAMENetLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.GAMENet 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: 15 | -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.GAN.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.GAN 2 | =================================== 3 | 4 | The GAN model (pyhealth trainer does not apply to GAN, refer to the example/ChestXray-image-generation-GAN.ipynb for examples of using GAN model). 5 | 6 | .. autoclass:: pyhealth.models.GAN 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.GNN.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.GNN 2 | =================================== 3 | 4 | The GNN model (pyhealth trainer does not apply to GNN, refer to the example/ChestXray-image-generation-GAN.ipynb for examples of using GNN model). 5 | 6 | .. autoclass:: pyhealth.models.GAT 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.GCN 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: 15 | -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.GRASP.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.GRASP 2 | =================================== 3 | 4 | The separate callable GRASPLayer and the complete GRASP model. 5 | 6 | .. autoclass:: pyhealth.models.GRASPLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.GRASP 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.MICRON.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.MICRON 2 | =================================== 3 | 4 | The separate callable MICRONLayer and the complete MICRON model. 5 | 6 | .. autoclass:: pyhealth.models.MICRONLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.MICRON 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.MLP.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.MLP 2 | =================================== 3 | 4 | The separate callable MLP model. 5 | 6 | .. autoclass:: pyhealth.models.MLP 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.MedLink.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.MedLink 2 | =================================== 3 | 4 | The complete MedLink model. 5 | 6 | 7 | .. autoclass:: pyhealth.models.MedLink 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.MoleRec.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.MoleRec 2 | =================================== 3 | 4 | The separate callable MoleRecLayer and the complete MoleRec model. 5 | 6 | .. autoclass:: pyhealth.models.MoleRecLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.MoleRec 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.RETAIN.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.RETAIN 2 | =================================== 3 | 4 | The separate callable RETAINLayer and the complete RETAIN model. 5 | 6 | .. autoclass:: pyhealth.models.RETAINLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.RETAIN 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.RNN.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.RNN 2 | =================================== 3 | 4 | 5 | The separate callable RNNLayer and the complete RNN model. 6 | 7 | .. autoclass:: pyhealth.models.RNNLayer 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | .. autoclass:: pyhealth.models.RNN 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.SafeDrug.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.SafeDrug 2 | =================================== 3 | 4 | 5 | The separate callable SafeDrugLayer and the complete SafeDrug model. 6 | 7 | .. autoclass:: pyhealth.models.SafeDrugLayer 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | .. autoclass:: pyhealth.models.SafeDrug 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.SparcNet.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.SparcNet 2 | =================================== 3 | 4 | 5 | The SparcNet Model: Jin Jing, et al. Development of Expert-level Classification of Seizures and Rhythmic and Periodic Patterns During EEG Interpretation. Neurology 2023. 6 | 7 | .. autoclass:: pyhealth.models.DenseLayer 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | .. autoclass:: pyhealth.models.DenseBlock 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: 16 | 17 | .. autoclass:: pyhealth.models.TransitionLayer 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | .. autoclass:: pyhealth.models.SparcNet 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.StageNet.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.StageNet 2 | =================================== 3 | 4 | The separate callable StageNetLayer and the complete StageNet model. 5 | 6 | .. autoclass:: pyhealth.models.StageNetLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.StageNet 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.TCN.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.TCN 2 | =================================== 3 | 4 | The separate callable TCNLayer and the complete TCN model. 5 | 6 | .. autoclass:: pyhealth.models.TCNLayer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | .. autoclass:: pyhealth.models.TCN 12 | :members: 13 | :undoc-members: 14 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.Transformer.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.Transformer 2 | =================================== 3 | 4 | 5 | The separate callable TransformerLayer and the complete Transformer model. 6 | 7 | .. autoclass:: pyhealth.models.TransformerLayer 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | .. autoclass:: pyhealth.models.Transformer 13 | :members: 14 | :undoc-members: 15 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.TransformersModel.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.TransformersModel 2 | =================================== 3 | 4 | 5 | The complete TransformersModel model. 6 | 7 | .. autoclass:: pyhealth.models.TransformersModel 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/models/pyhealth.models.VAE.rst: -------------------------------------------------------------------------------- 1 | pyhealth.models.VAE 2 | =================================== 3 | 4 | The VAE model (treated as a regression task). 5 | 6 | .. autoclass:: pyhealth.models.VAE 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | -------------------------------------------------------------------------------- /docs/api/tasks.rst: -------------------------------------------------------------------------------- 1 | Tasks 2 | =============== 3 | 4 | We support various real-world healthcare predictive tasks defined by **function calls**. The following example tasks are collected from top AI/Medical venues, such as: 5 | 6 | (i) Drug Recommendation [Yang et al. IJCAI 2021a, Yang et al. IJCAI 2021b, Shang et al. AAAI 2020] 7 | 8 | (ii) Readmission Prediction [Choi et al. AAAI 2021] 9 | 10 | (iii) Mortality Prediction [Choi et al. AAAI 2021] 11 | 12 | (iv) Length of Stay Prediction 13 | 14 | (v) Sleep Staging [Yang et al. ArXiv 2021] 15 | 16 | .. toctree:: 17 | :maxdepth: 3 18 | 19 | tasks/pyhealth.tasks.BaseTask 20 | tasks/pyhealth.tasks.Readmission30DaysMIMIC4 21 | tasks/pyhealth.tasks.InHospitalMortalityMIMIC4 22 | tasks/pyhealth.tasks.MIMIC3ICD9Coding 23 | tasks/pyhealth.tasks.cardiology_detect 24 | tasks/pyhealth.tasks.COVID19CXRClassification 25 | tasks/pyhealth.tasks.drug_recommendation 26 | tasks/pyhealth.tasks.EEG_abnormal 27 | tasks/pyhealth.tasks.EEG_events 28 | tasks/pyhealth.tasks.length_of_stay_prediction 29 | tasks/pyhealth.tasks.MedicalTranscriptionsClassification 30 | tasks/pyhealth.tasks.mortality_prediction 31 | tasks/pyhealth.tasks.patient_linkage_mimic3_fn 32 | tasks/pyhealth.tasks.readmission_prediction 33 | tasks/pyhealth.tasks.sleep_staging 34 | tasks/pyhealth.tasks.SleepStagingSleepEDF 35 | tasks/pyhealth.tasks.temple_university_EEG_tasks 36 | 37 | -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.BaseTask.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.BaseTask 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.BaseTask 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.COVID19CXRClassification.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.COVID19CXRClassification 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.COVID19CXRClassification 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.EEG_abnormal.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.EEG_abnormal 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.EEG_abnormal 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.EEG_events.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.EEG_events 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.EEG_events 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.InHospitalMortalityMIMIC4.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.EEG_evInHospitalMortalityMIMIC4ents 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.InHospitalMortalityMIMIC4 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.MIMIC3ICD9Coding.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.MIMIC3ICD9Coding 2 | =========================================== 3 | 4 | 5 | .. automodule:: pyhealth.tasks.MIMIC3ICD9Coding 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.MedicalTranscriptionsClassification.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.MedicalTranscriptionsClassification 2 | =========================================== 3 | 4 | 5 | .. automodule:: pyhealth.tasks.MedicalTranscriptionsClassification 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.Readmission30DaysMIMIC4.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.Readmission30DaysMIMIC4 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.Readmission30DaysMIMIC4 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.SleepStagingSleepEDF.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.SleepStagingSleepEDF 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.SleepStagingSleepEDF 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.cardiology_detect.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.cardiology_detect 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.cardiology_detect 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.drug_recommendation.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.drug_recommendation 2 | =================================== 3 | 4 | 5 | .. automodule:: pyhealth.tasks.drug_recommendation 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.length_of_stay_prediction.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.length_of_stay_prediction 2 | =========================================== 3 | 4 | 5 | .. automodule:: pyhealth.tasks.length_of_stay_prediction 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.mortality_prediction.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.mortality_prediction 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.mortality_prediction 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.patient_linkage_mimic3_fn.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.patient_linkage_mimic3_fn 2 | =========================================== 3 | 4 | 5 | .. automodule:: pyhealth.tasks.patient_linkage_mimic3_fn 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.readmission_prediction.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.readmission_prediction 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.readmission_prediction 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.sleep_staging.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.sleep_staging 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.sleep_staging 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tasks/pyhealth.tasks.temple_university_EEG_tasks.rst: -------------------------------------------------------------------------------- 1 | pyhealth.tasks.temple_university_EEG_tasks 2 | ======================================= 3 | 4 | 5 | .. automodule:: pyhealth.tasks.temple_university_EEG_tasks 6 | :members: 7 | :undoc-members: 8 | :show-inheritance: -------------------------------------------------------------------------------- /docs/api/tokenizer.rst: -------------------------------------------------------------------------------- 1 | Tokenizer 2 | =============== 3 | 4 | The tokenizer functionality can be used for supporting tokens-to-index or index-to-token mapping in general ML setting. 5 | 6 | .. automodule:: pyhealth.tokenizer 7 | :members: 8 | :undoc-members: 9 | :show-inheritance: 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /docs/api/trainer.rst: -------------------------------------------------------------------------------- 1 | Trainer 2 | =================================== 3 | 4 | .. autoclass:: pyhealth.trainer.Trainer 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: -------------------------------------------------------------------------------- /docs/developer.rst: -------------------------------------------------------------------------------- 1 | .. _how_to_contribute: 2 | 3 | ===================== 4 | How to Contribute 5 | ===================== 6 | 7 | Thank you for your interest in contributing to PyHealth! We welcome contributions of all types, including bug fixes, new features, documentation improvements, and more. 8 | 9 | Contribution Guidelines 10 | ======================= 11 | 12 | - For **bug fixes** and **documentation improvements**, you can get started on your own. 13 | - For **major framework updates, new datasets, or new models**, it's best to **open an issue and discuss with the PyHealth team first** before proceeding. 14 | 15 | Getting Started 16 | =============== 17 | 18 | 1. **Fork the Repository**: Navigate to the PyHealth GitHub repository and click the **Fork** button. 19 | 2. **Clone Your Fork**: 20 | .. code-block:: bash 21 | 22 | git clone https://github.com/your-username/PyHealth.git 23 | cd PyHealth 24 | 3. **Create a New Branch**: 25 | .. code-block:: bash 26 | 27 | git checkout -b feature-branch 28 | 29 | Setting Up the Development Environment 30 | ====================================== 31 | 32 | To avoid interfering with an existing PyHealth installation via `pip`, it is recommended to create a new Python environment specifically for development. 33 | You can use `conda` or `virtualenv` to do so. 34 | 35 | Making Changes 36 | ============== 37 | 38 | 1. **Write Code**: Implement your changes in a clean and modular way. 39 | 2. **Follow Code Style**: Ensure your code follows `black` formatting: 40 | .. code-block:: bash 41 | 42 | black . 43 | 3. **Write Tests**: Add unit tests for your changes in `unittests/`. 44 | 4. **Run Tests**: Before submitting, make sure all tests pass: 45 | .. code-block:: bash 46 | 47 | pytest 48 | 49 | Submitting Your Changes 50 | ======================= 51 | 52 | 1. **Commit Changes**: 53 | .. code-block:: bash 54 | 55 | git add . 56 | git commit -m "Description of your changes" 57 | 2. **Push to Your Fork**: 58 | .. code-block:: bash 59 | 60 | git push origin feature-branch 61 | 3. **Open a Pull Request (PR)**: 62 | - Go to the original PyHealth repository. 63 | - Click **New Pull Request**. 64 | - Select your fork and branch. 65 | - Provide a clear description of your changes. 66 | - Submit the PR. 67 | 68 | Review Process 69 | ============== 70 | 71 | - The PyHealth maintainers will review your PR and provide feedback. 72 | - Make any requested changes and push updates to your PR. 73 | - Once approved, your changes will be merged into the main branch! 74 | 75 | Need Help? 76 | ========== 77 | 78 | - Check the `Issues` tab on GitHub. 79 | - Reach out via `email` if applicable. 80 | 81 | We appreciate your contributions and look forward to working with you! -------------------------------------------------------------------------------- /docs/faq.rst: -------------------------------------------------------------------------------- 1 | Frequently Asked Questions 2 | ========================== 3 | 4 | ---- 5 | 6 | 7 | Blueprint & Development Plan 8 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 9 | 10 | The long term goal of PyHealth is to become a comprehensive healthcare AI toolkit that supports 11 | beyond EHR data, but also the images and clinical notes. 12 | 13 | This is the central place to track important things to be fixed/added: 14 | 15 | - The support of image datasets and clinical notes 16 | - The compatibility and the support of OMOP format datasets 17 | - Model persistence (save, load, and portability) 18 | - The release of a benchmark paper with PyHealth 19 | - Add contact channel with `Gitter `_ 20 | - Support additional languages, see `Manage Translations `_ 21 | 22 | Feel free to open on issue report if needed. 23 | See `Issues `_. 24 | 25 | 26 | Inclusion Criteria 27 | ^^^^^^^^^^^^^^^^^^ 28 | 29 | Similarly to Similarly to scikit-learn, We mainly consider well-established algorithms for inclusion. 30 | A rule of thumb is at least two years since publication, 50+ citations, and usefulness. 31 | 32 | However, we encourage the author(s) of newly proposed models to share and add your implementation into combo 33 | for boosting ML accessibility and reproducibility. 34 | This exception only applies if you could commit to the maintenance of your model for at least two year period. 35 | 36 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | You could install from PyPi: 5 | 6 | .. code-block:: bash 7 | 8 | pip install pyhealth 9 | 10 | 11 | or from github source: 12 | 13 | .. code-block:: bash 14 | 15 | git clone https://github.com/sunlabuiuc/PyHealth.git 16 | cd pyhealth 17 | pip install . 18 | 19 | 20 | **Required Dependencies**\ : 21 | 22 | .. code-block:: bash 23 | 24 | python>=3.8 25 | torch>=1.8.0 26 | rdkit>=2022.03.4 27 | scikit-learn>=0.24.2 28 | networkx>=2.6.3 29 | pandas>=1.3.2 30 | tqdm 31 | 32 | **Warning 1**\ : 33 | 34 | PyHealth has multiple neural network based models, e.g., LSTM, which are 35 | implemented in PyTorch. However, PyHealth does **NOT** install these DL libraries for you. 36 | This reduces the risk of interfering with your local copies. 37 | If you want to use neural-net based models, please make sure PyTorch is installed. 38 | Similarly, models depending on **xgboost** would **NOT** enforce xgboost installation by default. 39 | 40 | 41 | **CUDA Setting**\ : 42 | 43 | To run PyHealth, you also need CUDA and cudatoolkit that support your GPU well. `More info `_ 44 | 45 | For example, if you use NVIDIA RTX A6000 as your GPU for training, you should install a compatible cudatoolkit using: 46 | 47 | .. code-block:: bash 48 | 49 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch. 50 | 51 | ---- -------------------------------------------------------------------------------- /docs/live.rst: -------------------------------------------------------------------------------- 1 | PyHealth Webinar 4/9/25 2 | ====================== 3 | 4 | **Start Time**: 2 PM Central Time, Wednesday 5 | 6 | **Recurrence**: There are no repeat sessions for now. 7 | 8 | **Zoom**: `Join Link `_ 9 | 10 | **YouTube**: `Recorded Live Sessions `_ 11 | 12 | **User/Developer Discord**: `Click to join `_ 13 | 14 | Past Scheduled Videos 15 | ^^^^^^^^^^^^^^ 16 | **(Dec 21, 2022)** Live 01 - What is PyHealth and How to Get Started? `[Recap] `_ 17 | 18 | **(Dec 28, 2022)** Live 02 - Data & Datasets & Tasks: store unstructured data in an structured way. `[Recap I] `_ `[II] `_ `[III] `_ `[IV] `_ 19 | 20 | **(Jan 4, 2023)** Live 03 - Models & Trainer & Metrics: initialize and train a deep learning model. `[Recap I] `_ `[II] `_ `[III] `_ 21 | 22 | **(Jan 11, 2023)** Live 04 - Tokenizer & Medcode: master the medical code lookup and mapping `[Recap I] `_ `[II] `_ 23 | 24 | **(Jan 18, 2023)** Live 05 - PyHealth can support a complete healthcare ML pipeline `[Recap I] `_ `[II] `_ 25 | 26 | **(Jan 25, 2023)** Live 06 - Fit your own dataset into pipeline and use our model `[Recap] `_ 27 | 28 | **(Feb 1, 2023)** Live 07 - Adopt your customized model and quickly try it on our data `[Recap] `_ 29 | 30 | **(Feb 8, 2023)** Live 08 - New feature: support for biosignal data (EEG, ECG, etc.) classification `[Recap I] `_ `[II] `_ 31 | 32 | **(Feb 15, 2023)** Live 09 - New feature: parallel and faster data loading architecture 33 | 34 | **(Feb 22, 2023)** Live 10 - Add a covid prediction benchmark (new datasets, new models) -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx==5.2.3 2 | sphinx-automodapi 3 | sphinx-autodoc-annotation 4 | sphinx_last_updated_by_git 5 | sphinxcontrib-spelling 6 | sphinx-autodoc-typehints 7 | sphinx_gallery==0.11.1 8 | sphinxcontrib-bibtex==2.5.0 9 | nbsphinx==0.8.9 10 | nbsphinx-link==1.3.0 11 | sphinx-copybutton==0.5.0 12 | furo 13 | myst-parser==0.18.1 14 | sphinx-remove-toctrees==0.0.3 15 | sphinx-design==0.3.0 16 | sphinxext-opengraph==0.6.3 17 | sphinxcontrib-httpdomain==1.8.0 18 | sphinx_copybutton 19 | sphinx_toggleprompt 20 | pandas-bokeh==0.5.5 21 | bokeh==3.0.1 22 | gspread==5.6.2 23 | google-cloud-storage==2.6.0 24 | oauth2client==4.1.3 25 | Jinja2==3.1.4 26 | flask==2.2.5 -------------------------------------------------------------------------------- /docs/use_cases.rst: -------------------------------------------------------------------------------- 1 | .. _use_cases: 2 | 3 | ===================== 4 | Use Cases of PyHealth 5 | ===================== 6 | 7 | PyHealth enables various healthcare machine learning applications. Below are some practical use cases, categorized by data modality, each linked to an interactive Google Colab notebook. 8 | 9 | Structured Data 10 | =============== 11 | 12 | Predicting Hospital Readmission 13 | ------------------------------- 14 | Hospital readmission prediction helps identify patients at high risk of returning to the hospital shortly after discharge. This can assist healthcare providers in taking preventive measures. 15 | 16 | - [Colab Notebook](https://colab.research.google.com/drive/1bhCwbXce1YFtVaQLsOt4FcyZJ1_my7Cs?usp=sharing) 17 | 18 | Drug Recommendation System 19 | -------------------------- 20 | Personalized drug recommendation models can suggest appropriate medications based on a patient’s medical history, improving treatment outcomes. 21 | 22 | - [Colab Notebook](https://colab.research.google.com/drive/10CSb4F4llYJvv42yTUiRmvSZdoEsbmFF?usp=sharing) 23 | 24 | Length of Stay Prediction 25 | -------------------------- 26 | Predicting hospital length of stay aids resource allocation, bed management, and patient care planning in hospitals. 27 | 28 | - [Colab Notebook](https://colab.research.google.com/drive/1JoPpXqqB1_lGF1XscBOsDHMLtgvlOYI1?usp=sharing) 29 | 30 | Mortality Prediction from ICU Data 31 | ----------------------------------- 32 | Predicting ICU patient mortality using clinical data can help prioritize critical care and optimize resource usage. 33 | 34 | - [Colab Notebook](https://colab.research.google.com/drive/1Qblpcv4NWjrnADT66TjBcNwOe8x6wU4c?usp=sharing) 35 | 36 | Time-Series Data 37 | ================ 38 | 39 | Sleep Staging 40 | ------------- 41 | Sleep staging classification uses EEG data to determine different sleep stages, aiding in the diagnosis and treatment of sleep disorders. 42 | 43 | - [Colab Notebook](https://colab.research.google.com/drive/1mpSeNCAthXG3cqROkdUcUdozIPIMTCuo?usp=sharing) 44 | 45 | Imaging Data 46 | ============ 47 | 48 | X-ray Classification 49 | --------------------- 50 | X-ray classification models can assist radiologists by automatically detecting abnormalities in chest X-rays and other radiographic images. 51 | 52 | - [Colab Notebook](https://drive.google.com/file/d/1XokhV8dN3lis7gMdjpMhZBEGs03sus9R/view?usp=share_link) 53 | 54 | Text Data 55 | ========= 56 | 57 | Medical Transcription Classification 58 | ------------------------------------- 59 | Classifying medical transcriptions enables automated processing of clinical notes, improving documentation efficiency and accessibility. 60 | 61 | - [Colab Notebook](https://drive.google.com/file/d/1JxQYEj94WjEsRifAOyEfrFWWIjqXxrqH/view?usp=share_link) 62 | 63 | Each notebook provides step-by-step guidance on data processing, model training, and evaluation using PyHealth. 64 | -------------------------------------------------------------------------------- /examples/EEG_events_SparcNet.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import split_by_visit, get_dataloader 2 | from pyhealth.trainer import Trainer 3 | from pyhealth.datasets import TUEVDataset 4 | from pyhealth.tasks import EEG_events_fn 5 | from pyhealth.models import SparcNet 6 | 7 | # step 1: load signal data 8 | dataset = TUEVDataset(root="/srv/local/data/TUH/tuh_eeg_events/v2.0.0/edf/", 9 | dev=True, 10 | refresh_cache=True, 11 | ) 12 | 13 | # step 2: set task 14 | TUEV_ds = dataset.set_task(EEG_events_fn) 15 | TUEV_ds.stat() 16 | 17 | # split dataset 18 | train_dataset, val_dataset, test_dataset = split_by_visit( 19 | TUEV_ds, [0.6, 0.2, 0.2] 20 | ) 21 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 22 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 23 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 24 | print( 25 | "loader size: train/val/test", 26 | len(train_dataset), 27 | len(val_dataset), 28 | len(test_dataset), 29 | ) 30 | 31 | # STEP 3: define model 32 | model = SparcNet( 33 | dataset=TUEV_ds, 34 | feature_keys=["signal"], 35 | label_key="label", 36 | mode="multiclass", 37 | ) 38 | 39 | # STEP 4: define trainer 40 | trainer = Trainer(model=model, device="cuda:4") 41 | trainer.train( 42 | train_dataloader=train_dataloader, 43 | val_dataloader=val_dataloader, 44 | epochs=10, 45 | optimizer_params={"lr": 1e-3}, 46 | ) 47 | 48 | # STEP 5: evaluate 49 | print(trainer.evaluate(test_dataloader)) 50 | -------------------------------------------------------------------------------- /examples/EEG_isAbnormal_SparcNet.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import split_by_visit, get_dataloader 2 | from pyhealth.trainer import Trainer 3 | from pyhealth.datasets import TUABDataset 4 | from pyhealth.tasks import EEG_isAbnormal_fn 5 | from pyhealth.models import SparcNet 6 | 7 | # step 1: load signal data 8 | dataset = TUABDataset(root="/srv/local/data/TUH/tuh_eeg_abnormal/v3.0.0/edf/", 9 | dev=True, 10 | refresh_cache=True, 11 | ) 12 | 13 | # step 2: set task 14 | TUAB_ds = dataset.set_task(EEG_isAbnormal_fn) 15 | TUAB_ds.stat() 16 | 17 | # split dataset 18 | train_dataset, val_dataset, test_dataset = split_by_visit( 19 | TUAB_ds, [0.6, 0.2, 0.2] 20 | ) 21 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 22 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 23 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 24 | print( 25 | "loader size: train/val/test", 26 | len(train_dataset), 27 | len(val_dataset), 28 | len(test_dataset), 29 | ) 30 | 31 | # STEP 3: define model 32 | model = SparcNet( 33 | dataset=TUAB_ds, 34 | feature_keys=["signal"], 35 | label_key="label", 36 | mode="binary", 37 | ) 38 | 39 | # STEP 4: define trainer 40 | trainer = Trainer(model=model, device="cuda:4") 41 | trainer.train( 42 | train_dataloader=train_dataloader, 43 | val_dataloader=val_dataloader, 44 | epochs=10, 45 | monitor="pr_auc", 46 | optimizer_params={"lr": 1e-3}, 47 | ) 48 | 49 | # STEP 5: evaluate 50 | print(trainer.evaluate(test_dataloader)) 51 | -------------------------------------------------------------------------------- /examples/cardiology_detection_isAR_SparcNet.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import split_by_visit, get_dataloader 2 | from pyhealth.trainer import Trainer 3 | from pyhealth.datasets import CardiologyDataset 4 | from pyhealth.tasks import cardiology_isAR_fn 5 | from pyhealth.models import ContraWR, SparcNet 6 | 7 | # step 1: load signal data 8 | dataset = CardiologyDataset(root="/srv/local/data/physionet.org/files/challenge-2020/1.0.2/training", 9 | chosen_dataset=[1,1,1,1,1,1], 10 | refresh_cache=False, 11 | dev=True) 12 | 13 | # step 2: set task 14 | cardiology_ds = dataset.set_task(cardiology_isAR_fn) 15 | cardiology_ds.stat() 16 | 17 | # split dataset 18 | train_dataset, val_dataset, test_dataset = split_by_visit( 19 | cardiology_ds, [0.6, 0.2, 0.2] 20 | ) 21 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 22 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 23 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 24 | print( 25 | "loader size: train/val/test", 26 | len(train_dataset), 27 | len(val_dataset), 28 | len(test_dataset), 29 | ) 30 | 31 | # STEP 3: define model 32 | model = SparcNet( 33 | dataset=cardiology_ds, 34 | feature_keys=["signal"], 35 | label_key="label", 36 | mode="binary", 37 | ) 38 | 39 | # STEP 4: define trainer 40 | trainer = Trainer(model=model, device="cuda:4") 41 | trainer.train( 42 | train_dataloader=train_dataloader, 43 | val_dataloader=val_dataloader, 44 | epochs=5, 45 | monitor="pr_auc", 46 | ) 47 | 48 | # STEP 5: evaluate 49 | print(trainer.evaluate(test_dataloader)) 50 | -------------------------------------------------------------------------------- /examples/chextXray_image_generation_VAE.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import split_by_visit, get_dataloader 2 | from pyhealth.trainer import Trainer 3 | from pyhealth.datasets import COVID19CXRDataset 4 | from pyhealth.models import VAE 5 | from torchvision import transforms 6 | 7 | import torch 8 | import numpy as np 9 | 10 | # step 1: load signal data 11 | root = "/srv/local/data/COVID-19_Radiography_Dataset" 12 | base_dataset = COVID19CXRDataset(root) 13 | 14 | # step 2: set task 15 | sample_dataset = base_dataset.set_task() 16 | 17 | # the transformation automatically normalize the pixel intensity into [0, 1] 18 | transform = transforms.Compose([ 19 | transforms.Lambda(lambda x: x if x.shape[0] == 3 else x.repeat(3, 1, 1)), # only use the first channel 20 | transforms.Resize((128, 128)), 21 | ]) 22 | 23 | def encode(sample): 24 | sample["path"] = transform(sample["path"]) 25 | return sample 26 | 27 | sample_dataset.set_transform(encode) 28 | 29 | 30 | # split dataset 31 | train_dataset, val_dataset, test_dataset = split_by_visit( 32 | sample_dataset, [0.6, 0.2, 0.2] 33 | ) 34 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True) 35 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False) 36 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False) 37 | 38 | data = next(iter(train_dataloader)) 39 | print (data) 40 | 41 | print (data["path"][0].shape) 42 | 43 | print( 44 | "loader size: train/val/test", 45 | len(train_dataset), 46 | len(val_dataset), 47 | len(test_dataset), 48 | ) 49 | 50 | # STEP 3: define model 51 | model = VAE( 52 | dataset=sample_dataset, 53 | input_channel=3, 54 | input_size=128, 55 | feature_keys=["path"], 56 | label_key="path", 57 | mode="regression", 58 | hidden_dim = 128, 59 | ) 60 | 61 | # STEP 4: define trainer 62 | trainer = Trainer(model=model, device="cuda:4", metrics=["kl_divergence", "mse", "mae"]) 63 | trainer.train( 64 | train_dataloader=train_dataloader, 65 | val_dataloader=val_dataloader, 66 | epochs=10, 67 | monitor="kl_divergence", 68 | monitor_criterion="min", 69 | optimizer_params={"lr": 1e-3}, 70 | ) 71 | 72 | # # STEP 5: evaluate 73 | # print(trainer.evaluate(test_dataloader)) 74 | 75 | 76 | import matplotlib.pyplot as plt 77 | 78 | # EXP 1: check the real chestxray image and the reconstructed image 79 | X, X_rec, _ = trainer.inference(test_dataloader) 80 | 81 | plt.figure() 82 | plt.subplot(1, 2, 1) 83 | plt.imshow(X[0].reshape(128, 128), cmap="gray") 84 | plt.subplot(1, 2, 2) 85 | plt.imshow(X_rec[0].reshape(128, 128), cmap="gray") 86 | plt.savefig("chestxray_vae_comparison.png") 87 | 88 | # EXP 2: random images 89 | model = trainer.model 90 | 91 | model.eval() 92 | with torch.no_grad(): 93 | x = np.random.normal(0, 1, 128) 94 | x = x.astype(np.float32) 95 | x = torch.from_numpy(x).to(trainer.device) 96 | rec = model.decoder(x).detach().cpu().numpy() 97 | rec = rec.reshape((128, 128)) 98 | plt.figure() 99 | plt.imshow(rec, cmap="gray") 100 | plt.savefig("chestxray_vae_synthetic.png") -------------------------------------------------------------------------------- /examples/drug_recommendation_mimic3_gamenet.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import GAMENet 4 | from pyhealth.tasks import drug_recommendation_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | # STEP 1: load data 8 | base_dataset = MIMIC3Dataset( 9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 11 | code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, 12 | dev=True, 13 | refresh_cache=False, 14 | ) 15 | base_dataset.stat() 16 | 17 | # STEP 2: set task 18 | sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn) 19 | sample_dataset.stat() 20 | 21 | train_dataset, val_dataset, test_dataset = split_by_patient( 22 | sample_dataset, [0.8, 0.1, 0.1] 23 | ) 24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 27 | 28 | # STEP 3: define model 29 | model = GAMENet( 30 | sample_dataset, 31 | ) 32 | 33 | # STEP 4: define trainer 34 | trainer = Trainer( 35 | model=model, 36 | metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"], 37 | ) 38 | 39 | trainer.train( 40 | train_dataloader=train_dataloader, 41 | val_dataloader=val_dataloader, 42 | epochs=20, 43 | monitor="pr_auc_samples", 44 | ) 45 | 46 | # STEP 5: evaluate 47 | print (trainer.evaluate(test_dataloader)) 48 | -------------------------------------------------------------------------------- /examples/drug_recommendation_mimic3_micron.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import MICRON 4 | from pyhealth.tasks import drug_recommendation_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | # STEP 1: load data 8 | base_dataset = MIMIC3Dataset( 9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 11 | code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, 12 | dev=True, 13 | refresh_cache=False, 14 | ) 15 | base_dataset.stat() 16 | 17 | # STEP 2: set task 18 | sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn) 19 | sample_dataset.stat() 20 | 21 | train_dataset, val_dataset, test_dataset = split_by_patient( 22 | sample_dataset, [0.8, 0.1, 0.1] 23 | ) 24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 27 | 28 | # STEP 3: define model 29 | model = MICRON( 30 | sample_dataset, 31 | ) 32 | 33 | # STEP 4: define trainer 34 | trainer = Trainer(model=model, metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"]) 35 | trainer.train( 36 | train_dataloader=train_dataloader, 37 | val_dataloader=val_dataloader, 38 | epochs=5, 39 | monitor="pr_auc_samples", 40 | ) 41 | 42 | # STEP 5: evaluate 43 | print (trainer.evaluate(test_dataloader)) 44 | -------------------------------------------------------------------------------- /examples/drug_recommendation_mimic3_molerec.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import MoleRec 4 | from pyhealth.tasks import drug_recommendation_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | # STEP 1: load data 8 | base_dataset = MIMIC3Dataset( 9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 11 | code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, 12 | dev=True, 13 | refresh_cache=False, 14 | ) 15 | base_dataset.stat() 16 | 17 | # STEP 2: set task 18 | sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn) 19 | sample_dataset.stat() 20 | 21 | train_dataset, val_dataset, test_dataset = split_by_patient( 22 | sample_dataset, [0.8, 0.1, 0.1] 23 | ) 24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 27 | 28 | # STEP 3: define model 29 | model = MoleRec( 30 | sample_dataset, 31 | feature_keys=["conditions", "procedures"], 32 | label_key="drugs", 33 | mode="multilabel", 34 | ) 35 | 36 | # STEP 4: define trainer 37 | trainer = Trainer( 38 | model=model, 39 | metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"], 40 | ) 41 | 42 | trainer.train( 43 | train_dataloader=train_dataloader, 44 | val_dataloader=val_dataloader, 45 | epochs=3, 46 | monitor="pr_auc_samples", 47 | ) 48 | 49 | # STEP 5: evaluate 50 | print (trainer.evaluate(test_dataloader)) 51 | -------------------------------------------------------------------------------- /examples/drug_recommendation_mimic3_safedrug.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import SafeDrug 4 | from pyhealth.tasks import drug_recommendation_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | # STEP 1: load data 8 | base_dataset = MIMIC3Dataset( 9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 11 | code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, 12 | dev=True, 13 | refresh_cache=False, 14 | ) 15 | base_dataset.stat() 16 | 17 | # STEP 2: set task 18 | sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn) 19 | sample_dataset.stat() 20 | 21 | train_dataset, val_dataset, test_dataset = split_by_patient( 22 | sample_dataset, [0.8, 0.1, 0.1] 23 | ) 24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 27 | 28 | # STEP 3: define model 29 | model = SafeDrug( 30 | sample_dataset, 31 | ) 32 | 33 | # STEP 4: define trainer 34 | trainer = Trainer( 35 | model=model, 36 | metrics=["jaccard_samples", "f1_samples", "pr_auc_samples", "ddi"], 37 | ) 38 | 39 | trainer.train( 40 | train_dataloader=train_dataloader, 41 | val_dataloader=val_dataloader, 42 | epochs=25, 43 | monitor="pr_auc_samples", 44 | ) 45 | 46 | # STEP 5: evaluate 47 | print (trainer.evaluate(test_dataloader)) 48 | -------------------------------------------------------------------------------- /examples/drug_recommendation_mimic3_transformer.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import Transformer 4 | from pyhealth.tasks import drug_recommendation_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | # STEP 1: load data 8 | base_dataset = MIMIC3Dataset( 9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 11 | code_mapping={"NDC": ("ATC", {"target_kwargs": {"level": 3}})}, 12 | dev=True, 13 | refresh_cache=True, 14 | ) 15 | 16 | base_dataset.stat() 17 | 18 | # STEP 2: set task 19 | sample_dataset = base_dataset.set_task(drug_recommendation_mimic3_fn) 20 | sample_dataset.stat() 21 | 22 | train_dataset, val_dataset, test_dataset = split_by_patient( 23 | sample_dataset, [0.8, 0.1, 0.1] 24 | ) 25 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 26 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 27 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 28 | 29 | # STEP 3: define model 30 | model = Transformer( 31 | dataset=sample_dataset, 32 | feature_keys=["conditions", "procedures"], 33 | label_key="drugs", 34 | mode="multilabel", 35 | ) 36 | 37 | # STEP 4: define trainer 38 | trainer = Trainer( 39 | model=model, 40 | metrics=["jaccard_samples", "f1_samples", "pr_auc_samples"], 41 | ) 42 | 43 | trainer.train( 44 | train_dataloader=train_dataloader, 45 | val_dataloader=val_dataloader, 46 | epochs=20, 47 | monitor="pr_auc_samples", 48 | ) 49 | 50 | # STEP 5: evaluate 51 | print (trainer.evaluate(test_dataloader)) 52 | -------------------------------------------------------------------------------- /examples/length_of_stay_mimic3_rnn.py: -------------------------------------------------------------------------------- 1 | from pyhealth.calib import calibration, predictionset 2 | from pyhealth.datasets import MIMIC3Dataset, get_dataloader, split_by_patient 3 | from pyhealth.models import Transformer 4 | from pyhealth.tasks import length_of_stay_prediction_mimic3_fn 5 | from pyhealth.trainer import Trainer, get_metrics_fn 6 | 7 | # STEP 1: load data 8 | base_dataset = MIMIC3Dataset( 9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 11 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, 12 | dev=False, 13 | refresh_cache=True, 14 | ) 15 | base_dataset.stat() 16 | 17 | # STEP 2: set task 18 | sample_dataset = base_dataset.set_task(length_of_stay_prediction_mimic3_fn) 19 | sample_dataset.stat() 20 | 21 | train_dataset, val_dataset, test_dataset = split_by_patient( 22 | sample_dataset, [0.8, 0.1, 0.1] 23 | ) 24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 27 | 28 | # STEP 3: define modedl 29 | model = Transformer( 30 | dataset=sample_dataset, 31 | feature_keys=["conditions", "procedures", "drugs"], 32 | label_key="label", 33 | mode="multiclass", 34 | ) 35 | 36 | # STEP 4: define trainer 37 | trainer = Trainer(model=model) 38 | trainer.train( 39 | train_dataloader=train_dataloader, 40 | val_dataloader=val_dataloader, 41 | epochs=50, 42 | monitor="accuracy", 43 | ) 44 | 45 | # STEP 5: evaluate 46 | metrics = ['accuracy', 'f1_macro', 'f1_micro'] + ['ECE_adapt', 'cwECEt_adapt'] 47 | y_true_all, y_prob_all = trainer.inference(test_dataloader)[:2] 48 | print(get_metrics_fn(model.mode)(y_true_all, y_prob_all, metrics=metrics)) 49 | 50 | # STEP 6: calibrate the model 51 | cal_model = calibration.HistogramBinning(model, debug=True) 52 | cal_model.calibrate(cal_dataset=val_dataset) 53 | y_true_all, y_prob_all = Trainer(model=cal_model).inference(test_dataloader)[:2] 54 | print(get_metrics_fn(cal_model.mode)(y_true_all, y_prob_all, metrics=metrics)) 55 | 56 | 57 | # STEP 7: Construct prediction set, controlling overall miscoverage rate (<0.1) 58 | # Note that if you use calibrated model the coverate rate cannot be controlled, because 59 | # with repect to the calibrated model (which was trained on the calibration set), the 60 | # test set and calibration set is not i.i.d 61 | ps_model = predictionset.LABEL(model, 0.1, debug=True) 62 | ps_model.calibrate(cal_dataset=val_dataset) 63 | y_true_all, y_prob_all, _, extra_output = Trainer(model=ps_model).inference(test_dataloader, additional_outputs=['y_predset']) 64 | print(get_metrics_fn(ps_model.mode)(y_true_all, y_prob_all, 65 | metrics=metrics + ['miscoverage_overall_ps', 'rejection_rate'], 66 | y_predset=extra_output['y_predset'])) 67 | -------------------------------------------------------------------------------- /examples/medcode.py: -------------------------------------------------------------------------------- 1 | from pyhealth.medcode import CrossMap, InnerMap 2 | 3 | ndc = InnerMap.load("NDC") 4 | print("Looking up for NDC code 00597005801") 5 | print(ndc.lookup("00597005801")) 6 | 7 | codemap = CrossMap.load("NDC", "ATC") 8 | print("Mapping NDC code 00597005801 to ATC") 9 | print(codemap.map("00597005801")) 10 | 11 | atc = InnerMap.load("ATC") 12 | print("Looking up for ATC code G04CA02") 13 | print(atc.lookup("G04CA02")) 14 | -------------------------------------------------------------------------------- /examples/mortality_mimic3_adacare.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import AdaCare 4 | from pyhealth.tasks import mortality_prediction_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | if __name__ == "__main__": 8 | # STEP 1: load data 9 | base_dataset = MIMIC3Dataset( 10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, 13 | dev=False, 14 | refresh_cache=False, 15 | ) 16 | base_dataset.stat() 17 | 18 | # STEP 2: set task 19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn) 20 | sample_dataset.stat() 21 | 22 | train_dataset, val_dataset, test_dataset = split_by_patient( 23 | sample_dataset, [0.8, 0.1, 0.1] 24 | ) 25 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 26 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 27 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 28 | 29 | # STEP 3: define model 30 | model = AdaCare( 31 | dataset=sample_dataset, 32 | feature_keys=["conditions", "procedures"], 33 | label_key="label", 34 | mode="binary", 35 | use_embedding=[True, True, True], 36 | embedding_dim=32, 37 | hidden_dim=32, 38 | dropout=0.5, 39 | ) 40 | 41 | # STEP 4: define trainer 42 | trainer = Trainer(model=model) 43 | trainer.train( 44 | train_dataloader=train_dataloader, 45 | val_dataloader=val_dataloader, 46 | epochs=50, 47 | monitor="roc_auc", 48 | ) 49 | 50 | # STEP 5: evaluate 51 | print(trainer.evaluate(test_dataloader)) 52 | -------------------------------------------------------------------------------- /examples/mortality_mimic3_agent.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import Agent 4 | from pyhealth.tasks import mortality_prediction_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | if __name__ == "__main__": 8 | # STEP 1: load data 9 | base_dataset = MIMIC3Dataset( 10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, 13 | dev=False, 14 | refresh_cache=False, 15 | ) 16 | base_dataset.stat() 17 | 18 | # STEP 2: set task 19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn) 20 | sample_dataset.stat() 21 | 22 | train_dataset, val_dataset, test_dataset = split_by_patient( 23 | sample_dataset, [0.8, 0.1, 0.1] 24 | ) 25 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True) 26 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False) 27 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False) 28 | 29 | # STEP 3: define model 30 | model = Agent( 31 | dataset=sample_dataset, 32 | feature_keys=["conditions", "procedures"], 33 | label_key="label", 34 | mode="binary", 35 | embedding_dim=32, 36 | hidden_dim=32, 37 | ) 38 | 39 | # STEP 4: define trainer 40 | trainer = Trainer(model=model) 41 | trainer.train( 42 | train_dataloader=train_dataloader, 43 | val_dataloader=val_dataloader, 44 | epochs=50, 45 | monitor="roc_auc", 46 | ) 47 | 48 | # STEP 5: evaluate 49 | print(trainer.evaluate(test_dataloader)) 50 | -------------------------------------------------------------------------------- /examples/mortality_mimic3_concare.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import ConCare 4 | from pyhealth.tasks import mortality_prediction_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | if __name__ == "__main__": 8 | # STEP 1: load data 9 | base_dataset = MIMIC3Dataset( 10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, 13 | dev=False, 14 | refresh_cache=False, 15 | ) 16 | base_dataset.stat() 17 | 18 | # STEP 2: set task 19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn) 20 | sample_dataset.stat() 21 | 22 | train_dataset, val_dataset, test_dataset = split_by_patient( 23 | sample_dataset, [0.8, 0.1, 0.1] 24 | ) 25 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True) 26 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False) 27 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False) 28 | 29 | # STEP 3: define model 30 | model = ConCare( 31 | dataset=sample_dataset, 32 | feature_keys=["conditions", "procedures"], 33 | label_key="label", 34 | mode="binary", 35 | use_embedding=[True, True, True], 36 | hidden_dim=32, 37 | ) 38 | 39 | # STEP 4: define trainer 40 | trainer = Trainer(model=model) 41 | trainer.train( 42 | train_dataloader=train_dataloader, 43 | val_dataloader=val_dataloader, 44 | epochs=3, 45 | monitor="roc_auc", 46 | ) 47 | 48 | # STEP 5: evaluate 49 | print(trainer.evaluate(test_dataloader)) 50 | -------------------------------------------------------------------------------- /examples/mortality_mimic3_grasp.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import GRASP 4 | from pyhealth.tasks import mortality_prediction_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | if __name__ == "__main__": 8 | # STEP 1: load data 9 | base_dataset = MIMIC3Dataset( 10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, 13 | dev=False, 14 | refresh_cache=False, 15 | ) 16 | base_dataset.stat() 17 | 18 | # STEP 2: set task 19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn) 20 | sample_dataset.stat() 21 | 22 | train_dataset, val_dataset, test_dataset = split_by_patient( 23 | sample_dataset, [0.8, 0.1, 0.1] 24 | ) 25 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True) 26 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False) 27 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False) 28 | 29 | # STEP 3: define model 30 | model = GRASP( 31 | dataset=sample_dataset, 32 | feature_keys=["conditions", "procedures"], 33 | label_key="label", 34 | mode="binary", 35 | use_embedding=[True, True, True], 36 | embedding_dim=32, 37 | hidden_dim=32, 38 | ) 39 | 40 | # STEP 4: define trainer 41 | trainer = Trainer(model=model) 42 | trainer.train( 43 | train_dataloader=train_dataloader, 44 | val_dataloader=val_dataloader, 45 | epochs=5, 46 | monitor="roc_auc", 47 | ) 48 | 49 | # STEP 5: evaluate 50 | print(trainer.evaluate(test_dataloader)) 51 | -------------------------------------------------------------------------------- /examples/mortality_mimic3_rnn.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import RNN 4 | from pyhealth.tasks import mortality_prediction_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | # STEP 1: load data 8 | base_dataset = MIMIC3Dataset( 9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 11 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, 12 | dev=False, 13 | refresh_cache=False, 14 | ) 15 | base_dataset.stat() 16 | 17 | # STEP 2: set task 18 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn) 19 | sample_dataset.stat() 20 | 21 | train_dataset, val_dataset, test_dataset = split_by_patient( 22 | sample_dataset, [0.8, 0.1, 0.1] 23 | ) 24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 27 | 28 | # STEP 3: define model 29 | model = RNN( 30 | dataset=sample_dataset, 31 | feature_keys=["conditions", "procedures", "drugs"], 32 | label_key="label", 33 | mode="binary", 34 | ) 35 | 36 | # STEP 4: define trainer 37 | trainer = Trainer(model=model) 38 | trainer.train( 39 | train_dataloader=train_dataloader, 40 | val_dataloader=val_dataloader, 41 | epochs=50, 42 | monitor="roc_auc", 43 | ) 44 | 45 | # STEP 5: evaluate 46 | print(trainer.evaluate(test_dataloader)) 47 | -------------------------------------------------------------------------------- /examples/mortality_mimic3_stagenet.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import StageNet 4 | from pyhealth.tasks import mortality_prediction_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | if __name__ == "__main__": 8 | # STEP 1: load data 9 | base_dataset = MIMIC3Dataset( 10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, 13 | dev=False, 14 | refresh_cache=False, 15 | ) 16 | base_dataset.stat() 17 | 18 | # STEP 2: set task 19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn) 20 | sample_dataset.stat() 21 | 22 | train_dataset, val_dataset, test_dataset = split_by_patient( 23 | sample_dataset, [0.8, 0.1, 0.1] 24 | ) 25 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True) 26 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False) 27 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False) 28 | 29 | # STEP 3: define model 30 | model = StageNet( 31 | dataset=sample_dataset, 32 | feature_keys=["conditions", "procedures"], 33 | label_key="label", 34 | mode="binary", 35 | embedding_dim=32, 36 | ) 37 | 38 | # STEP 4: define trainer 39 | trainer = Trainer(model=model) 40 | trainer.train( 41 | train_dataloader=train_dataloader, 42 | val_dataloader=val_dataloader, 43 | epochs=50, 44 | monitor="roc_auc", 45 | ) 46 | 47 | # STEP 5: evaluate 48 | print(trainer.evaluate(test_dataloader)) 49 | -------------------------------------------------------------------------------- /examples/mortality_mimic3_tcn.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import TCN 4 | from pyhealth.tasks import mortality_prediction_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | if __name__ == "__main__": 8 | # STEP 1: load data 9 | base_dataset = MIMIC3Dataset( 10 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 11 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 12 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, 13 | dev=False, 14 | refresh_cache=False, 15 | ) 16 | base_dataset.stat() 17 | 18 | # STEP 2: set task 19 | sample_dataset = base_dataset.set_task(mortality_prediction_mimic3_fn) 20 | sample_dataset.stat() 21 | 22 | train_dataset, val_dataset, test_dataset = split_by_patient( 23 | sample_dataset, [0.8, 0.1, 0.1] 24 | ) 25 | train_dataloader = get_dataloader(train_dataset, batch_size=256, shuffle=True) 26 | val_dataloader = get_dataloader(val_dataset, batch_size=256, shuffle=False) 27 | test_dataloader = get_dataloader(test_dataset, batch_size=256, shuffle=False) 28 | 29 | # STEP 3: define model 30 | model = TCN( 31 | dataset=sample_dataset, 32 | feature_keys=["conditions", "procedures"], 33 | label_key="label", 34 | mode="binary", 35 | embedding_dim=32, 36 | ) 37 | 38 | # STEP 4: define trainer 39 | trainer = Trainer(model=model) 40 | trainer.train( 41 | train_dataloader=train_dataloader, 42 | val_dataloader=val_dataloader, 43 | epochs=50, 44 | monitor="roc_auc", 45 | ) 46 | 47 | # STEP 5: evaluate 48 | print(trainer.evaluate(test_dataloader)) 49 | -------------------------------------------------------------------------------- /examples/readmission_mimic3_fairness.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.tasks import readmission_prediction_mimic3_fn 3 | from pyhealth.datasets import split_by_patient, get_dataloader 4 | from pyhealth.metrics import fairness_metrics_fn 5 | from pyhealth.models import Transformer 6 | from pyhealth.trainer import Trainer 7 | from pyhealth.metrics.fairness_utils.utils import sensitive_attributes_from_patient_ids 8 | 9 | # STEP 1: load data 10 | base_dataset = MIMIC3Dataset( 11 | root="https://storage.googleapis.com/pyhealth/Synthetic_MIMIC-III/", 12 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 13 | ) 14 | base_dataset.stat() 15 | 16 | # STEP 2: set task 17 | sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn) 18 | sample_dataset.stat() 19 | 20 | train_dataset, val_dataset, test_dataset = split_by_patient(sample_dataset, [0.8, 0.1, 0.1]) 21 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 22 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 23 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 24 | 25 | # STEP 3: define model 26 | model = Transformer( 27 | dataset=sample_dataset, 28 | # look up what are available for "feature_keys" and "label_keys" in dataset.samples[0] 29 | feature_keys=["conditions", "procedures"], 30 | label_key="label", 31 | mode="binary", 32 | ) 33 | 34 | # STEP 4: define trainer 35 | trainer = Trainer(model=model) 36 | trainer.train( 37 | train_dataloader=train_dataloader, 38 | val_dataloader=val_dataloader, 39 | epochs=3, 40 | monitor="pr_auc", 41 | ) 42 | 43 | # STEP 5: inference, return patient_ids 44 | y_true, y_prob, loss, patient_ids = trainer.inference(test_dataloader, return_patient_ids=True) 45 | 46 | # STEP 6: get sensitive attribute array from patient_ids 47 | sensitive_attribute_array = sensitive_attributes_from_patient_ids(base_dataset, patient_ids, 48 | 'gender', 'F') 49 | 50 | # STEP 7: use pyhealth.metrics to evaluate fairness 51 | fairness_metrics = fairness_metrics_fn(y_true, y_prob, sensitive_attribute_array, 52 | favorable_outcome=0) 53 | print(fairness_metrics) -------------------------------------------------------------------------------- /examples/readmission_mimic3_rnn.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import MIMIC3Dataset 2 | from pyhealth.datasets import split_by_patient, get_dataloader 3 | from pyhealth.models import RNN 4 | from pyhealth.tasks import readmission_prediction_mimic3_fn 5 | from pyhealth.trainer import Trainer 6 | 7 | # STEP 1: load data 8 | base_dataset = MIMIC3Dataset( 9 | root="/srv/local/data/physionet.org/files/mimiciii/1.4", 10 | tables=["DIAGNOSES_ICD", "PROCEDURES_ICD", "PRESCRIPTIONS"], 11 | code_mapping={"ICD9CM": "CCSCM", "ICD9PROC": "CCSPROC", "NDC": "ATC"}, 12 | dev=False, 13 | refresh_cache=True, 14 | ) 15 | base_dataset.stat() 16 | 17 | # STEP 2: set task 18 | sample_dataset = base_dataset.set_task(readmission_prediction_mimic3_fn) 19 | sample_dataset.stat() 20 | 21 | train_dataset, val_dataset, test_dataset = split_by_patient( 22 | sample_dataset, [0.8, 0.1, 0.1] 23 | ) 24 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 25 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 26 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 27 | 28 | # STEP 3: define model 29 | model = RNN( 30 | dataset=sample_dataset, 31 | feature_keys=["conditions", "procedures", "drugs"], 32 | label_key="label", 33 | mode="binary", 34 | ) 35 | 36 | # STEP 4: define trainer 37 | trainer = Trainer(model=model) 38 | trainer.train( 39 | train_dataloader=train_dataloader, 40 | val_dataloader=val_dataloader, 41 | epochs=50, 42 | monitor="roc_auc", 43 | ) 44 | 45 | # STEP 5: evaluate 46 | trainer.evaluate(test_dataloader) 47 | -------------------------------------------------------------------------------- /examples/sleep_staging_ISRUC_SparcNet.py: -------------------------------------------------------------------------------- 1 | from pyhealth.calib import calibration 2 | from pyhealth.datasets import ISRUCDataset, get_dataloader, split_by_patient 3 | from pyhealth.models import ContraWR, SparcNet 4 | from pyhealth.tasks import sleep_staging_isruc_fn 5 | from pyhealth.trainer import Trainer, get_metrics_fn 6 | 7 | # step 1: load signal data 8 | dataset = ISRUCDataset( 9 | root="/srv/local/data/trash/", 10 | dev=True, 11 | refresh_cache=False, 12 | # download=True, 13 | ) 14 | 15 | print(dataset.stat()) 16 | 17 | # step 2: set task 18 | sleep_staging_ds = dataset.set_task(sleep_staging_isruc_fn) 19 | sleep_staging_ds.stat() 20 | print(sleep_staging_ds.samples[0]) 21 | 22 | # split dataset 23 | train_dataset, val_dataset, test_dataset = split_by_patient( 24 | sleep_staging_ds, [0.34, 0.33, 0.33] 25 | ) 26 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 27 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 28 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 29 | print( 30 | "loader size: train/val/test", 31 | len(train_dataset), 32 | len(val_dataset), 33 | len(test_dataset), 34 | ) 35 | 36 | # STEP 3: define model 37 | model = SparcNet( 38 | dataset=sleep_staging_ds, 39 | feature_keys=["signal"], 40 | label_key="label", 41 | mode="multiclass", 42 | ) 43 | 44 | # STEP 4: define trainer 45 | trainer = Trainer(model=model) 46 | trainer.train( 47 | train_dataloader=train_dataloader, 48 | val_dataloader=val_dataloader, 49 | epochs=5, 50 | monitor="accuracy", 51 | ) 52 | 53 | # STEP 5: evaluate 54 | metrics = ['accuracy', 'f1_macro', 'f1_micro'] + ['cwECEt_adapt'] 55 | y_true_all, y_prob_all = trainer.inference(test_dataloader)[:2] 56 | print(get_metrics_fn(model.mode)(y_true_all, y_prob_all, metrics=metrics)) 57 | 58 | # STEP 6: calibrate the model 59 | cal_model = calibration.KCal(model, debug=True, dim=32) 60 | cal_model.calibrate( 61 | cal_dataset=val_dataset, 62 | # Uncomment the following line if you want to re-train the embeddings 63 | # train_dataset=train_dataset, 64 | ) 65 | y_true_all, y_prob_all = Trainer(model=cal_model).inference(test_dataloader)[:2] 66 | print(get_metrics_fn(cal_model.mode)(y_true_all, y_prob_all, metrics=metrics)) -------------------------------------------------------------------------------- /examples/sleep_staging_shhs_contrawr.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import split_by_patient, get_dataloader 2 | from pyhealth.trainer import Trainer 3 | from pyhealth.datasets import SHHSDataset 4 | from pyhealth.tasks import sleep_staging_shhs_fn 5 | from pyhealth.models import ContraWR, SparcNet 6 | 7 | # step 1: load signal data 8 | dataset = SHHSDataset( 9 | root="/srv/local/data/SHHS/polysomnography", 10 | dev=True, 11 | refresh_cache=False, 12 | ) 13 | 14 | # step 2: set task 15 | sleep_staging_ds = dataset.set_task(sleep_staging_shhs_fn) 16 | sleep_staging_ds.stat() 17 | 18 | # split dataset 19 | train_dataset, val_dataset, test_dataset = split_by_patient( 20 | sleep_staging_ds, [0.6, 0.2, 0.2] 21 | ) 22 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 23 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 24 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 25 | print( 26 | "loader size: train/val/test", 27 | len(train_dataset), 28 | len(val_dataset), 29 | len(test_dataset), 30 | ) 31 | 32 | # STEP 3: define model 33 | model = SparcNet( 34 | dataset=sleep_staging_ds, 35 | feature_keys=["signal"], 36 | label_key="label", 37 | mode="multiclass", 38 | ) 39 | 40 | # STEP 4: define trainer 41 | trainer = Trainer(model=model) 42 | trainer.train( 43 | train_dataloader=train_dataloader, 44 | val_dataloader=val_dataloader, 45 | epochs=5, 46 | monitor="accuracy", 47 | ) 48 | 49 | # STEP 5: evaluate 50 | print(trainer.evaluate(test_dataloader)) 51 | -------------------------------------------------------------------------------- /examples/sleep_staging_sleepEDF_contrawr.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import split_by_patient, get_dataloader 2 | from pyhealth.trainer import Trainer 3 | from pyhealth.datasets import SleepEDFDataset 4 | from pyhealth.tasks import sleep_staging_sleepedf_fn 5 | from pyhealth.models import ContraWR, SparcNet 6 | 7 | # step 1: load signal data 8 | dataset = SleepEDFDataset( 9 | root="/srv/local/data/SLEEPEDF/sleep-edf-database-expanded-1.0.0/sleep-cassette", 10 | dev=True, 11 | refresh_cache=False, 12 | ) 13 | 14 | # step 2: set task 15 | sleep_staging_ds = dataset.set_task(sleep_staging_sleepedf_fn) 16 | sleep_staging_ds.stat() 17 | 18 | # split dataset 19 | train_dataset, val_dataset, test_dataset = split_by_patient( 20 | sleep_staging_ds, [0.6, 0.2, 0.2] 21 | ) 22 | train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True) 23 | val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False) 24 | test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False) 25 | print( 26 | "loader size: train/val/test", 27 | len(train_dataset), 28 | len(val_dataset), 29 | len(test_dataset), 30 | ) 31 | 32 | # STEP 3: define model 33 | model = SparcNet( 34 | dataset=sleep_staging_ds, 35 | feature_keys=["signal"], 36 | label_key="label", 37 | mode="multiclass", 38 | ) 39 | 40 | # STEP 4: define trainer 41 | trainer = Trainer(model=model) 42 | trainer.train( 43 | train_dataloader=train_dataloader, 44 | val_dataloader=val_dataloader, 45 | epochs=5, 46 | monitor="accuracy", 47 | ) 48 | 49 | # STEP 5: evaluate 50 | print(trainer.evaluate(test_dataloader)) 51 | -------------------------------------------------------------------------------- /examples/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | for i in range(100): 4 | a = torch.randn((5, 5, 5)) > 0 5 | print(((torch.sum(a, dim=2) != 0) != (torch.any(a != 0, dim=2))).any()) 6 | -------------------------------------------------------------------------------- /examples/test_dataloader.sh: -------------------------------------------------------------------------------- 1 | python pyhealth/datasets/mimic3.py 2 | python pyhealth/datasets/eicu.py 3 | python pyhealth/datasets/omop.py 4 | python pyhealth/datasets/mimic4.py -------------------------------------------------------------------------------- /examples/test_eICU_addition.py: -------------------------------------------------------------------------------- 1 | from pyhealth.datasets import eICUDataset 2 | from pyhealth.tasks import mortality_prediction_eicu_fn, mortality_prediction_eicu_fn2 3 | 4 | base_dataset = eICUDataset( 5 | root="/srv/local/data/physionet.org/files/eicu-crd/2.0", 6 | tables=["diagnosis", "admissionDx", "treatment"], 7 | dev=False, 8 | refresh_cache=False, 9 | ) 10 | sample_dataset = base_dataset.set_task(task_fn=mortality_prediction_eicu_fn2) 11 | sample_dataset.stat() 12 | print(sample_dataset.available_keys) 13 | 14 | # base_dataset = eICUDataset( 15 | # root="/srv/local/data/physionet.org/files/eicu-crd/2.0", 16 | # tables=["diagnosis", "admissionDx", "treatment"], 17 | # dev=True, 18 | # refresh_cache=False, 19 | # ) 20 | # sample_dataset = base_dataset.set_task(task_fn=mortality_prediction_eicu_fn2) 21 | # sample_dataset.stat() 22 | # print(sample_dataset.available_keys) 23 | -------------------------------------------------------------------------------- /figure/five-stage-pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/figure/five-stage-pipeline.png -------------------------------------------------------------------------------- /figure/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/figure/overview.png -------------------------------------------------------------------------------- /figure/poster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/figure/poster.png -------------------------------------------------------------------------------- /figure/structured-dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/figure/structured-dataset.png -------------------------------------------------------------------------------- /leaderboard/__init__.py: -------------------------------------------------------------------------------- 1 | from .leaderboard_gen import leaderboard_generation, plots_generation, construct_args 2 | from .utils import * 3 | -------------------------------------------------------------------------------- /leaderboard/ext_plot.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | import argparse 3 | import os 4 | os.environ["KMP_DUPLICATE_LIB_OK"]=True 5 | 6 | credential_file = './credentials.json' 7 | 8 | def plot_html(args): 9 | dfs = read_dataframes_by_time_from_gcp(args.credentials) 10 | output_file(filename='leaderboard_plotting.html', title='PyHealth Leaderboard') 11 | 12 | bokeh_figures = [] 13 | 14 | for task in args.tasks: 15 | df = get_typed_df_with_time(dfs, task) 16 | bokeh_figure = generate_bokeh_figure(df) 17 | bokeh_figures.append(bokeh_figure) 18 | 19 | save(column(bokeh_figures)) 20 | 21 | 22 | def construct_args(): 23 | parser.add_argument("--credentials", type=str, default=credential_file) 24 | parser.add_argument("--tasks", type=list, 25 | default=[ 26 | "drugrec", 27 | "lenOfStay", 28 | "mortality", 29 | "readmission" 30 | ]) 31 | 32 | 33 | def main(): 34 | args = construct_args() 35 | plot_html(args) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /leaderboard/rtd_build_trigger.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | def build(args): 4 | version = args.version 5 | version = 'leaderboard-and-unitest-update' 6 | URL = f'https://readthedocs.org/api/v3/projects/pyhealth/versions/{version}/builds/' 7 | with open(args.rtd_token) as f: 8 | TOKEN = f.readlines()[0] 9 | HEADERS = {'Authorization': f'token {TOKEN}'} 10 | response = requests.post(URL, headers=HEADERS) 11 | print(response.json()) 12 | 13 | 14 | def construct_args(): 15 | parser.add_argument("--rtd_token", type=str, default='./rtd_token.txt') 16 | parser.add_argument("--version", type=str, default='lastest') 17 | 18 | 19 | def main(): 20 | args = construct_args() 21 | build(args) 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | 27 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | #@meta {author: "Paul Landes"} 2 | #@meta {desc: "PyHealth build automation", date: "2025-05-22"} 3 | 4 | 5 | ## Build 6 | # 7 | # directory with the unit tests 8 | PY_TEST_DIR ?= tests 9 | # test file glob pattern 10 | PY_TEST_GLOB ?= test_metrics.py 11 | 12 | 13 | ## Targets 14 | # 15 | # install dependencies 16 | .PHONY: deps 17 | deps: 18 | pip install -r requirements-nlp.txt 19 | 20 | # run the unit test cases 21 | .PHONY: test 22 | test: 23 | @echo "Running tests in $(PY_TEST_DIR)/$(PY_TEST_GLOB)" 24 | python -m unittest discover \ 25 | -s $(PY_TEST_DIR) -p '$(PY_TEST_GLOB)' -v 26 | 27 | # clean derived objects 28 | .PHONY: clean 29 | clean: 30 | @echo "removing __pycache__" 31 | @find . -type d -name __pycache__ -prune -exec rm -r {} \; 32 | -------------------------------------------------------------------------------- /pyhealth/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | import sys 5 | 6 | __version__ = "1.1.4" 7 | 8 | # package-level cache path 9 | BASE_CACHE_PATH = os.path.join(str(Path.home()), ".cache/pyhealth/") 10 | # BASE_CACHE_PATH = "/srv/local/data/pyhealth-cache" 11 | if not os.path.exists(BASE_CACHE_PATH): 12 | os.makedirs(BASE_CACHE_PATH) 13 | 14 | # logging 15 | logger = logging.getLogger(__name__) 16 | logger.setLevel(logging.INFO) 17 | handler = logging.StreamHandler(sys.stdout) 18 | formatter = logging.Formatter("%(message)s") 19 | handler.setFormatter(formatter) 20 | logger.addHandler(handler) 21 | -------------------------------------------------------------------------------- /pyhealth/calib/README.md: -------------------------------------------------------------------------------- 1 | This module contains several post-hoc uncertainty quantification methods. 2 | 3 | # Model Calibration: 4 | 5 | ### KCal 6 | 7 | Used in: `multiclass`. 8 | 9 | The model needs to be able to take `embed=True` in `forward` and output the penultimate embedding in the output. 10 | 11 | Lin, Zhen, Shubhendu Trivedi, and Jimeng Sun. 12 | "Taking a Step Back with KCal: Multi-Class Kernel-Based Calibration for Deep Neural Networks." 13 | ICLR 2023. 14 | 15 | ### Temperature Scaling 16 | 17 | Used in: `multiclass`, `multilabel` and `binary`. 18 | 19 | Guo, Chuan, Geoff Pleiss, Yu Sun, and Kilian Q. Weinberger. 20 | "On calibration of modern neural networks." 21 | ICML 2017. 22 | 23 | ### Histogram Binning 24 | 25 | Used in: `multiclass`, `multilabel` and `binary`. 26 | 27 | 28 | Zadrozny, Bianca, and Charles Elkan. 29 | "Learning and making decisions when costs and probabilities are both unknown." 30 | In Proceedings of the seventh ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 204-213. 2001. 31 | 32 | Gupta, Chirag, and Aaditya Ramdas. 33 | "Top-label calibration and multiclass-to-binary reductions." 34 | ICLR 2022. 35 | 36 | # Prediction Set: 37 | 38 | ### SCRIB 39 | 40 | Used in: `multiclass`. 41 | 42 | Lin, Zhen, Lucas Glass, M. Brandon Westover, Cao Xiao, and Jimeng Sun. 43 | "SCRIB: Set-classifier with Class-specific Risk Bounds for Blackbox Models." 44 | AAAI 2022. 45 | 46 | ### LABEL 47 | 48 | Used in: `multiclass`. 49 | 50 | Sadinle, Mauricio, Jing Lei, and Larry Wasserman. 51 | "Least ambiguous set-valued classifiers with bounded error levels." 52 | Journal of the American Statistical Association 114, no. 525 (2019): 223-234. -------------------------------------------------------------------------------- /pyhealth/calib/__init__.py: -------------------------------------------------------------------------------- 1 | from pyhealth.calib import calibration, predictionset 2 | -------------------------------------------------------------------------------- /pyhealth/calib/base_classes.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Dict 3 | 4 | import torch 5 | 6 | 7 | class PostHocCalibrator(ABC, torch.nn.Module): 8 | def __init__(self, model, **kwargs) -> None: 9 | super().__init__() 10 | self.model = model 11 | 12 | def calibrate(self, cal_dataset): 13 | ... 14 | 15 | def forward(self, **kwargs) -> Dict[str, torch.Tensor]: 16 | ... 17 | 18 | 19 | def to(self, device): 20 | super().to(device) 21 | self.device = device 22 | return self 23 | 24 | 25 | class SetPredictor(ABC, torch.nn.Module): 26 | def __init__(self, model, **kwargs) -> None: 27 | super().__init__() 28 | self.model = model 29 | 30 | def calibrate(self, cal_dataset): 31 | ... 32 | 33 | def forward(self, **kwargs) -> Dict[str, torch.Tensor]: 34 | ... 35 | 36 | def to(self, device): 37 | super().to(device) 38 | self.device = device 39 | return self -------------------------------------------------------------------------------- /pyhealth/calib/calibration/__init__.py: -------------------------------------------------------------------------------- 1 | """Model calibration methods""" 2 | from pyhealth.calib.calibration.dircal import DirichletCalibration 3 | from pyhealth.calib.calibration.hb import HistogramBinning 4 | from pyhealth.calib.calibration.kcal import KCal 5 | from pyhealth.calib.calibration.temperature_scale import TemperatureScaling 6 | 7 | __all__ = ['DirichletCalibration', 'HistogramBinning', 'KCal', 'TemperatureScaling'] 8 | -------------------------------------------------------------------------------- /pyhealth/calib/predictionset/__init__.py: -------------------------------------------------------------------------------- 1 | """Prediction set construction methods""" 2 | from pyhealth.calib.predictionset.favmac import FavMac 3 | from pyhealth.calib.predictionset.label import LABEL 4 | from pyhealth.calib.predictionset.scrib import SCRIB 5 | 6 | __all__ = ['LABEL', 'SCRIB', 'FavMac'] 7 | -------------------------------------------------------------------------------- /pyhealth/calib/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import torch 6 | import tqdm 7 | from torch import Tensor 8 | 9 | from pyhealth.datasets import utils as datautils 10 | 11 | 12 | def agg_loss(loss:torch.Tensor, reduction: str): 13 | if reduction == 'mean': 14 | return loss.mean() 15 | if reduction == 'sum': 16 | return loss.sum() 17 | return loss 18 | 19 | 20 | def one_hot_np(labels, K): 21 | new_labels = np.zeros((len(labels), K)) 22 | new_labels[np.arange(len(labels)), labels] = 1 23 | return new_labels 24 | 25 | class LogLoss(torch.nn.Module): 26 | """Cross entropy, but takes in the probability instead of the logits""" 27 | reduction: str 28 | def __init__(self, weight: Optional[Tensor] = None, ignore_index: int = -100, reduction: str = 'mean', clip=1e-10) -> None: 29 | super(LogLoss, self).__init__() 30 | self.register_buffer('weight', weight) 31 | self.ignore_index = ignore_index 32 | self.reduction = reduction 33 | self.clip = clip 34 | 35 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 36 | assert self.weight is None or isinstance(self.weight, Tensor) 37 | dim = input.dim() 38 | assert dim == 2, f"Expected 2 dimensions (got {dim})" 39 | input = input.clip(self.clip)#this weight should be trivial, so I won't normalize 40 | input = -torch.log(input) 41 | if self.weight is not None: 42 | input = input * self.weight.unsqueeze(0) 43 | loss = torch.gather(input, -1, target.unsqueeze(-1)).squeeze(-1) 44 | return agg_loss(loss, self.reduction) 45 | 46 | 47 | def prepare_numpy_dataset(model, dataset, keys, forward_kwargs=None, 48 | incl_data_keys=None, debug=False, batch_size=32): 49 | if forward_kwargs is None: 50 | forward_kwargs = {} 51 | if incl_data_keys is None: 52 | incl_data_keys = [] 53 | loader = datautils.get_dataloader(dataset, batch_size, shuffle=False) 54 | 55 | ret = defaultdict(list) 56 | with torch.no_grad(): 57 | for _i, data in tqdm.tqdm(enumerate(loader), desc=f"retrieving {keys}", total=len(loader)): 58 | if debug and _i % 10 != 0: 59 | continue 60 | data.update(forward_kwargs) 61 | res = model(**data) 62 | for key in keys: 63 | ret[key].append(res[key].detach().cpu().numpy()) 64 | for key in incl_data_keys: 65 | ret[key].extend(data[key]) 66 | for key in incl_data_keys: 67 | ret[key] = np.asarray(ret[key]) 68 | for key in keys: 69 | ret[key] = np.concatenate(ret[key]) 70 | return ret 71 | -------------------------------------------------------------------------------- /pyhealth/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import Event, Patient 2 | 3 | 4 | class Visit: 5 | """This class is deprecated and should not be used.""" 6 | def __init__(self, *args, **kwargs): 7 | import warnings 8 | warnings.warn("The Visit class is deprecated and will be removed in a future version.", DeprecationWarning) -------------------------------------------------------------------------------- /pyhealth/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | class BaseEHRDataset: 2 | """This class is deprecated and should not be used.""" 3 | def __init__(self, *args, **kwargs): 4 | import warnings 5 | warnings.warn("The BaseEHRDataset class is deprecated and will be removed in a future version.", DeprecationWarning) 6 | 7 | class BaseSignalDataset: 8 | """This class is deprecated and should not be used.""" 9 | def __init__(self, *args, **kwargs): 10 | import warnings 11 | warnings.warn("The BaseSignalDataset class is deprecated and will be removed in a future version.", DeprecationWarning) 12 | 13 | 14 | class SampleEHRDataset: 15 | """This class is deprecated and should not be used.""" 16 | def __init__(self, *args, **kwargs): 17 | import warnings 18 | warnings.warn("The SampleEHRDataset class is deprecated and will be removed in a future version.", DeprecationWarning) 19 | 20 | 21 | class SampleSignalDataset: 22 | """This class is deprecated and should not be used.""" 23 | def __init__(self, *args, **kwargs): 24 | import warnings 25 | warnings.warn("The SampleSignalDataset class is deprecated and will be removed in a future version.", DeprecationWarning) 26 | 27 | 28 | from .base_dataset import BaseDataset 29 | from .cardiology import CardiologyDataset 30 | from .covid19_cxr import COVID19CXRDataset 31 | from .ehrshot import EHRShotDataset 32 | from .eicu import eICUDataset 33 | from .isruc import ISRUCDataset 34 | from .medical_transcriptions import MedicalTranscriptionsDataset 35 | from .mimic3 import MIMIC3Dataset 36 | from .mimic4 import MIMIC4CXRDataset, MIMIC4Dataset, MIMIC4EHRDataset, MIMIC4NoteDataset 37 | from .mimicextract import MIMICExtractDataset 38 | from .omop import OMOPDataset 39 | from .sample_dataset import SampleDataset 40 | from .shhs import SHHSDataset 41 | from .sleepedf import SleepEDFDataset 42 | from .splitter import split_by_patient, split_by_sample, split_by_visit 43 | from .tuab import TUABDataset 44 | from .tuev import TUEVDataset 45 | from .utils import collate_fn_dict, collate_fn_dict_with_padding, get_dataloader 46 | -------------------------------------------------------------------------------- /pyhealth/datasets/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import DatasetConfig, load_yaml_config 2 | -------------------------------------------------------------------------------- /pyhealth/datasets/configs/covid19_cxr.yaml: -------------------------------------------------------------------------------- 1 | version: "5.0" 2 | tables: 3 | covid19_cxr: 4 | file_path: "covid19_cxr-metadata-pyhealth.csv" 5 | patient_id: null 6 | timestamp: null 7 | attributes: 8 | - "path" 9 | - "url" 10 | - "label" 11 | -------------------------------------------------------------------------------- /pyhealth/datasets/configs/medical_transcriptions.yaml: -------------------------------------------------------------------------------- 1 | version: "1.0" 2 | tables: 3 | mtsamples: 4 | file_path: "mtsamples.csv" 5 | patient_id: null 6 | timestamp: null 7 | attributes: 8 | - "description" 9 | - "medical_specialty" 10 | - "sample_name" 11 | - "transcription" 12 | - "keywords" -------------------------------------------------------------------------------- /pyhealth/datasets/configs/mimic4_ehr.yaml: -------------------------------------------------------------------------------- 1 | version: "2.2" 2 | tables: 3 | patients: 4 | file_path: "hosp/patients.csv.gz" 5 | patient_id: "subject_id" 6 | timestamp: null 7 | attributes: 8 | - "gender" 9 | - "anchor_age" 10 | - "anchor_year" 11 | - "anchor_year_group" 12 | - "dod" 13 | 14 | admissions: 15 | file_path: "hosp/admissions.csv.gz" 16 | patient_id: "subject_id" 17 | timestamp: "admittime" 18 | attributes: 19 | - "hadm_id" 20 | - "admission_type" 21 | - "admission_location" 22 | - "insurance" 23 | - "language" 24 | - "marital_status" 25 | - "race" 26 | - "discharge_location" 27 | - "dischtime" 28 | - "hospital_expire_flag" 29 | 30 | icustays: 31 | file_path: "icu/icustays.csv.gz" 32 | patient_id: "subject_id" 33 | timestamp: "intime" 34 | attributes: 35 | - "stay_id" 36 | - "first_careunit" 37 | - "last_careunit" 38 | - "outtime" 39 | 40 | diagnoses_icd: 41 | file_path: "hosp/diagnoses_icd.csv.gz" 42 | patient_id: "subject_id" 43 | join: 44 | - file_path: "hosp/admissions.csv.gz" 45 | "on": "hadm_id" 46 | how: "inner" 47 | columns: 48 | - "dischtime" 49 | timestamp: "dischtime" 50 | attributes: 51 | - "icd_code" 52 | - "icd_version" 53 | - "seq_num" 54 | 55 | procedures_icd: 56 | file_path: "hosp/procedures_icd.csv.gz" 57 | patient_id: "subject_id" 58 | join: 59 | - file_path: "hosp/admissions.csv.gz" 60 | "on": "hadm_id" 61 | how: "inner" 62 | columns: 63 | - "dischtime" 64 | timestamp: "dischtime" 65 | attributes: 66 | - "icd_code" 67 | - "icd_version" 68 | - "seq_num" 69 | 70 | prescriptions: 71 | file_path: "hosp/prescriptions.csv.gz" 72 | patient_id: "subject_id" 73 | timestamp: "starttime" 74 | attributes: 75 | - "drug" 76 | - "ndc" 77 | - "prod_strength" 78 | - "dose_val_rx" 79 | - "dose_unit_rx" 80 | - "route" 81 | - "stoptime" 82 | 83 | labevents: 84 | file_path: "hosp/labevents.csv.gz" 85 | patient_id: "subject_id" 86 | join: 87 | - file_path: "hosp/d_labitems.csv.gz" 88 | "on": "itemid" 89 | how: "inner" 90 | columns: 91 | - "label" 92 | - "fluid" 93 | - "category" 94 | timestamp: "charttime" 95 | attributes: 96 | - "itemid" 97 | - "label" 98 | - "fluid" 99 | - "category" 100 | - "value" 101 | - "valuenum" 102 | - "valueuom" 103 | - "flag" 104 | - "storetime" 105 | 106 | hcpcsevents: 107 | file_path: "hosp/hcpcsevents.csv.gz" 108 | patient_id: "subject_id" 109 | timestamp: "chartdate" 110 | attributes: 111 | - "hcpcs_cd" 112 | - "seq_num" 113 | - "short_description" 114 | -------------------------------------------------------------------------------- /pyhealth/datasets/configs/mimic4_note.yaml: -------------------------------------------------------------------------------- 1 | version: "2.2" 2 | tables: 3 | discharge: 4 | file_path: "note/discharge.csv.gz" 5 | patient_id: "subject_id" 6 | timestamp: "charttime" 7 | attributes: 8 | - "note_id" 9 | - "hadm_id" 10 | - "note_type" 11 | - "note_seq" 12 | - "storetime" 13 | - "text" 14 | 15 | discharge_detail: 16 | file_path: "note/discharge_detail.csv.gz" 17 | patient_id: "subject_id" 18 | join: 19 | - file_path: "note/discharge.csv.gz" 20 | "on": "note_id" 21 | how: "inner" 22 | columns: 23 | - "charttime" 24 | timestamp: "charttime" 25 | attributes: 26 | - "note_id" 27 | - "field_name" 28 | - "field_value" 29 | - "field_ordinal" 30 | 31 | radiology: 32 | file_path: "note/radiology.csv.gz" 33 | patient_id: "subject_id" 34 | timestamp: "charttime" 35 | attributes: 36 | - "note_id" 37 | - "hadm_id" 38 | - "note_type" 39 | - "note_seq" 40 | - "storetime" 41 | - "text" 42 | 43 | radiology_detail: 44 | file_path: "note/radiology_detail.csv.gz" 45 | patient_id: "subject_id" 46 | join: 47 | - file_path: "note/radiology.csv.gz" 48 | "on": "note_id" 49 | how: "inner" 50 | columns: 51 | - "charttime" 52 | timestamp: "charttime" 53 | attributes: 54 | - "note_id" 55 | - "field_name" 56 | - "field_value" 57 | - "field_ordinal" 58 | -------------------------------------------------------------------------------- /pyhealth/datasets/ehrshot.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import List, Optional 4 | 5 | from .base_dataset import BaseDataset 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class EHRShotDataset(BaseDataset): 11 | """ 12 | A dataset class for handling EHRShot data. 13 | 14 | This class is responsible for loading and managing the EHRShot dataset. 15 | 16 | Website: https://som-shahlab.github.io/ehrshot-website/ 17 | 18 | Attributes: 19 | root (str): The root directory where the dataset is stored. 20 | tables (List[str]): A list of tables to be included in the dataset. 21 | dataset_name (Optional[str]): The name of the dataset. 22 | config_path (Optional[str]): The path to the configuration file. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | root: str, 28 | tables: List[str], 29 | dataset_name: Optional[str] = None, 30 | config_path: Optional[str] = None, 31 | **kwargs 32 | ) -> None: 33 | if config_path is None: 34 | logger.info("No config path provided, using default config") 35 | config_path = Path(__file__).parent / "configs" / "ehrshot.yaml" 36 | super().__init__( 37 | root=root, 38 | tables=tables, 39 | dataset_name=dataset_name or "ehrshot", 40 | config_path=config_path, 41 | **kwargs 42 | ) 43 | return 44 | -------------------------------------------------------------------------------- /pyhealth/datasets/featurizers/signal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.signal import resample, butter, filtfilt 3 | 4 | class BiosignalFeaturizer: 5 | 6 | def __init__(self, target_length=256, lowcut=0.5, highcut=40.0, sampling_rate=100): 7 | self.target_length = target_length 8 | self.lowcut = lowcut 9 | self.highcut = highcut 10 | self.sampling_rate = sampling_rate 11 | 12 | def _bandpass_filter(self, signal): 13 | """Apply a bandpass filter to the signal.""" 14 | nyquist = 0.5 * self.sampling_rate 15 | low = self.lowcut / nyquist 16 | high = self.highcut / nyquist 17 | b, a = butter(1, [low, high], btype="band") 18 | return filtfilt(b, a, signal) 19 | 20 | def encode(self, signal): 21 | """ 22 | Encodes the biosignal into a fixed-length feature representation. 23 | 24 | Args: 25 | signal (np.array): 1D array of raw biosignal data. 26 | 27 | Returns: 28 | np.array: Processed and fixed-length representation of the biosignal. 29 | """ 30 | # Apply bandpass filtering 31 | filtered_signal = self._bandpass_filter(signal) 32 | 33 | # Resample to target length for uniformity 34 | resampled_signal = resample(filtered_signal, self.target_length) 35 | 36 | # Normalize the signal 37 | normalized_signal = (resampled_signal - np.mean(resampled_signal)) / np.std(resampled_signal) 38 | 39 | return normalized_signal 40 | 41 | if __name__ == "__main__": 42 | # Example biosignal (ECG or other) with 1000 sample points 43 | sample_signal = np.sin(np.linspace(0, 10, 1000)) + 0.1 * np.random.randn(1000) 44 | featurizer = BiosignalFeaturizer() 45 | print(featurizer) 46 | print(type(featurizer)) 47 | print(featurizer.encode(sample_signal)) 48 | -------------------------------------------------------------------------------- /pyhealth/datasets/featurizers/text.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModel 2 | import torch 3 | 4 | 5 | class TextFeaturizer: 6 | 7 | def __init__(self, model_name="bert-base-uncased"): 8 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 9 | self.model = AutoModel.from_pretrained(model_name) 10 | 11 | def encode(self, value): 12 | # Tokenize and process the input text 13 | inputs = self.tokenizer(value, return_tensors="pt", truncation=True, 14 | padding=True) 15 | 16 | # Generate embeddings 17 | with torch.no_grad(): 18 | outputs = self.model(**inputs) 19 | 20 | # Use the CLS token embedding as the representation (can also use pooling for sentence embeddings) 21 | embedding = outputs.last_hidden_state[:, 0, :] # CLS token embedding 22 | 23 | return embedding 24 | 25 | 26 | if __name__ == "__main__": 27 | sample_text = "This is a sample text input for the TextFeaturizer." 28 | featurizer = TextFeaturizer() 29 | print(featurizer) 30 | print(type(featurizer)) 31 | print(featurizer.encode(sample_text)) 32 | -------------------------------------------------------------------------------- /pyhealth/datasets/medical_transcriptions.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | from ..tasks import MedicalTranscriptionsClassification 6 | from .base_dataset import BaseDataset 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class MedicalTranscriptionsDataset(BaseDataset): 12 | """Medical transcription data scraped from mtsamples.com. 13 | 14 | Dataset is available at: 15 | https://www.kaggle.com/datasets/tboyle10/medicaltranscriptions 16 | 17 | Args: 18 | root: Root directory of the raw data. 19 | dataset_name: Name of the dataset. Defaults to "medical_transcriptions". 20 | config_path: Path to the configuration file. If None, uses default config. 21 | 22 | Attributes: 23 | root: Root directory of the raw data (should contain many csv files). 24 | dataset_name: Name of the dataset. 25 | config_path: Path to the configuration file. 26 | 27 | Examples: 28 | >>> from pyhealth.datasets import MedicalTranscriptionsDataset 29 | >>> dataset = MedicalTranscriptionsDataset( 30 | ... root="path/to/medical_transcriptions", 31 | ... ) 32 | >>> dataset.stats() 33 | >>> samples = dataset.set_task() 34 | >>> print(samples[0]) 35 | """ 36 | 37 | def __init__( 38 | self, 39 | root: str, 40 | dataset_name: Optional[str] = None, 41 | config_path: Optional[str] = None, 42 | ) -> None: 43 | if config_path is None: 44 | logger.info("No config path provided, using default config") 45 | config_path = ( 46 | Path(__file__).parent / "configs" / "medical_transcriptions.yaml" 47 | ) 48 | default_tables = ["mtsamples"] 49 | super().__init__( 50 | root=root, 51 | tables=default_tables, 52 | dataset_name=dataset_name or "medical_transcriptions", 53 | config_path=config_path, 54 | ) 55 | return 56 | 57 | @property 58 | def default_task(self) -> MedicalTranscriptionsClassification: 59 | """Returns the default task for this dataset.""" 60 | return MedicalTranscriptionsClassification() 61 | -------------------------------------------------------------------------------- /pyhealth/interpret/__init__.py: -------------------------------------------------------------------------------- 1 | from pyhealth.interpret import methods -------------------------------------------------------------------------------- /pyhealth/interpret/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from pyhealth.interpret.methods.chefer import CheferRelevance -------------------------------------------------------------------------------- /pyhealth/medcode/__init__.py: -------------------------------------------------------------------------------- 1 | from .codes.atc import ATC 2 | from .codes.ccscm import CCSCM 3 | from .codes.ccsproc import CCSPROC 4 | from .codes.icd10cm import ICD10CM 5 | from .codes.icd10proc import ICD10PROC 6 | from .codes.icd9cm import ICD9CM 7 | from .codes.icd9proc import ICD9PROC 8 | from .codes.ndc import NDC 9 | from .codes.rxnorm import RxNorm 10 | from .codes.umls import UMLS 11 | from .cross_map import CrossMap 12 | from .inner_map import InnerMap -------------------------------------------------------------------------------- /pyhealth/medcode/codes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/pyhealth/medcode/codes/__init__.py -------------------------------------------------------------------------------- /pyhealth/medcode/codes/atc.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pyhealth.medcode.inner_map import InnerMap 4 | from pyhealth.medcode.utils import download_and_read_csv 5 | 6 | 7 | class ATC(InnerMap): 8 | """Anatomical Therapeutic Chemical.""" 9 | 10 | def __init__(self, **kwargs): 11 | super(ATC, self).__init__(vocabulary="ATC", **kwargs) 12 | self.ddi = dict() 13 | 14 | @staticmethod 15 | def convert(code: str, level=5): 16 | """Convert ATC code to a specific level.""" 17 | if type(level) is str: 18 | level = int(level) 19 | assert level in [1, 2, 3, 4, 5] 20 | if level == 1: 21 | return code[:1] 22 | elif level == 2: 23 | return code[:3] 24 | elif level == 3: 25 | return code[:4] 26 | elif level == 4: 27 | return code[:5] 28 | else: 29 | return code 30 | 31 | def get_ddi( 32 | self, gamenet_ddi: bool = False, refresh_cache: bool = False 33 | ) -> List[str]: 34 | """Gets the drug-drug interactions (DDI). 35 | 36 | Args: 37 | gamenet_ddi: Whether to use the DDI from the GAMENet paper, 38 | which is a subset of the DDI from the ATC. 39 | refresh_cache: Whether to refresh the cache. Default is False. 40 | """ 41 | filename = "DDI_GAMENet.csv" if gamenet_ddi else "DDI.csv" 42 | if filename not in self.ddi or refresh_cache: 43 | df = download_and_read_csv(filename, refresh_cache) 44 | ddi = [] 45 | for idx, row in df.iterrows(): 46 | ddi.append([row["ATC i"], row["ATC j"]]) 47 | self.ddi[filename] = ddi 48 | return self.ddi[filename] 49 | 50 | 51 | if __name__ == "__main__": 52 | code_sys = ATC(refresh_cache=True) 53 | code_sys.stat() 54 | print(code_sys.lookup("N01AB07")) 55 | print(code_sys.lookup("N01AB07", attribute="level")) 56 | print(code_sys.lookup("N01AB07", attribute="description")) 57 | print(code_sys.lookup("N01AB07", attribute="indication")) 58 | print(code_sys.lookup("N01AB07", attribute="smiles")) 59 | print(code_sys.lookup("N01AB07", attribute="drugbank_id")) 60 | print(code_sys.get_ancestors("N01AB07")) 61 | print(code_sys.get_descendants("N01AB")) 62 | print(len(code_sys.get_ddi(gamenet_ddi=True))) 63 | print(len(code_sys.get_ddi(gamenet_ddi=False))) 64 | -------------------------------------------------------------------------------- /pyhealth/medcode/codes/ccscm.py: -------------------------------------------------------------------------------- 1 | from pyhealth.medcode.inner_map import InnerMap 2 | 3 | 4 | class CCSCM(InnerMap): 5 | """Classification of Diseases, Clinical Modification.""" 6 | 7 | def __init__(self, **kwargs): 8 | super(CCSCM, self).__init__(vocabulary="CCSCM", **kwargs) 9 | 10 | 11 | if __name__ == "__main__": 12 | code_sys = CCSCM(refresh_cache=True) 13 | code_sys.stat() 14 | print("20" in code_sys) 15 | print(code_sys.lookup("10")) 16 | print(code_sys.get_ancestors("10")) 17 | print(code_sys.get_descendants("10")) 18 | -------------------------------------------------------------------------------- /pyhealth/medcode/codes/ccsproc.py: -------------------------------------------------------------------------------- 1 | from pyhealth.medcode.inner_map import InnerMap 2 | 3 | 4 | class CCSPROC(InnerMap): 5 | """Classification of Diseases, Procedure.""" 6 | 7 | def __init__(self, **kwargs): 8 | super(CCSPROC, self).__init__(vocabulary="CCSPROC", **kwargs) 9 | 10 | 11 | if __name__ == "__main__": 12 | code_sys = CCSPROC(refresh_cache=True) 13 | code_sys.stat() 14 | print("1" in code_sys) 15 | print(code_sys.lookup("20")) 16 | print(code_sys.get_ancestors("20")) 17 | print(code_sys.get_descendants("20")) 18 | -------------------------------------------------------------------------------- /pyhealth/medcode/codes/icd10cm.py: -------------------------------------------------------------------------------- 1 | from pyhealth.medcode.inner_map import InnerMap 2 | 3 | 4 | # TODO: add convert 5 | 6 | 7 | class ICD10CM(InnerMap): 8 | """10-th International Classification of Diseases, Clinical Modification.""" 9 | 10 | def __init__(self, **kwargs): 11 | super(ICD10CM, self).__init__(vocabulary="ICD10CM", **kwargs) 12 | 13 | @staticmethod 14 | def standardize(code: str): 15 | """Standardizes ICD10CM code.""" 16 | if "." in code: 17 | return code 18 | if len(code) <= 3: 19 | return code 20 | return code[:3] + "." + code[3:] 21 | 22 | 23 | if __name__ == "__main__": 24 | code_sys = ICD10CM(refresh_cache=True) 25 | code_sys.stat() 26 | print("A00.0" in code_sys) 27 | print(code_sys.lookup("D50.0")) 28 | print(code_sys.get_ancestors("D50.0")) 29 | print(code_sys.get_descendants("D50")) 30 | -------------------------------------------------------------------------------- /pyhealth/medcode/codes/icd10proc.py: -------------------------------------------------------------------------------- 1 | from pyhealth.medcode.inner_map import InnerMap 2 | 3 | 4 | # TODO: add convert 5 | 6 | 7 | class ICD10PROC(InnerMap): 8 | """10-th International Classification of Diseases, Procedure.""" 9 | 10 | def __init__(self, **kwargs): 11 | super(ICD10PROC, self).__init__(vocabulary="ICD10PROC", **kwargs) 12 | 13 | 14 | if __name__ == "__main__": 15 | code_sys = ICD10PROC(refresh_cache=True) 16 | code_sys.stat() 17 | print("0LBG0ZZ" in code_sys) 18 | print(code_sys.lookup("0LBG0ZZ")) 19 | print(code_sys.get_ancestors("0LBG0ZZ")) 20 | print(code_sys.get_descendants("0LBG0")) 21 | -------------------------------------------------------------------------------- /pyhealth/medcode/codes/icd9cm.py: -------------------------------------------------------------------------------- 1 | from pyhealth.medcode.inner_map import InnerMap 2 | 3 | 4 | # TODO: add convert 5 | 6 | 7 | class ICD9CM(InnerMap): 8 | """9-th International Classification of Diseases, Clinical Modification.""" 9 | 10 | def __init__(self, **kwargs): 11 | super(ICD9CM, self).__init__(vocabulary="ICD9CM", **kwargs) 12 | 13 | @staticmethod 14 | def standardize(code: str): 15 | """Standardizes ICD9CM code.""" 16 | if "." in code: 17 | return code 18 | if code.startswith("E"): 19 | if len(code) <= 4: 20 | return code 21 | return code[:4] + "." + code[4:] 22 | else: 23 | if len(code) <= 3: 24 | return code 25 | return code[:3] + "." + code[3:] 26 | 27 | 28 | if __name__ == "__main__": 29 | code_sys = ICD9CM(refresh_cache=True) 30 | code_sys.stat() 31 | print("821.01" in code_sys) 32 | print(code_sys.lookup("82101")) 33 | print(code_sys.get_ancestors("821.01")) 34 | print(code_sys.get_descendants("821")) 35 | -------------------------------------------------------------------------------- /pyhealth/medcode/codes/icd9proc.py: -------------------------------------------------------------------------------- 1 | from pyhealth.medcode.inner_map import InnerMap 2 | 3 | 4 | # TODO: add convert 5 | 6 | 7 | class ICD9PROC(InnerMap): 8 | """9-th International Classification of Diseases, Procedure.""" 9 | 10 | def __init__(self, **kwargs): 11 | super(ICD9PROC, self).__init__(vocabulary="ICD9PROC", **kwargs) 12 | 13 | @staticmethod 14 | def standardize(code: str): 15 | """Standardizes ICD9PROC code.""" 16 | if "." in code: 17 | return code 18 | if len(code) <= 2: 19 | return code 20 | return code[:2] + "." + code[2:] 21 | 22 | 23 | if __name__ == "__main__": 24 | code_sys = ICD9PROC(refresh_cache=True) 25 | code_sys.stat() 26 | print("81.01" in code_sys) 27 | print(code_sys.lookup("01.31")) 28 | print(code_sys.get_ancestors("01.31")) 29 | print(code_sys.get_descendants("01")) 30 | -------------------------------------------------------------------------------- /pyhealth/medcode/codes/ndc.py: -------------------------------------------------------------------------------- 1 | from pyhealth.medcode.inner_map import InnerMap 2 | 3 | 4 | # TODO: add standardize for different formats of NDC codes 5 | 6 | 7 | class NDC(InnerMap): 8 | """National Drug Code.""" 9 | 10 | def __init__(self, **kwargs): 11 | super(NDC, self).__init__(vocabulary="NDC", **kwargs) 12 | 13 | 14 | if __name__ == "__main__": 15 | code_sys = NDC(refresh_cache=True) 16 | code_sys.stat() 17 | print("00527051210" in code_sys) 18 | print(code_sys.lookup("00527051210")) 19 | print(code_sys.get_ancestors("00527051210")) 20 | print(code_sys.get_descendants("00527051210")) 21 | -------------------------------------------------------------------------------- /pyhealth/medcode/codes/rxnorm.py: -------------------------------------------------------------------------------- 1 | from pyhealth.medcode.inner_map import InnerMap 2 | 3 | 4 | class RxNorm(InnerMap): 5 | """RxNorm.""" 6 | 7 | def __init__(self, **kwargs): 8 | super(RxNorm, self).__init__(vocabulary="RxNorm", **kwargs) 9 | 10 | 11 | if __name__ == "__main__": 12 | code_sys = RxNorm(refresh_cache=True) 13 | code_sys.stat() 14 | print("21914" in code_sys) 15 | print(code_sys.graph.nodes["21914"]) 16 | print(code_sys.get_ancestors("21914")) 17 | print(code_sys.get_descendants("21914")) 18 | -------------------------------------------------------------------------------- /pyhealth/medcode/codes/umls.py: -------------------------------------------------------------------------------- 1 | from pyhealth.medcode.inner_map import InnerMap 2 | 3 | 4 | class UMLS(InnerMap): 5 | """UMLS.""" 6 | 7 | def __init__(self, **kwargs): 8 | super(UMLS, self).__init__(vocabulary="UMLS", **kwargs) 9 | 10 | 11 | if __name__ == "__main__": 12 | code_sys = UMLS(refresh_cache=True) 13 | code_sys.stat() 14 | print(code_sys.lookup("C0000768")) 15 | -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from .kg_emb import * 2 | from .lm_emb import * -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/kg_emb/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import TransE 2 | from .models import ComplEx 3 | from .models import RotatE 4 | from .models import DistMult 5 | from .models import kg_base 6 | from .datasets import * 7 | from .tasks import * 8 | -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .sample_kg_dataset import SampleKGDataset 2 | from .base_kg_dataset import BaseKGDataset 3 | from .umls import UMLSDataset 4 | from .splitter import split -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/splitter.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import Optional, Tuple, Union, List 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from pyhealth.datasets import SampleBaseDataset 8 | 9 | 10 | def split( 11 | dataset: SampleBaseDataset, 12 | ratios: Union[Tuple[float, float, float], List[float]], 13 | seed: Optional[int] = None, 14 | ): 15 | """Splits the dataset by its outermost indexed items 16 | 17 | Args: 18 | dataset: a `SampleBaseDataset` object 19 | ratios: a list/tuple of ratios for train / val / test 20 | seed: random seed for shuffling the dataset 21 | 22 | Returns: 23 | train_dataset, val_dataset, test_dataset: three subsets of the dataset of 24 | type `torch.utils.data.Subset`. 25 | 26 | Note: 27 | The original dataset can be accessed by `train_dataset.dataset`, 28 | `val_dataset.dataset`, and `test_dataset.dataset`. 29 | """ 30 | 31 | if seed is not None: 32 | np.random.seed(seed) 33 | assert sum(ratios) == 1.0, "ratios must sum to 1.0" 34 | index = np.arange(len(dataset)) 35 | np.random.shuffle(index) 36 | train_index = index[: int(len(dataset) * ratios[0])] 37 | val_index = index[ 38 | int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1])) 39 | ] 40 | test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :] 41 | train_dataset = torch.utils.data.Subset(dataset, train_index) 42 | train_dataset = [{**train_dataset[i], **{'train': True, 'hyperparameters': dataset.task_spec_param}} for i in range(len(train_dataset))] 43 | 44 | val_dataset = torch.utils.data.Subset(dataset, val_index) 45 | val_dataset = [{**val_dataset[i], 'train': False} for i in range(len(val_dataset))] 46 | test_dataset = torch.utils.data.Subset(dataset, test_index) 47 | test_dataset = [{**test_dataset[i], 'train': False} for i in range(len(test_dataset))] 48 | return train_dataset, val_dataset, test_dataset -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/kg_emb/datasets/umls.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from tqdm import tqdm 4 | import numpy as np 5 | import pandas as pd 6 | from pandarallel import pandarallel 7 | from pyhealth.medcode.pretrained_embeddings.kg_emb.datasets import BaseKGDataset 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | class UMLSDataset(BaseKGDataset): 12 | """Base UMLS knowleddge graph dataset 13 | 14 | Dataset is available at https://www.nlm.nih.gov/research/umls/index.html 15 | 16 | Args: 17 | dataset_name: name of the dataset. 18 | root: root directory of the raw data (should contain many csv files). 19 | dev: whether to enable dev mode (only use a small subset of the data). 20 | Default is False. 21 | refresh_cache: whether to refresh the cache; if true, the dataset will 22 | be processed from scratch and the cache will be updated. Default is False. 23 | 24 | """ 25 | 26 | def raw_graph_process(self): 27 | pandarallel.initialize(progress_bar=False) 28 | if self.dev == False: 29 | self.graph_path = os.path.join(self.root, "graph.txt") 30 | else: 31 | self.graph_path = os.path.join(self.root, "graph_filtered.txt") 32 | 33 | if os.path.exists(self.graph_path): 34 | logger.debug("umls knowledge graph exists and load umls") 35 | else: 36 | logger.debug("umls does not exist") 37 | 38 | print("Loading UMLS knowledge graph...") 39 | graph_df = pd.read_csv( 40 | self.graph_path, 41 | sep='\t', 42 | names=['e1', 'r', 'e2'] 43 | ) 44 | 45 | print("Processing UMLS knowledge graph...") 46 | entity_list = pd.unique(graph_df[['e1', 'e2']].values.ravel('K')) 47 | relation_list = pd.unique(graph_df['r'].values) 48 | 49 | self.entity2id = {val: i for i, val in enumerate(entity_list)} 50 | self.relation2id = {val: i for i, val in enumerate(relation_list)} 51 | self.entity_num = len(self.entity2id) 52 | self.relation_num = len(self.relation2id) 53 | 54 | print("Building UMLS knowledge graph...") 55 | self.triples = [(self.entity2id[e1], self.relation2id[r], self.entity2id[e2]) 56 | for e1, r, e2 in tqdm(zip(graph_df['e1'], graph_df['r'], graph_df['e2']), total=graph_df.shape[0])] 57 | 58 | 59 | return 60 | 61 | 62 | if __name__ == "__main__": 63 | dataset = UMLSDataset( 64 | root="https://storage.googleapis.com/pyhealth/umls/", 65 | dev=True, 66 | refresh_cache=True, 67 | ) 68 | dataset.stat() 69 | dataset.info() -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/kg_emb/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .kg_base import KGEBaseModel 2 | from .transe import TransE 3 | from .rotate import RotatE 4 | from .distmult import DistMult 5 | from .complex import ComplEx -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/kg_emb/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .link_prediction import link_prediction_fn 2 | -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/kg_emb/tasks/link_prediction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm import tqdm 4 | from typing import Tuple, List 5 | from collections import defaultdict 6 | 7 | def link_prediction_fn( 8 | triples: List[Tuple] 9 | ): 10 | 11 | """Process a triple list for the link prediction task 12 | 13 | Link prediction is a task to either 14 | Tail Prediction: predict tail entity t given a triple query (h, r, ?), or 15 | Head Prediction: predict head entity h given a triple query (?, r, t) 16 | 17 | Args: 18 | triples: a list of triples (indexed) from the knowledge graph 19 | 20 | Returns: 21 | samples: a list of samples 22 | """ 23 | samples = [] 24 | count = count_frequency(triples) 25 | triple_set = set(triples) 26 | gt_head, gt_tail = ground_truth_for_query(triple_set) 27 | 28 | for positive_sample in tqdm(triples): 29 | head, relation, tail = positive_sample 30 | subsampling_weight = torch.sqrt(1 / torch.Tensor([count[(head, relation)] + count[(tail, -relation-1)]])) 31 | gt_h = gt_head[(relation, tail)] 32 | gt_t = gt_tail[(head, relation)] 33 | 34 | tmp = { 35 | "triple": positive_sample, 36 | "ground_truth_head": gt_h, 37 | "ground_truth_tail": gt_t, 38 | "subsampling_weight": subsampling_weight 39 | } 40 | 41 | samples.append(tmp) 42 | 43 | return samples 44 | 45 | 46 | def ground_truth_for_query(triple_set): 47 | """ 48 | Search ground truth of either query (h, r, ?) or (?, r, t) in the dataset 49 | """ 50 | gt_head = defaultdict(list) 51 | gt_tail = defaultdict(list) 52 | 53 | for triple in triple_set: 54 | head, relation, tail = triple 55 | gt_head[(relation, tail)].append(head) 56 | gt_tail[(head, relation)].append(tail) 57 | 58 | return gt_head, gt_tail 59 | 60 | 61 | def count_frequency(triples, start=4): 62 | ''' 63 | Get frequency of a partial triple like (head, relation) or (relation, tail) 64 | The frequency will be used for subsampling like word2vec 65 | ''' 66 | count = {} 67 | for head, relation, tail in triples: 68 | if (head, relation) not in count: 69 | count[(head, relation)] = start 70 | else: 71 | count[(head, relation)] += 1 72 | 73 | if (tail, -relation-1) not in count: 74 | count[(tail, -relation-1)] = start 75 | else: 76 | count[(tail, -relation-1)] += 1 77 | return count 78 | 79 | -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/lm_emb/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/pyhealth/medcode/pretrained_embeddings/lm_emb/__init__.py -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/lm_emb/huggingface_retriever.py: -------------------------------------------------------------------------------- 1 | def embedding_retrieve(model, tokenizer, phrase): 2 | # Encode the sentence 3 | inputs = tokenizer(phrase, return_tensors='pt') 4 | 5 | # Get the model's output 6 | outputs = model(**inputs) 7 | 8 | # Extract the embeddings 9 | embedding = outputs.last_hidden_state.mean(dim=1) 10 | 11 | # Now, `embedding` is a tensor that contains the embedding for your sentence. 12 | # You can convert it to a numpy array if needed: 13 | embedding = embedding.detach().numpy().tolist()[0] 14 | 15 | return embedding 16 | -------------------------------------------------------------------------------- /pyhealth/medcode/pretrained_embeddings/lm_emb/openai_retriever.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import os 4 | 5 | key = os.environ['OPENAI_API_KEY'] 6 | 7 | def embedding_retrieve(term): 8 | # Set up the API endpoint URL and request headers 9 | url = "https://api.openai.com/v1/embeddings" 10 | headers = { 11 | "Content-Type": "application/json", 12 | "Authorization": f"Bearer {key}" 13 | } 14 | 15 | # Set up the request payload with the text string to embed and the model to use 16 | payload = { 17 | "input": term, 18 | "model": "text-embedding-ada-002" 19 | } 20 | 21 | # Send the request and retrieve the response 22 | response = requests.post(url, headers=headers, data=json.dumps(payload)) 23 | 24 | # Extract the text embeddings from the response JSON 25 | embedding = response.json()["data"][0]['embedding'] 26 | 27 | return embedding -------------------------------------------------------------------------------- /pyhealth/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .binary import binary_metrics_fn 2 | from .drug_recommendation import ddi_rate_score 3 | from .multiclass import multiclass_metrics_fn 4 | from .multilabel import multilabel_metrics_fn 5 | 6 | # from .fairness import fairness_metrics_fn 7 | from .ranking import ranking_metrics_fn 8 | from .regression import regression_metrics_fn 9 | -------------------------------------------------------------------------------- /pyhealth/metrics/drug_recommendation.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | 6 | # TODO: this metric is very ad-hoc, need to be improved 7 | 8 | 9 | def ddi_rate_score(medications: List[np.ndarray], ddi_matrix: np.ndarray) -> float: 10 | """DDI rate score. 11 | 12 | Args: 13 | medications: list of medications for each patient, where each medication 14 | is represented by the corresponding index in the ddi matrix. 15 | ddi_matrix: array-like of shape (n_classes, n_classes). 16 | 17 | Returns: 18 | result: DDI rate score. 19 | """ 20 | all_cnt = 0 21 | ddi_cnt = 0 22 | for sample in medications: 23 | for i, med_i in enumerate(sample): 24 | for j, med_j in enumerate(sample): 25 | if j <= i: continue 26 | all_cnt += 1 27 | if ddi_matrix[med_i, med_j] == 1 or ddi_matrix[med_j, med_i] == 1: 28 | ddi_cnt += 1 29 | if all_cnt == 0: 30 | return 0 31 | return ddi_cnt / all_cnt 32 | -------------------------------------------------------------------------------- /pyhealth/metrics/fairness.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import numpy as np 4 | 5 | from pyhealth.metrics.fairness_utils import disparate_impact, statistical_parity_difference 6 | 7 | def fairness_metrics_fn( 8 | y_true: np.ndarray, 9 | y_prob: np.ndarray, 10 | sensitive_attributes: np.ndarray, 11 | favorable_outcome: int = 1, 12 | metrics: Optional[List[str]] = None, 13 | threshold: float = 0.5, 14 | ) -> Dict[str, float]: 15 | """Computes metrics for binary classification. 16 | 17 | User can specify which metrics to compute by passing a list of metric names. 18 | The accepted metric names are: 19 | - disparate_impact: 20 | - statistical_parity_difference: 21 | 22 | If no metrics are disparate_impact, and statistical_parity_difference are computed by default. 23 | 24 | Args: 25 | y_true: True target values of shape (n_samples,). 26 | y_prob: Predicted probabilities of shape (n_samples,). 27 | sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group. 28 | favorable_outcome: Label value which is considered favorable (i.e. "positive"). 29 | metrics: List of metrics to compute. Default is ["disparate_impact", "statistical_parity_difference"]. 30 | threshold: Threshold for binary classification. Default is 0.5. 31 | 32 | Returns: 33 | Dictionary of metrics whose keys are the metric names and values are 34 | the metric values. 35 | """ 36 | if metrics is None: 37 | metrics = ["disparate_impact", "statistical_parity_difference"] 38 | 39 | y_pred = y_prob.copy() 40 | y_pred[y_pred >= threshold] = 1 41 | y_pred[y_pred < threshold] = 0 42 | 43 | output = {} 44 | for metric in metrics: 45 | if metric == "disparate_impact": 46 | output[metric] = disparate_impact(sensitive_attributes, y_pred, favorable_outcome) 47 | elif metric == "statistical_parity_difference": 48 | output[metric] = statistical_parity_difference(sensitive_attributes, y_pred, favorable_outcome) 49 | else: 50 | raise ValueError(f"Unknown metric for fairness: {metric}") 51 | return output 52 | 53 | -------------------------------------------------------------------------------- /pyhealth/metrics/fairness_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .group import disparate_impact, statistical_parity_difference 2 | from .utils import sensitive_attributes_from_patient_ids -------------------------------------------------------------------------------- /pyhealth/metrics/fairness_utils/group.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """ 4 | Notation: 5 | - Protected group: P 6 | - Unprotected group: U 7 | """ 8 | 9 | def disparate_impact(sensitive_attributes: np.ndarray, y_pred: np.ndarray, favorable_outcome: int = 1, allow_zero_division = False, epsilon: float = 1e-8) -> float: 10 | """ 11 | Computes the disparate impact between the the protected and unprotected group. 12 | 13 | disparate_impact = P(y_pred = favorable_outcome | P) / P(y_pred = favorable_outcome | U) 14 | 15 | Args: 16 | sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group. 17 | y_pred: Predicted target values of shape (n_samples,). 18 | favorable_outcome: Label value which is considered favorable (i.e. "positive"). 19 | allow_zero_division: If True, use epsilon instead of 0 in the denominator if the denominator is 0. Otherwise, raise a ValueError. 20 | 21 | Returns: 22 | The disparate impact between the protected and unprotected group. 23 | """ 24 | 25 | p_fav_unpr = np.sum(y_pred[sensitive_attributes == 0] == favorable_outcome) / len(y_pred[sensitive_attributes == 0]) 26 | p_fav_prot = np.sum(y_pred[sensitive_attributes == 1] == favorable_outcome) / len(y_pred[sensitive_attributes == 1]) 27 | 28 | if p_fav_unpr == 0: 29 | if allow_zero_division: 30 | p_fav_unpr = epsilon 31 | else: 32 | raise ValueError("Unprotected group has no instances with a favorable outcome. Disparate impact is undefined.") 33 | 34 | disparate_impact_value = p_fav_prot / p_fav_unpr 35 | 36 | return disparate_impact_value 37 | 38 | def statistical_parity_difference(sensitive_attributes: np.ndarray, y_pred: np.ndarray, favorable_outcome: int = 1) -> float: 39 | """ 40 | Computes the statistical parity difference between the the protected and unprotected group. 41 | 42 | statistical_parity_difference = P(y_pred = favorable_outcome | P) - P(y_pred = favorable_outcome | U) 43 | Args: 44 | sensitive_attributes: Sensitive attributes of shape (n_samples,) where 1 is the protected group and 0 is the unprotected group. 45 | y_pred: Predicted target values of shape (n_samples,). 46 | favorable_outcome: Label value which is considered favorable (i.e. "positive"). 47 | Returns: 48 | The statistical parity difference between the protected and unprotected group. 49 | """ 50 | 51 | p_fav_unpr = np.sum(y_pred[sensitive_attributes == 0] == favorable_outcome) / len(y_pred[sensitive_attributes == 0]) 52 | p_fav_prot = np.sum(y_pred[sensitive_attributes == 1] == favorable_outcome) / len(y_pred[sensitive_attributes == 1]) 53 | 54 | statistical_parity_difference_value = p_fav_prot - p_fav_unpr 55 | 56 | return statistical_parity_difference_value 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /pyhealth/metrics/fairness_utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import List 3 | import numpy as np 4 | 5 | from pyhealth.datasets import BaseEHRDataset 6 | 7 | def sensitive_attributes_from_patient_ids(dataset: BaseEHRDataset, 8 | patient_ids: List[str], 9 | sensitive_attribute: str, 10 | protected_group: str) -> np.ndarray: 11 | """ 12 | Returns the desired sensitive attribute array from patient_ids. 13 | 14 | Args: 15 | dataset: Dataset object. 16 | patient_ids: List of patient IDs. 17 | sensitive_attribute: Sensitive attribute to extract. 18 | protected_group: Value of the protected group. 19 | 20 | Returns: 21 | Sensitive attribute array of shape (n_samples,). 22 | """ 23 | 24 | sensitive_attribute_array = np.zeros(len(patient_ids)) 25 | for idx, patient_id in enumerate(patient_ids): 26 | sensitive_attribute_value = getattr(dataset.patients[patient_id], sensitive_attribute) 27 | if sensitive_attribute_value == protected_group: 28 | sensitive_attribute_array[idx] = 1 29 | return sensitive_attribute_array 30 | 31 | -------------------------------------------------------------------------------- /pyhealth/metrics/regression.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import numpy as np 4 | import sklearn.metrics as sklearn_metrics 5 | 6 | 7 | def regression_metrics_fn( 8 | x: np.ndarray, 9 | x_rec: np.ndarray, 10 | metrics: Optional[List[str]] = None, 11 | ) -> Dict[str, float]: 12 | """Computes metrics for regression. 13 | 14 | User can specify which metrics to compute by passing a list of metric names. 15 | The accepted metric names are: 16 | - kl_divergence: KL divergence 17 | - mse: mean squared error 18 | - mae: mean absolute error 19 | If no metrics are specified, kd_div, mse, mae are computed by default. 20 | 21 | This function calls sklearn.metrics functions to compute the metrics. For 22 | more information on the metrics, please refer to the documentation of the 23 | corresponding sklearn.metrics functions. 24 | 25 | Args: 26 | x: True target data sample. 27 | x_rec: reconstructed data sample. 28 | metrics: List of metrics to compute. Default is ["kl_divergence", "mse", "mae"]. 29 | 30 | Returns: 31 | Dictionary of metrics whose keys are the metric names and values are 32 | the metric values. 33 | 34 | Examples: 35 | >>> from pyhealth.metrics import binary_metrics_fn 36 | >>> y_true = np.array([0, 0, 1, 1]) 37 | >>> y_prob = np.array([0.1, 0.4, 0.35, 0.8]) 38 | >>> binary_metrics_fn(y_true, y_prob, metrics=["accuracy"]) 39 | {'accuracy': 0.75} 40 | """ 41 | if metrics is None: 42 | metrics = ["kl_divergence", "mse", "mae"] 43 | 44 | x = x.flatten() 45 | x_rec = x_rec.flatten() 46 | 47 | if x.shape != x_rec.shape: 48 | raise ValueError("x and x_rec must have the same shape.") 49 | 50 | output = {} 51 | for metric in metrics: 52 | if metric == "kl_divergence": 53 | x[x < 1e-6] = 1e-6 54 | x_rec[x_rec < 1e-6] = 1e-6 55 | x = x / np.sum(x) 56 | x_rec = x_rec / np.sum(x_rec) 57 | kl_divergence = np.sum(x_rec * np.log(x_rec / x)) 58 | output["kl_divergence"] = kl_divergence 59 | elif metric == "mse": 60 | mse = sklearn_metrics.mean_squared_error(x, x_rec) 61 | output["mse"] = mse 62 | elif metric == "mae": 63 | mae = sklearn_metrics.mean_absolute_error(x, x_rec) 64 | output["mae"] = mae 65 | else: 66 | raise ValueError(f"Unknown metric for regression task: {metric}") 67 | return output 68 | 69 | 70 | if __name__ == "__main__": 71 | x = np.random.random(size=10000) 72 | x_rec = np.random.random(size=10000) 73 | print(regression_metrics_fn(x, x_rec)) 74 | -------------------------------------------------------------------------------- /pyhealth/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .adacare import AdaCare, AdaCareLayer 2 | from .agent import Agent, AgentLayer 3 | from .base_model import BaseModel 4 | from .cnn import CNN, CNNLayer 5 | from .concare import ConCare, ConCareLayer 6 | from .contrawr import ContraWR, ResBlock2D 7 | from .deepr import Deepr, DeeprLayer 8 | from .embedding import EmbeddingModel 9 | from .gamenet import GAMENet, GAMENetLayer 10 | from .gan import GAN 11 | from .gnn import GAT, GCN 12 | from .graph_torchvision_model import Graph_TorchvisionModel 13 | from .grasp import GRASP, GRASPLayer 14 | from .medlink import MedLink 15 | from .micron import MICRON, MICRONLayer 16 | from .mlp import MLP 17 | from .molerec import MoleRec, MoleRecLayer 18 | from .retain import RETAIN, RETAINLayer 19 | from .rnn import RNN, RNNLayer 20 | from .safedrug import SafeDrug, SafeDrugLayer 21 | from .sparcnet import DenseBlock, DenseLayer, SparcNet, TransitionLayer 22 | from .stagenet import StageNet, StageNetLayer 23 | from .tcn import TCN, TCNLayer 24 | from .torchvision_model import TorchvisionModel 25 | from .transformer import Transformer, TransformerLayer 26 | from .transformers_model import TransformersModel 27 | from .vae import VAE 28 | -------------------------------------------------------------------------------- /pyhealth/models/embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..datasets import SampleDataset 7 | from ..processors import SequenceProcessor, TimeseriesProcessor 8 | from .base_model import BaseModel 9 | 10 | 11 | class EmbeddingModel(BaseModel): 12 | """ 13 | EmbeddingModel is responsible for creating embedding layers for different types of input data. 14 | 15 | Attributes: 16 | dataset (SampleDataset): The dataset containing input processors. 17 | embedding_layers (nn.ModuleDict): A dictionary of embedding layers for each input field. 18 | """ 19 | 20 | def __init__(self, dataset: SampleDataset, embedding_dim: int = 128): 21 | """ 22 | Initializes the EmbeddingModel with the given dataset and embedding dimension. 23 | 24 | Args: 25 | dataset (SampleDataset): The dataset containing input processors. 26 | embedding_dim (int): The dimension of the embedding space. Default is 128. 27 | """ 28 | super().__init__(dataset) 29 | self.embedding_layers = nn.ModuleDict() 30 | for field_name, processor in self.dataset.input_processors.items(): 31 | if isinstance(processor, SequenceProcessor): 32 | vocab_size = len(processor.code_vocab) 33 | self.embedding_layers[field_name] = nn.Embedding( 34 | num_embeddings=vocab_size, 35 | embedding_dim=embedding_dim, 36 | padding_idx=0 37 | ) 38 | elif isinstance(processor, TimeseriesProcessor): 39 | self.embedding_layers[field_name] = nn.Linear( 40 | in_features=processor.size, 41 | out_features=embedding_dim 42 | ) 43 | 44 | def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 45 | """ 46 | Forward pass to compute embeddings for the input data. 47 | 48 | Args: 49 | inputs (Dict[str, torch.Tensor]): A dictionary of input tensors. 50 | 51 | Returns: 52 | Dict[str, torch.Tensor]: A dictionary of embedded tensors. 53 | """ 54 | embedded = {} 55 | for field_name, tensor in inputs.items(): 56 | if field_name in self.embedding_layers: 57 | tensor = tensor.to(self.device) 58 | embedded[field_name] = self.embedding_layers[field_name](tensor) 59 | else: 60 | embedded[field_name] = tensor # passthrough for continuous features 61 | return embedded 62 | 63 | def __repr__(self) -> str: 64 | """ 65 | Returns a string representation of the EmbeddingModel. 66 | 67 | Returns: 68 | str: A string representation of the model. 69 | """ 70 | return f"EmbeddingModel(embedding_layers={self.embedding_layers})" 71 | -------------------------------------------------------------------------------- /pyhealth/models/medlink/__init__.py: -------------------------------------------------------------------------------- 1 | from .bm25 import BM25Okapi 2 | from .model import MedLink 3 | from .utils import convert_to_ir_format 4 | from .utils import filter_by_candidates 5 | from .utils import generate_candidates 6 | from .utils import get_bm25_hard_negatives 7 | from .utils import get_eval_dataloader 8 | from .utils import get_train_dataloader 9 | from .utils import tvt_split 10 | -------------------------------------------------------------------------------- /pyhealth/models/transformers_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers import AutoModel, AutoTokenizer 6 | 7 | from ..datasets import SampleDataset 8 | from .base_model import BaseModel 9 | 10 | 11 | class TransformersModel(BaseModel): 12 | """Transformers class for Huggingface models. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | dataset: SampleDataset, 18 | model_name: str, 19 | ): 20 | super(TransformersModel, self).__init__( 21 | dataset=dataset, 22 | ) 23 | self.model_name = model_name 24 | self.model = AutoModel.from_pretrained(model_name) 25 | assert len(self.feature_keys) == 1, "Only one feature key is supported if Transformers is initialized" 26 | self.feature_key = self.feature_keys[0] 27 | assert len(self.label_keys) == 1, "Only one label key is supported if RNN is initialized" 28 | self.label_key = self.label_keys[0] 29 | self.mode = self.dataset.output_schema[self.label_key] 30 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 31 | output_size = self.get_output_size() 32 | hidden_dim = self.model.config.hidden_size 33 | self.fc = nn.Linear(hidden_dim, output_size) 34 | 35 | def forward(self, **kwargs) -> Dict[str, torch.Tensor]: 36 | """Forward propagation.""" 37 | # concat the info within one batch (batch, channel, length) 38 | x = kwargs[self.feature_key] 39 | # TODO: max_length should be a parameter 40 | x = self.tokenizer( 41 | x, return_tensors="pt", padding=True, truncation=True, max_length=256 42 | ) 43 | x = x.to(self.device) 44 | # TODO: should not use pooler_output, but use the last hidden state 45 | embeddings = self.model(**x).pooler_output 46 | logits = self.fc(embeddings) 47 | y_true = kwargs[self.label_key].to(self.device) 48 | loss = self.get_loss_function()(logits, y_true) 49 | y_prob = self.prepare_y_prob(logits) 50 | return { 51 | "loss": loss, 52 | "y_prob": y_prob, 53 | "y_true": y_true, 54 | } 55 | 56 | 57 | if __name__ == "__main__": 58 | from pyhealth.datasets import MedicalTranscriptionsDataset, get_dataloader 59 | 60 | base_dataset = MedicalTranscriptionsDataset( 61 | root="/srv/local/data/zw12/raw_data/MedicalTranscriptions" 62 | ) 63 | 64 | sample_dataset = base_dataset.set_task() 65 | 66 | train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True) 67 | 68 | model = TransformersModel( 69 | dataset=sample_dataset, 70 | feature_keys=["transcription"], 71 | label_key="label", 72 | mode="multiclass", 73 | model_name="emilyalsentzer/Bio_ClinicalBERT", 74 | ) 75 | 76 | # data batch 77 | data_batch = next(iter(train_loader)) 78 | 79 | # try the model 80 | ret = model(**data_batch) 81 | print(ret) 82 | 83 | # try loss backward 84 | ret["loss"].backward() -------------------------------------------------------------------------------- /pyhealth/models/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | 6 | def batch_to_multihot(label: List[List[int]], num_labels: int) -> torch.tensor: 7 | """Converts label to multihot format. 8 | 9 | Args: 10 | label: [batch size, *] 11 | num_labels: total number of labels 12 | 13 | Returns: 14 | multihot: [batch size, num_labels] 15 | """ 16 | multihot = torch.zeros((len(label), num_labels)) 17 | for i, l in enumerate(label): 18 | multihot[i, l] = 1 19 | return multihot 20 | 21 | 22 | def get_last_visit(hidden_states, mask): 23 | """Gets the last visit from the sequence model. 24 | 25 | Args: 26 | hidden_states: [batch size, seq len, hidden_size] 27 | mask: [batch size, seq len] 28 | 29 | Returns: 30 | last_visit: [batch size, hidden_size] 31 | """ 32 | if mask is None: 33 | return hidden_states[:, -1, :] 34 | else: 35 | mask = mask.long() 36 | last_visit = torch.sum(mask, 1) - 1 37 | last_visit = last_visit.unsqueeze(-1) 38 | last_visit = last_visit.expand(-1, hidden_states.shape[1] * hidden_states.shape[2]) 39 | last_visit = torch.reshape(last_visit, hidden_states.shape) 40 | last_hidden_states = torch.gather(hidden_states, 1, last_visit) 41 | last_hidden_state = last_hidden_states[:, 0, :] 42 | return last_hidden_state 43 | -------------------------------------------------------------------------------- /pyhealth/processors/__init__.py: -------------------------------------------------------------------------------- 1 | PROCESSOR_REGISTRY = {} 2 | 3 | 4 | def register_processor(name: str): 5 | def decorator(cls): 6 | if name in PROCESSOR_REGISTRY: 7 | raise ValueError(f"Processor '{name}' already registered.") 8 | PROCESSOR_REGISTRY[name] = cls 9 | return cls 10 | return decorator 11 | 12 | 13 | def get_processor(name: str): 14 | if name not in PROCESSOR_REGISTRY: 15 | raise ValueError(f"Unknown processor: {name}") 16 | return PROCESSOR_REGISTRY[name] 17 | 18 | 19 | # Import all processors so they register themselves 20 | from .image_processor import ImageProcessor 21 | from .label_processor import ( 22 | BinaryLabelProcessor, 23 | MultiClassLabelProcessor, 24 | MultiLabelProcessor, 25 | RegressionLabelProcessor, 26 | ) 27 | from .raw_processor import RawProcessor 28 | from .sequence_processor import SequenceProcessor 29 | from .signal_processor import SignalProcessor 30 | from .text_processor import TextProcessor 31 | from .timeseries_processor import TimeseriesProcessor 32 | 33 | # Expose public API 34 | __all__ = [ 35 | "get_processor", 36 | "ImageProcessor", 37 | "SequenceProcessor", 38 | "TimeseriesProcessor", 39 | "SignalProcessor", 40 | "BinaryLabelProcessor", 41 | "MultiClassLabelProcessor", 42 | "MultiLabelProcessor", 43 | "RegressionLabelProcessor", 44 | "RawProcessor", 45 | "TextProcessor", 46 | ] 47 | -------------------------------------------------------------------------------- /pyhealth/processors/base_processor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, List, Optional 3 | 4 | 5 | class Processor(ABC): 6 | """ 7 | Abstract base processor class. 8 | 9 | Defines optional hooks for saving/loading state to/from disk. 10 | """ 11 | 12 | def save(self, path: str) -> None: 13 | """Optional: Save processor state to disk. 14 | 15 | Args: 16 | path: File path to save processor state. 17 | """ 18 | pass 19 | 20 | def load(self, path: str) -> None: 21 | """Optional: Load processor state from disk. 22 | 23 | Args: 24 | path: File path to load processor state from. 25 | """ 26 | pass 27 | 28 | 29 | class FeatureProcessor(Processor): 30 | """ 31 | Processor for individual fields (features). 32 | 33 | Example: Tokenization, image loading, normalization. 34 | """ 35 | 36 | def fit(self, samples: List[Dict[str, Any]], field: str) -> None: 37 | """Fit the processor to the samples. 38 | 39 | Args: 40 | samples: List of sample dictionaries. 41 | """ 42 | pass 43 | 44 | @abstractmethod 45 | def process(self, value: Any) -> Any: 46 | """Process an individual field value. 47 | 48 | Args: 49 | value: Raw field value. 50 | 51 | Returns: 52 | Processed value. 53 | """ 54 | pass 55 | 56 | 57 | class SampleProcessor(Processor): 58 | """ 59 | Processor for individual samples (dict of fields). 60 | 61 | Example: Imputation, sample-level augmentation, label smoothing. 62 | """ 63 | 64 | @abstractmethod 65 | def process(self, sample: Dict[str, Any]) -> Dict[str, Any]: 66 | """Process a single sample dictionary. 67 | 68 | Args: 69 | sample: Sample dictionary. 70 | 71 | Returns: 72 | Processed sample dictionary. 73 | """ 74 | pass 75 | 76 | 77 | class DatasetProcessor(Processor): 78 | """ 79 | Processor for the full dataset. 80 | 81 | Example: Global normalization, train/val/test splitting, dataset caching. 82 | """ 83 | 84 | @abstractmethod 85 | def process(self, samples: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 86 | """Process the entire dataset. 87 | 88 | Args: 89 | samples: List of sample dictionaries. 90 | 91 | Returns: 92 | List of processed sample dictionaries. 93 | """ 94 | pass 95 | -------------------------------------------------------------------------------- /pyhealth/processors/raw_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from . import register_processor 4 | from .base_processor import FeatureProcessor 5 | 6 | 7 | @register_processor("raw") 8 | class RawProcessor(FeatureProcessor): 9 | """ 10 | Processor that returns the raw value. 11 | """ 12 | 13 | def process(self, value: Any) -> str: 14 | return value 15 | 16 | def size(self): 17 | return None 18 | 19 | def __repr__(self): 20 | return "RawProcessor()" 21 | -------------------------------------------------------------------------------- /pyhealth/processors/sequence_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import torch 4 | 5 | from . import register_processor 6 | from .base_processor import FeatureProcessor 7 | 8 | 9 | @register_processor("sequence") 10 | class SequenceProcessor(FeatureProcessor): 11 | """ 12 | Feature processor for encoding categorical sequences (e.g., medical codes) into numerical indices. 13 | 14 | Supports single or multiple tokens (e.g., single diagnosis or list of procedures). 15 | Can build vocabulary on the fly if not provided. 16 | """ 17 | 18 | def __init__(self): 19 | # -1 for for ease of boolean arithmetic > 0, > -1, etc. 20 | # TODO: this can be a problem if we pass -1 into nn.Embedding 21 | self.code_vocab: Dict[Any, int] = {"": -1, "": 0} 22 | self._next_index = 1 23 | 24 | def process(self, value: Any) -> torch.Tensor: 25 | """Process token value(s) into tensor of indices. 26 | 27 | Args: 28 | value: Raw token string or list of token strings. 29 | 30 | Returns: 31 | Tensor of indices. 32 | """ 33 | indices = [] 34 | for token in value: 35 | if token is None: # missing values 36 | indices.append(self.code_vocab[""]) 37 | else: 38 | if token not in self.code_vocab: 39 | self.code_vocab[token] = self._next_index 40 | self._next_index += 1 41 | indices.append(self.code_vocab[token]) 42 | return torch.tensor(indices, dtype=torch.long) 43 | 44 | def size(self): 45 | return len(self.code_vocab) 46 | 47 | def __repr__(self): 48 | return ( 49 | f"SequenceProcessor(code_vocab_size={len(self.code_vocab)})" 50 | ) 51 | -------------------------------------------------------------------------------- /pyhealth/processors/signal_processor.py: -------------------------------------------------------------------------------- 1 | from .base_processor import FeatureProcessor 2 | 3 | 4 | class SignalProcessor(FeatureProcessor): 5 | pass -------------------------------------------------------------------------------- /pyhealth/processors/text_processor.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from . import register_processor 4 | from .base_processor import FeatureProcessor 5 | 6 | 7 | @register_processor("text") 8 | class TextProcessor(FeatureProcessor): 9 | """ 10 | Processor for text data. 11 | """ 12 | 13 | def process(self, value: Any) -> str: 14 | return value 15 | 16 | def size(self): 17 | return None 18 | 19 | def __repr__(self): 20 | return "TextProcessor()" 21 | -------------------------------------------------------------------------------- /pyhealth/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | from .sage_sampler import NeighborSampler -------------------------------------------------------------------------------- /pyhealth/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_task import BaseTask 2 | from .benchmark_ehrshot import BenchmarkEHRShot 3 | from .cardiology_detect import ( 4 | cardiology_isAD_fn, 5 | cardiology_isAR_fn, 6 | cardiology_isBBBFB_fn, 7 | cardiology_isCD_fn, 8 | cardiology_isWA_fn, 9 | ) 10 | from .covid19_cxr_classification import COVID19CXRClassification 11 | from .drug_recommendation import ( 12 | drug_recommendation_eicu_fn, 13 | drug_recommendation_mimic3_fn, 14 | drug_recommendation_mimic4_fn, 15 | drug_recommendation_omop_fn, 16 | ) 17 | from .EEG_abnormal import EEG_isAbnormal_fn 18 | from .EEG_events import EEG_events_fn 19 | from .in_hospital_mortality_mimic4 import InHospitalMortalityMIMIC4 20 | from .length_of_stay_prediction import ( 21 | length_of_stay_prediction_eicu_fn, 22 | length_of_stay_prediction_mimic3_fn, 23 | length_of_stay_prediction_mimic4_fn, 24 | length_of_stay_prediction_omop_fn, 25 | ) 26 | from .medical_coding import MIMIC3ICD9Coding 27 | from .medical_transcriptions_classification import MedicalTranscriptionsClassification 28 | from .mortality_prediction import ( 29 | MortalityPredictionEICU, 30 | MortalityPredictionEICU2, 31 | MortalityPredictionMIMIC3, 32 | MortalityPredictionMIMIC4, 33 | MortalityPredictionOMOP, 34 | ) 35 | from .patient_linkage import patient_linkage_mimic3_fn 36 | from .readmission_30days_mimic4 import Readmission30DaysMIMIC4 37 | from .readmission_prediction import ( 38 | readmission_prediction_eicu_fn, 39 | readmission_prediction_eicu_fn2, 40 | readmission_prediction_mimic3_fn, 41 | readmission_prediction_mimic4_fn, 42 | readmission_prediction_omop_fn, 43 | ) 44 | from .sleep_staging import ( 45 | sleep_staging_isruc_fn, 46 | sleep_staging_shhs_fn, 47 | sleep_staging_sleepedf_fn, 48 | ) 49 | from .sleep_staging_v2 import SleepStagingSleepEDF 50 | from .temple_university_EEG_tasks import EEG_events_fn, EEG_isAbnormal_fn 51 | -------------------------------------------------------------------------------- /pyhealth/tasks/base_task.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List 3 | 4 | import polars as pl 5 | 6 | 7 | class BaseTask(ABC): 8 | task_name: str 9 | input_schema: Dict[str, str] 10 | output_schema: Dict[str, str] 11 | 12 | def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: 13 | return df 14 | 15 | @abstractmethod 16 | def __call__(self, patient) -> List[Dict]: 17 | raise NotImplementedError -------------------------------------------------------------------------------- /pyhealth/tasks/covid19_cxr_classification.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from .base_task import BaseTask 4 | 5 | 6 | class COVID19CXRClassification(BaseTask): 7 | """A task for classifying chest disease from chest X-ray images. 8 | 9 | This task classifies chest X-ray images into different disease categories. 10 | It expects a single chest X-ray image per patient and returns the 11 | corresponding disease label. 12 | 13 | Attributes: 14 | task_name (str): The name of the task, set to 15 | "COVID19CXRClassification". 16 | input_schema (Dict[str, str]): The input schema specifying the required 17 | input format. Contains a single key "image" with value "image". 18 | output_schema (Dict[str, str]): The output schema specifying the output 19 | format. Contains a single key "disease" with value "multiclass". 20 | """ 21 | 22 | task_name: str = "COVID19CXRClassification" 23 | input_schema: Dict[str, str] = {"image": "image"} 24 | output_schema: Dict[str, str] = {"disease": "multiclass"} 25 | 26 | def __call__(self, patient: Any) -> List[Dict[str, Any]]: 27 | """Process a patient's chest X-ray data to classify COVID-19 status. 28 | 29 | Args: 30 | patient: A patient object containing chest X-ray data. 31 | 32 | Returns: 33 | List[Dict[str, Any]]: A list containing a single dictionary with: 34 | - "image": Path to the chest X-ray image 35 | - "disease": The disease classification label 36 | 37 | Raises: 38 | AssertionError: If the patient has more than one chest X-ray event. 39 | """ 40 | event = patient.get_events(event_type="covid19_cxr") 41 | # There should be only one event 42 | assert len(event) == 1 43 | event = event[0] 44 | image = event.path 45 | disease = event.label 46 | return [{"image": image, "disease": disease}] 47 | -------------------------------------------------------------------------------- /pyhealth/tasks/medical_transcriptions_classification.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | from ..data import Patient 4 | from .base_task import BaseTask 5 | 6 | 7 | class MedicalTranscriptionsClassification(BaseTask): 8 | """Task for classifying medical transcriptions into medical specialties. 9 | 10 | This task takes medical transcription text as input and predicts the 11 | corresponding medical specialty. It processes patient records containing 12 | mtsamples events and extracts transcription and medical specialty 13 | information. 14 | 15 | Attributes: 16 | task_name (str): Name of the task 17 | input_schema (Dict[str, str]): Schema defining input features 18 | output_schema (Dict[str, str]): Schema defining output features 19 | """ 20 | task_name: str = "MedicalTranscriptionsClassification" 21 | input_schema: Dict[str, str] = {"transcription": "text"} 22 | output_schema: Dict[str, str] = {"medical_specialty": "multiclass"} 23 | 24 | def __call__(self, patient: Patient) -> List[Dict[str, Any]]: 25 | """Process a patient record to extract medical transcription samples. 26 | 27 | Args: 28 | patient (Patient): Patient record containing medical 29 | transcription events 30 | 31 | Returns: 32 | List[Dict[str, Any]]: List of samples containing transcription 33 | and medical specialty 34 | """ 35 | event = patient.get_events(event_type="mtsamples") 36 | # There should be only one event 37 | assert len(event) == 1 38 | event = event[0] 39 | 40 | transcription_valid = isinstance(event.transcription, str) 41 | specialty_valid = isinstance(event.medical_specialty, str) 42 | 43 | if transcription_valid and specialty_valid: 44 | sample = { 45 | "id": patient.patient_id, 46 | "transcription": event.transcription, 47 | "medical_specialty": event.medical_specialty, 48 | } 49 | return [sample] 50 | else: 51 | return [] 52 | -------------------------------------------------------------------------------- /pyhealth/tasks/patient_linkage.py: -------------------------------------------------------------------------------- 1 | from pyhealth.data import Patient 2 | 3 | 4 | def patient_linkage_mimic3_fn(patient: Patient): 5 | """ Patient linkage task for the mimic3 dataset. """ 6 | 7 | # exclude patients with less than two visits 8 | if len(patient) < 2: 9 | return [] 10 | 11 | samples = [] 12 | q_visit = patient.get_visit_by_index(len(patient) - 1) 13 | d_visit = patient.get_visit_by_index(len(patient) - 2) 14 | 15 | q_age = (q_visit.encounter_time - patient.birth_datetime).days // 365.25 16 | d_age = (d_visit.encounter_time - patient.birth_datetime).days // 365.25 17 | 18 | # exclude patients under 18 19 | if (q_age < 18) or (d_age < 18): 20 | return [] 21 | 22 | q_conditions = q_visit.get_code_list(table="DIAGNOSES_ICD") 23 | d_conditions = d_visit.get_code_list(table="DIAGNOSES_ICD") 24 | 25 | # exclude patients without conditions 26 | if len(q_conditions) * len(d_conditions) == 0: 27 | return [] 28 | 29 | # identifiers 30 | gender = patient.gender 31 | insurance = q_visit.attr_dict["insurance"] 32 | language = q_visit.attr_dict["language"] 33 | religion = q_visit.attr_dict["religion"] 34 | marital_status = q_visit.attr_dict["marital_status"] 35 | ethnicity = q_visit.attr_dict["ethnicity"] 36 | insurance = "" if insurance != insurance else insurance 37 | language = "" if language != language else language 38 | religion = "" if religion != religion else religion 39 | marital_status = "" if marital_status != marital_status else marital_status 40 | ethnicity = "" if ethnicity != ethnicity else ethnicity 41 | q_identifiers = "+".join( 42 | [gender, insurance, language, religion, marital_status, ethnicity] 43 | ) 44 | 45 | insurance = d_visit.attr_dict["insurance"] 46 | language = d_visit.attr_dict["language"] 47 | religion = d_visit.attr_dict["religion"] 48 | marital_status = d_visit.attr_dict["marital_status"] 49 | ethnicity = d_visit.attr_dict["ethnicity"] 50 | insurance = "" if insurance != insurance else insurance 51 | language = "" if language != language else language 52 | religion = "" if religion != religion else religion 53 | marital_status = "" if marital_status != marital_status else marital_status 54 | ethnicity = "" if ethnicity != ethnicity else ethnicity 55 | d_identifiers = "+".join( 56 | [gender, insurance, language, religion, marital_status, ethnicity] 57 | ) 58 | 59 | samples.append({ 60 | "patient_id": patient.patient_id, 61 | "visit_id": q_visit.visit_id, 62 | "conditions": [""] + q_conditions, 63 | "age": q_age, 64 | "identifiers": q_identifiers, 65 | "d_visit_id": d_visit.visit_id, 66 | "d_conditions": [""] + d_conditions, 67 | "d_age": d_age, 68 | "d_identifiers": d_identifiers, 69 | }) 70 | 71 | return samples 72 | -------------------------------------------------------------------------------- /pyhealth/unittests/test_data/test_data.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import unittest 3 | 4 | from pyhealth.data import Event, Patient 5 | 6 | 7 | class TestEvent(unittest.TestCase): 8 | def setUp(self): 9 | self.event = Event( 10 | type="DIAGNOSES_ICD", 11 | timestamp=datetime(2012, 1, 1, 0, 0), 12 | attr_dict={"code": "428.0", "vocabulary": "ICD9CM"} 13 | ) 14 | 15 | def test_type(self): 16 | self.assertIsInstance(self.event.type, str) 17 | self.assertIsInstance(self.event.timestamp, datetime) 18 | 19 | def test_attr(self): 20 | self.assertEqual(self.event.type, "DIAGNOSES_ICD") 21 | self.assertEqual(self.event.timestamp, datetime(2012, 1, 1, 0, 0)) 22 | attr_dict = self.event.attr_dict 23 | self.assertEqual(attr_dict["code"], "428.0") 24 | self.assertEqual(attr_dict["vocabulary"], "ICD9CM") 25 | 26 | def test_repr_and_str(self): 27 | print(repr(self.event)) 28 | print(str(self.event)) 29 | 30 | 31 | class TestPatient(unittest.TestCase): 32 | def setUp(self): 33 | self.event1 = Event(type="diagnosis", timestamp=datetime(2023, 5, 17)) 34 | self.event2 = Event(type="prescription", timestamp=datetime(2023, 5, 18)) 35 | self.event3 = Event(type="lab_test", timestamp=None) 36 | self.patient = Patient(patient_id="12345", 37 | attr_dict={"name": "John Doe", "age": 45}) 38 | 39 | def test_attr(self): 40 | self.assertEqual(self.patient.patient_id, "12345") 41 | self.assertEqual(self.patient.attr_dict["name"], "John Doe") 42 | self.assertEqual(self.patient.attr_dict["age"], 45) 43 | 44 | def test_methods(self): 45 | self.patient.add_event(self.event1) 46 | self.patient.add_event(self.event2) 47 | self.patient.add_event(self.event3) 48 | self.assertEqual(len(self.patient.events), 3) 49 | self.assertEqual(self.patient.events[0], self.event1) 50 | self.assertEqual(self.patient.events[1], self.event2) 51 | self.assertEqual(self.patient.events[2], self.event3) 52 | 53 | def test_repr_and_str(self): 54 | self.patient.add_event(self.event1) 55 | self.patient.add_event(self.event2) 56 | self.patient.add_event(self.event3) 57 | print(repr(self.patient)) 58 | print(str(self.patient)) 59 | 60 | 61 | if __name__ == "__main__": 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /pyhealth/unittests/test_datasets/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pyhealth.datasets import BaseEHRDataset 3 | 4 | class EHRDatasetStatAssertion: 5 | 6 | def __init__(self, dataset: BaseEHRDataset, eps: float): 7 | self.dataset = dataset 8 | self.eps = eps 9 | # return self # builder 10 | 11 | def assertEHRStats( 12 | self, 13 | expected_num_patients: int, 14 | expected_num_visits: int, 15 | expected_num_visits_per_patient: float, 16 | expected_events_per_visit_per_table: List[float], 17 | ): 18 | self.assertNumPatients(expected_num_patients) 19 | self.assertNumVisits(expected_num_visits) 20 | self.assertMeanVisitsPerPatient(expected_num_visits_per_patient) 21 | self.assertTableMeans(expected_events_per_visit_per_table) 22 | 23 | def assertNumPatients(self, expected: int): 24 | actual = len(self.dataset.patients) 25 | if expected != actual: 26 | raise AssertionError(f"Expected {expected} patients got {actual}") 27 | 28 | def assertNumVisits(self, expected: int): 29 | actual = sum([len(patient) for patient in self.dataset.patients.values()]) 30 | if expected != actual: 31 | raise AssertionError(f"Expected {expected} num visits got {actual}") 32 | 33 | def assertMeanVisitsPerPatient(self, expected: int): 34 | actual_visits = [len(patient) for patient in self.dataset.patients.values()] 35 | actual = sum(actual_visits) / len(actual_visits) 36 | if abs(expected - actual) > self.eps: 37 | raise AssertionError(f"Expected {expected} mean num visits got {actual}") 38 | 39 | 40 | # expected list must be ordered by tables 41 | def assertTableMeans(self, expected_per_table: List[float]): 42 | for expected_value, table in zip(expected_per_table, self.dataset.tables): 43 | actual_num_events = [ 44 | len(v.get_event_list(table)) 45 | for p in self.dataset.patients.values() 46 | for v in p 47 | ] 48 | 49 | actual_value = sum(actual_num_events) / len(actual_num_events) 50 | 51 | if abs(expected_value - actual_value) > self.eps: 52 | raise AssertionError(f"Expected {expected_value} mean for events in {table} got {actual_value}") -------------------------------------------------------------------------------- /pyhealth/unittests/test_medcode.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import sys 3 | import os 4 | current = os.path.dirname(os.path.realpath(__file__)) 5 | sys.path.append(os.path.dirname(os.path.dirname(current))) 6 | 7 | from pyhealth.medcode import InnerMap, CrossMap 8 | 9 | class TestInnerMap(unittest.TestCase): 10 | 11 | def setUp(self): 12 | map_name = "ICD9CM" 13 | self.inner_map = InnerMap.load(map_name) 14 | 15 | def test_contain(self): 16 | self.assertTrue( 17 | "428.0" in self.inner_map, 18 | msg="contain function of InnerMap failed" 19 | ) 20 | 21 | def test_lookup(self): 22 | self.assertEqual( 23 | self.inner_map.lookup("428.0"), 24 | 'Congestive heart failure, unspecified', 25 | msg="lookup function of InnerMap failed" 26 | ) 27 | 28 | def test_get_ancestors(self): 29 | self.assertEqual( 30 | self.inner_map.get_ancestors("428.0"), 31 | ['428', '420-429.99', '390-459.99', '001-999.99'], 32 | msg="get_ancestors function of InnerMap failed" 33 | ) 34 | 35 | def test_get_descendants(self): 36 | self.assertEqual( 37 | self.inner_map.get_descendants("428"), 38 | ['428.0', '428.1', '428.2', '428.3', '428.4', '428.9', '428.20', '428.21', '428.22', '428.23', '428.30', '428.31', '428.32', '428.33', '428.40', '428.41', '428.42', '428.43'], 39 | msg="get_descendants function of InnerMap failed" 40 | ) 41 | 42 | 43 | class TestInnerMapATC(unittest.TestCase): 44 | def setUp(self): 45 | self.inner_map = InnerMap.load("ATC") 46 | 47 | def test_lookup(self): 48 | self.assertEqual( 49 | self.inner_map.lookup("M01AE51"), 50 | 'ibuprofen, combinations', 51 | msg="lookup function of InnerMap (ATC) failed" 52 | ) 53 | self.assertEqual( 54 | self.inner_map.lookup("M01AE51", "drugbank_id"), 55 | 'DB01050', 56 | msg="lookup function of InnerMap (ATC) failed" 57 | ) 58 | self.assertEqual( 59 | self.inner_map.lookup("M01AE51", "smiles"), 60 | 'CC(C)CC1=CC=C(C=C1)C(C)C(O)=O', 61 | msg="lookup function of InnerMap (ATC) failed" 62 | ) 63 | 64 | 65 | def test_convert(self): 66 | self.assertEqual( 67 | self.inner_map.convert("A12CE02", level=3), 68 | "A12C", 69 | msg="convert function of InnerMap (ATC) failed" 70 | ) 71 | 72 | 73 | class TestCrossMap(unittest.TestCase): 74 | def setUp(self): 75 | self.cross_map = CrossMap.load(source_vocabulary="ICD9CM", target_vocabulary="CCSCM") 76 | 77 | def test_map(self): 78 | self.assertEqual( 79 | self.cross_map.map("428.0"), 80 | ["108"], 81 | msg="map function of CrossMap failed" 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | unittest.main() 87 | -------------------------------------------------------------------------------- /pyhealth/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def set_seed(seed): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | if torch.cuda.is_available(): 15 | torch.cuda.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | os.environ["PYTHONHASHSEED"] = str(seed) 20 | 21 | 22 | def create_directory(directory): 23 | if not os.path.exists(directory): 24 | os.makedirs(directory) 25 | 26 | 27 | def load_pickle(filename): 28 | with open(filename, "rb") as f: 29 | return pickle.load(f) 30 | 31 | 32 | def save_pickle(data, filename): 33 | with open(filename, "wb") as f: 34 | pickle.dump(data, f) 35 | 36 | 37 | def load_json(filename): 38 | with open(filename, "r") as f: 39 | return json.load(f) 40 | 41 | 42 | def save_json(data, filename): 43 | with open(filename, "w") as f: 44 | json.dump(data, f) 45 | 46 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.8" 7 | 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | formats: 12 | - epub 13 | - htmlzip 14 | 15 | python: 16 | install: 17 | - method: pip 18 | path: . 19 | - requirements: docs/requirements.txt 20 | 21 | submodules: 22 | include: all -------------------------------------------------------------------------------- /requirements-nlp.txt: -------------------------------------------------------------------------------- 1 | numpy~=1.26.4 2 | pandas~=2.1.4 3 | scikit-learn~=1.3.2 4 | torch~=2.2.2 5 | transformers~=4.48.3 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.8.0 2 | torchvision>=0.9.0 3 | rdkit>=2022.03.4 4 | scikit-learn>=0.24.2 5 | networkx>=2.6.3 6 | pandas>=1.3.2,<2 7 | pandarallel>=1.5.3 8 | mne>=1.0.3 9 | urllib3<=1.26.15 10 | numpy 11 | tqdm 12 | polars 13 | transformers -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.rst 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup, Extension 2 | 3 | # read the contents of README file 4 | from os import path 5 | from io import open # for Python 2 and 3 compatibility 6 | 7 | # # get __version__ from _version.py 8 | # ver_file = path.join('pyhealth', 'version.py') 9 | # with open(ver_file) as f: 10 | # exec(f.read()) 11 | 12 | this_directory = path.abspath(path.dirname(__file__)) 13 | 14 | 15 | # read the contents of README.rst 16 | def readme(): 17 | with open(path.join(this_directory, "README.rst"), encoding="utf-8") as f: 18 | return f.read() 19 | 20 | 21 | # read the contents of requirements.txt 22 | with open(path.join(this_directory, "requirements.txt"), encoding="utf-8") as f: 23 | requirements = f.read().splitlines() 24 | 25 | VERSION = "1.1.4" 26 | 27 | setup( 28 | name="pyhealth", 29 | version=VERSION, 30 | description="A Python library for healthcare AI", 31 | long_description=readme(), 32 | long_description_content_type="text/x-rst", 33 | author="Chaoqi Yang, Zhenbang Wu, Patrick Jiang, Zhen Lin, Benjamin Danek, Junyi Gao, Jimeng Sun", 34 | author_email="chaoqiy2@illinois.edu", 35 | url="https://github.com/sunlabuiuc/pyhealth", 36 | keywords=[ 37 | "heathcare AI", 38 | "healthcare", 39 | "electronic health records", 40 | "EHRs", 41 | "machine learning", 42 | "data mining", 43 | "neural networks", 44 | "deep learning", 45 | ], 46 | packages=find_packages(exclude=["test"]), 47 | include_package_data=True, 48 | install_requires=requirements, 49 | setup_requires=["setuptools>=38.6.0"], 50 | classifiers=[ 51 | "Development Status :: 3 - Alpha", 52 | "Intended Audience :: Education", 53 | "Intended Audience :: Financial and Insurance Industry", 54 | "Intended Audience :: Science/Research", 55 | "Intended Audience :: Developers", 56 | "Intended Audience :: Information Technology", 57 | "License :: OSI Approved :: BSD License", 58 | "Programming Language :: Python :: 3.8", 59 | ], 60 | ) 61 | -------------------------------------------------------------------------------- /test-resources/nlp/metrics.csv: -------------------------------------------------------------------------------- 1 | bleu,levenshtein,rouge1_f_score,rouge1_precision,rouge1_recall,rouge2_f_score,rouge2_precision,rouge2_recall,rouge3_f_score,rouge3_precision,rouge3_recall,rouge4_f_score,rouge4_precision,rouge4_recall,rouge5_f_score,rouge5_precision,rouge5_recall,rouge6_f_score,rouge6_precision,rouge6_recall,rouge7_f_score,rouge7_precision,rouge7_recall,rouge8_f_score,rouge8_precision,rouge8_recall,rouge9_f_score,rouge9_precision,rouge9_recall,rougeL_f_score,rougeL_precision,rougeL_recall 2 | 1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0 3 | 0.658,0.9592,0.9474,0.9,1.0,0.8235,0.7778,0.875,0.6667,0.625,0.7143,0.4615,0.4286,0.5,0.3636,0.3333,0.4,0.2222,0.2,0.25,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.9474,0.9,1.0 4 | -------------------------------------------------------------------------------- /tests/base.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | import unittest 3 | import logging 4 | 5 | 6 | class BaseTestCase(unittest.TestCase): 7 | @staticmethod 8 | def _setup_logging(level: int = logging.INFO): 9 | logging.basicConfig(level=level) 10 | 11 | @classmethod 12 | def _set_debug(cls: Type): 13 | cls._setup_logging(logging.DEBUG) 14 | print() 15 | print('_' * 80) 16 | -------------------------------------------------------------------------------- /tests/nlp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunlabuiuc/PyHealth/7a0a86c9d62b1347b95a00b7398c7e7146b2e37b/tests/nlp/__init__.py -------------------------------------------------------------------------------- /tests/nlp/test_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import logging 3 | from base import BaseTestCase 4 | from pathlib import Path 5 | import pandas as pd 6 | from pyhealth.nlp.metrics import ( 7 | LevenshteinDistanceScoreMethod, 8 | ScoreContext, ScoreSet, ScoreResult, Scorer 9 | ) 10 | 11 | 12 | class TestMetrics(BaseTestCase): 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | # silence info messages from rouge package 16 | self._setup_logging(logging.WARNING) 17 | 18 | def setUp(self): 19 | s1: str = 'The boy threw the ball. He practiced every day.' 20 | s2: str = 'The boy threw X the ball. He practiced every day.' 21 | self.pairs: List[List[str]] = [[s1, s1], [s1, s2]] 22 | 23 | def test_object_graph(self): 24 | # configure only the edit distance method 25 | scorer = Scorer( 26 | methods={'editdistance': LevenshteinDistanceScoreMethod()}, 27 | ) 28 | ss: ScoreSet = scorer.score(ScoreContext(self.pairs)) 29 | self.assertEqual(ScoreSet, type(ss)) 30 | self.assertEqual(2, len(ss.results)) 31 | res1: ScoreResult = ss.results[0] 32 | self.assertEqual(ScoreResult, type(res1)) 33 | self.assertEqual(1, len(res1.scores)) 34 | self.assertTrue('editdistance' in res1.scores) 35 | self.assertEqual(1., res1.scores['editdistance'].value) 36 | 37 | def test_pandas(self): 38 | WRITE: bool = 0 39 | should_file: Path = Path('test-resources/nlp/metrics.csv') 40 | self._set_debug() 41 | scorer = Scorer() 42 | ss: ScoreSet = scorer.score(ScoreContext(self.pairs)) 43 | df: pd.DataFrame = ss.as_dataframe() 44 | # give tolarance for arch high sig digits that might be off by epsilon 45 | df = df.round(4) 46 | if WRITE: 47 | should_file.parent.mkdir(parents=True, exist_ok=True) 48 | df.to_csv(should_file, index=False) 49 | should: pd.DataFrame = pd.read_csv(should_file) 50 | self.assertEqual(should.to_string(), df.to_string()) 51 | --------------------------------------------------------------------------------