├── .codecov.yml ├── .github └── workflows │ ├── ci.yml │ ├── codecov.yml │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── docs ├── EventTimesSampler.ipynb ├── ModelsComparison.ipynb ├── PaperCodeFigure.ipynb ├── PerformanceMeasures.ipynb ├── Regularization.ipynb ├── Simple Simulation.ipynb ├── SimulatedDataset.ipynb ├── UsageExample-DataPreparation.ipynb ├── UsageExample-FittingDataExpansionFitter-FULL.ipynb ├── UsageExample-FittingDataExpansionFitter.ipynb ├── UsageExample-FittingTwoStagesFitter-FULL.ipynb ├── UsageExample-FittingTwoStagesFitter.ipynb ├── UsageExample-FittingTwoStagesFitterExact-FULL.ipynb ├── UsageExample-Intro.ipynb ├── UsageExample-RegroupingData.ipynb ├── UsageExample-SIS-SIS-L.ipynb ├── User Story.ipynb ├── api │ ├── cross_validation.md │ ├── data_expansion_fitter.md │ ├── evaluation.md │ ├── event_times_sampler.md │ ├── model_selection.md │ ├── screening.md │ ├── two_stages_fitter.md │ ├── two_stages_fitter_exact.md │ └── utils.md ├── dtsicon.svg ├── icon.png ├── index.md ├── intro.md ├── jss_replication.py ├── methods.md ├── methodsevaluation.md ├── methodsintro.md ├── models_params.png └── requirements.txt ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml ├── src └── pydts │ ├── __init__.py │ ├── base_fitters.py │ ├── config.py │ ├── cross_validation.py │ ├── data_generation.py │ ├── datasets │ └── LOS_simulated_data.csv │ ├── evaluation.py │ ├── examples_utils │ ├── __init__.py │ ├── datasets.py │ ├── generate_simulations_data.py │ ├── mimic_consts.py │ ├── plots.py │ └── simulations_data_config.py │ ├── fitters.py │ ├── model_selection.py │ ├── screening.py │ └── utils.py └── tests ├── __init__.py ├── test_DataExpansionFitter.py ├── test_EventTimesSampler.py ├── test_TwoStagesFitter.py ├── test_TwoStagesFitterExact.py ├── test_basefitter.py ├── test_cross_validation.py ├── test_evaluation.py ├── test_model_selection.py ├── test_repetative_fitter.py └── test_screening.py /.codecov.yml: -------------------------------------------------------------------------------- 1 | # 2 | # This codecov.yml is the default configuration for 3 | # all repositories on Codecov. You may adjust the settings 4 | # below in your own codecov.yml in your repository. 5 | # 6 | coverage: 7 | precision: 2 8 | round: down 9 | range: 70...100 10 | 11 | status: 12 | # Learn more at https://docs.codecov.io/docs/commit-status 13 | project: 14 | default: 15 | threshold: 1% # allow this much decrease on project 16 | app: 17 | target: 70% 18 | flags: 19 | - app 20 | modules: 21 | target: 70% 22 | flags: 23 | - modules 24 | client: 25 | flags: 26 | - client 27 | changes: false 28 | 29 | comment: 30 | layout: "reach, diff, files" 31 | behavior: default # update if exists else create new 32 | require_changes: true 33 | 34 | flags: 35 | app: 36 | paths: 37 | - "app/" 38 | - "baseapp/" 39 | modules: 40 | paths: 41 | - "x/" 42 | - "!x/**/client/" # ignore client package 43 | client: 44 | paths: 45 | - "client/" 46 | - "x/**/client/" 47 | 48 | ignore: 49 | - "docs" 50 | - "*.md" 51 | - "*.rst" 52 | - "**/*.pb.go" 53 | - "types/*.pb.go" 54 | - "tests/*" 55 | - "tests/**/*" 56 | - "x/**/*.pb.go" 57 | - "x/**/test_common.go" 58 | - "scripts/" 59 | - "contrib" 60 | - "src/pydts/examples_utils" 61 | - "src/pydts/utils.py" 62 | - "src/pydts/config.py" -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - uses: actions/setup-node@v4 13 | with: 14 | node-version: 16 15 | - uses: actions/setup-python@v4 16 | with: 17 | python-version: '3.10' 18 | - run: sudo apt-get update && sudo apt-get install -y gettext 19 | - run: pip install poetry 20 | - run: poetry install 21 | - run: poetry run mkdocs gh-deploy --force 22 | 23 | 24 | -------------------------------------------------------------------------------- /.github/workflows/codecov.yml: -------------------------------------------------------------------------------- 1 | name: Code cov 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | name: Code cov 9 | steps: 10 | - uses: actions/checkout@v4 11 | - uses: actions/setup-python@v2 12 | with: 13 | python-version: '3.10' 14 | - name: Install requirements 15 | run: | 16 | sudo apt-get update && sudo apt-get install -y gettext 17 | pip install poetry 18 | poetry install 19 | - name: Run tests and collect coverage 20 | run: poetry run pytest tests/ --cov=./ --cov-report=xml 21 | - name: Upload coverage reports to Codecov with GitHub Action 22 | uses: codecov/codecov-action@v5 23 | env: 24 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 25 | 26 | #name: Codecov 27 | #on: [push, pull_request] 28 | #jobs: 29 | # run: 30 | # runs-on: ${{ matrix.os }} 31 | # strategy: 32 | # matrix: 33 | # os: [ubuntu-latest] 34 | # env: 35 | # OS: ${{ matrix.os }} 36 | # PYTHON: '3.9' 37 | # steps: 38 | # - uses: actions/checkout@master 39 | # - name: Setup Python 40 | # uses: actions/setup-python@master 41 | # with: 42 | # python-version: 3.9 43 | # - name: Generate coverage report 44 | # run: | 45 | # pip install poetry==1.5.1 46 | # poetry install 47 | # poetry run pytest --cov=./ --cov-report=xml 48 | # - name: Upload coverage to Codecov 49 | # uses: codecov/codecov-action@v2 50 | # with: 51 | # version: "v0.1.15" 52 | # token: ${{ secrets.CODECOV_TOKEN }} 53 | # directory: ./coverage/reports/ 54 | # env_vars: OS,PYTHON 55 | # fail_ci_if_error: true 56 | # # files: ./coverage1.xml,./coverage2.xml 57 | # files: ./coverage.xml 58 | # flags: unittests 59 | # name: codecov-umbrella 60 | # path_to_write_report: ./coverage/codecov_report.txt 61 | # verbose: true 62 | 63 | 64 | #name: Codecov 65 | #on: [push, pull_request] 66 | #jobs: 67 | # test-coverage: 68 | # runs-on: ubuntu-latest 69 | # timeout-minutes: 120 70 | # strategy: 71 | # matrix: 72 | # python-version: ['3.10'] 73 | # steps: 74 | # - uses: actions/checkout@v4 75 | # - name: Set up Python ${{ matrix.python-version }} 76 | # uses: actions/setup-python@v4 77 | # with: 78 | # python-version: ${{ matrix.python-version }} 79 | # - name: Install dependencies 80 | # run: | 81 | # sudo apt-get update && sudo apt-get install -y gettext 82 | # pip install poetry 83 | # poetry install 84 | # - name: Run tests with coverage 85 | # run: | 86 | # # Start a background loop to print a message every 60 seconds 87 | # while true; do echo "Running tests... still working at $(date)"; sleep 60; done & 88 | # keep_alive=$! 89 | # 90 | # # Run tests 91 | # poetry run pytest tests/ -v --cov=src/pydts --cov-report=xml 92 | # test_status=$? 93 | # 94 | # # Stop keep-alive loop 95 | # kill $keep_alive 96 | # 97 | # # Return pytest exit code 98 | # exit $test_status 99 | # - name: Upload coverage to Codecov 100 | # uses: codecov/codecov-action@v2 101 | # with: 102 | # version: "v0.1.15" 103 | # token: ${{ secrets.CODECOV_TOKEN }} 104 | # directory: ./coverage/reports/ 105 | # env_vars: OS,PYTHON 106 | # fail_ci_if_error: true 107 | # # files: ./coverage1.xml,./coverage2.xml 108 | # files: ./coverage.xml 109 | # flags: unittests 110 | # name: codecov-umbrella 111 | # path_to_write_report: ./coverage/codecov_report.txt 112 | # verbose: true 113 | ## - name: Upload coverage reports to Codecov 114 | ## uses: codecov/codecov-action@v2 115 | ## env: 116 | ## CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 117 | ## CODECOV_VERSION: 'v0.1.15' 118 | ## with: 119 | ## files: coverage.xml 120 | ## token: ${{ secrets.CODECOV_TOKEN }} 121 | ## version: ${{ env.CODECOV_VERSION }} 122 | ## fail_ci_if_error: true -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | on: [push, pull_request] 3 | jobs: 4 | test-DataExpansionFitter: 5 | runs-on: ubuntu-latest 6 | timeout-minutes: 60 7 | strategy: 8 | matrix: 9 | python-version: ['3.10'] 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python ${{ matrix.python-version }} 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: ${{ matrix.python-version }} 16 | - name: Setup 17 | run: | 18 | sudo apt-get update && sudo apt-get install -y gettext 19 | pip install poetry 20 | poetry install 21 | - name: Run tests 22 | run: poetry run pytest tests/test_DataExpansionFitter.py -v 23 | 24 | test-EventTimesSampler: 25 | runs-on: ubuntu-latest 26 | timeout-minutes: 60 27 | strategy: 28 | matrix: 29 | python-version: ['3.10'] 30 | steps: 31 | - uses: actions/checkout@v4 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v4 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | - name: Setup 37 | run: | 38 | sudo apt-get update && sudo apt-get install -y gettext 39 | pip install poetry 40 | poetry install 41 | - name: Run tests 42 | run: poetry run pytest tests/test_EventTimesSampler.py -v 43 | 44 | test-TwoStagesFitter: 45 | runs-on: ubuntu-latest 46 | timeout-minutes: 60 47 | strategy: 48 | matrix: 49 | python-version: ['3.10'] 50 | steps: 51 | - uses: actions/checkout@v4 52 | - name: Set up Python ${{ matrix.python-version }} 53 | uses: actions/setup-python@v4 54 | with: 55 | python-version: ${{ matrix.python-version }} 56 | - name: Setup 57 | run: | 58 | sudo apt-get update && sudo apt-get install -y gettext 59 | pip install poetry 60 | poetry install 61 | - name: Run tests 62 | run: poetry run pytest tests/test_TwoStagesFitter.py -v 63 | 64 | test-TwoStagesFitterExact: 65 | runs-on: ubuntu-latest 66 | timeout-minutes: 120 67 | strategy: 68 | matrix: 69 | python-version: ['3.10'] 70 | steps: 71 | - uses: actions/checkout@v4 72 | - name: Set up Python ${{ matrix.python-version }} 73 | uses: actions/setup-python@v4 74 | with: 75 | python-version: ${{ matrix.python-version }} 76 | - name: Setup 77 | run: | 78 | sudo apt-get update && sudo apt-get install -y gettext 79 | pip install poetry 80 | poetry install 81 | - name: Run tests 82 | run: poetry run pytest tests/test_TwoStagesFitterExact.py -v 83 | 84 | test-screening: 85 | runs-on: ubuntu-latest 86 | timeout-minutes: 60 87 | strategy: 88 | matrix: 89 | python-version: ['3.10'] 90 | steps: 91 | - uses: actions/checkout@v4 92 | - name: Set up Python ${{ matrix.python-version }} 93 | uses: actions/setup-python@v4 94 | with: 95 | python-version: ${{ matrix.python-version }} 96 | - name: Setup 97 | run: | 98 | sudo apt-get update && sudo apt-get install -y gettext 99 | pip install poetry 100 | poetry install 101 | - name: Run tests 102 | run: poetry run pytest tests/test_screening.py -v 103 | 104 | test-model-selection: 105 | runs-on: ubuntu-latest 106 | timeout-minutes: 120 107 | strategy: 108 | matrix: 109 | python-version: ['3.10'] 110 | steps: 111 | - uses: actions/checkout@v4 112 | - name: Set up Python ${{ matrix.python-version }} 113 | uses: actions/setup-python@v4 114 | with: 115 | python-version: ${{ matrix.python-version }} 116 | - name: Setup 117 | run: | 118 | sudo apt-get update && sudo apt-get install -y gettext 119 | pip install poetry 120 | poetry install 121 | - name: Run tests 122 | run: poetry run pytest tests/test_model_selection.py -v 123 | 124 | test-remaining: 125 | runs-on: ubuntu-latest 126 | timeout-minutes: 120 127 | strategy: 128 | matrix: 129 | python-version: ['3.10'] 130 | steps: 131 | - uses: actions/checkout@v4 132 | - name: Set up Python ${{ matrix.python-version }} 133 | uses: actions/setup-python@v4 134 | with: 135 | python-version: ${{ matrix.python-version }} 136 | - name: Setup 137 | run: | 138 | sudo apt-get update && sudo apt-get install -y gettext 139 | pip install poetry 140 | poetry install 141 | - name: Run remaining tests 142 | run: | 143 | poetry run pytest tests/ \ 144 | --ignore=tests/test_DataExpansionFitter.py \ 145 | --ignore=tests/test_EventTimesSampler.py \ 146 | --ignore=tests/test_TwoStagesFitter.py \ 147 | --ignore=tests/test_TwoStagesFitterExact.py \ 148 | --ignore=tests/test_screening.py \ 149 | --ignore=tests/test_model_selection.py \ 150 | -v 151 | 152 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | output* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | PyDTS - A python package for discrete-time survival-analysis with competing risks 2 | Copyright (C) 2022 Tomer Meir, Rom Gutman, and Malka Gorfine 3 | 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | 9 | This program is distributed in the hope that it will be useful, 10 | but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | GNU General Public License for more details. 13 | 14 | You should have received a copy of the GNU General Public License 15 | along with this program. If not, see . 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![pypi version](https://img.shields.io/pypi/v/pydts)](https://pypi.org/project/pydts/) 2 | [![Tests](https://github.com/tomer1812/pydts/workflows/Tests/badge.svg)](https://github.com/tomer1812/pydts/actions?workflow=Tests) 3 | [![documentation](https://img.shields.io/badge/docs-mkdocs%20material-blue.svg?style=flat)](https://tomer1812.github.io/pydts) 4 | [![codecov](https://codecov.io/gh/tomer1812/pydts/branch/main/graph/badge.svg)](https://codecov.io/gh/tomer1812/pydts) 5 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15296343.svg)](https://doi.org/10.5281/zenodo.15296343) 6 | 7 | # Discrete Time Survival Analysis 8 | A Python package for discrete-time survival data analysis with competing risks. 9 | 10 | ![PyDTS](docs/icon.png) 11 | 12 | [Tomer Meir](https://tomer1812.github.io/), [Rom Gutman](https://github.com/RomGutman), [Malka Gorfine](https://www.tau.ac.il/~gorfinem/) 2022 13 | 14 | [Documentation](https://tomer1812.github.io/pydts/) 15 | 16 | ## Installation 17 | ```console 18 | pip install pydts 19 | ``` 20 | 21 | ## Quick Start 22 | 23 | ```python 24 | from pydts.fitters import TwoStagesFitter 25 | from pydts.examples_utils.generate_simulations_data import generate_quick_start_df 26 | 27 | patients_df = generate_quick_start_df(n_patients=10000, n_cov=5, d_times=14, j_events=2, pid_col='pid', seed=0) 28 | 29 | fitter = TwoStagesFitter() 30 | fitter.fit(df=patients_df.drop(['C', 'T'], axis=1)) 31 | fitter.print_summary() 32 | ``` 33 | 34 | ## Examples 35 | 1. [Usage Example](https://tomer1812.github.io/pydts/UsageExample-Intro/) 36 | 2. [Hospital Length of Stay Simulation Example](https://tomer1812.github.io/pydts/SimulatedDataset/) 37 | 38 | ## Citations 39 | If you found PyDTS software useful to your research, please cite the papers: 40 | 41 | ```bibtex 42 | @article{Meir_PyDTS_2022, 43 | author = {Meir, Tomer and Gutman, Rom, and Gorfine, Malka}, 44 | doi = {10.48550/arXiv.2204.05731}, 45 | title = {{PyDTS: A Python Package for Discrete Time Survival Analysis with Competing Risks}}, 46 | url = {https://arxiv.org/abs/2204.05731}, 47 | year = {2022} 48 | } 49 | 50 | @article{Meir_Gorfine_DTSP_2025, 51 | author = {Meir, Tomer and Gorfine, Malka}, 52 | doi = {10.1093/biomtc/ujaf040}, 53 | title = {{Discrete-Time Competing-Risks Regression with or without Penalization}}, 54 | year = {2025}, 55 | journal = {Biometrics}, 56 | volume = {81}, 57 | number = {2}, 58 | url = {https://academic.oup.com/biometrics/article/81/2/ujaf040/8120014}, 59 | } 60 | ``` 61 | 62 | and please consider starring the project [on GitHub](https://github.com/tomer1812/pydts) 63 | 64 | ## How to Contribute 65 | 1. Open Github issues to suggest new features or to report bugs\errors 66 | 2. Contact Tomer or Rom if you want to add a usage example to the documentation 67 | 3. If you want to become a developer (thank you, we appreciate it!) - please contact Tomer or Rom for developers' on-boarding 68 | 69 | Tomer Meir: tomer1812@gmail.com, Rom Gutman: rom.gutman1@gmail.com -------------------------------------------------------------------------------- /docs/PaperCodeFigure.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "ee1b61c5", 7 | "metadata": { 8 | "ExecuteTime": { 9 | "end_time": "2022-05-18T19:07:03.674581Z", 10 | "start_time": "2022-05-18T19:03:12.578523Z" 11 | }, 12 | "scrolled": false 13 | }, 14 | "outputs": [], 15 | "source": [ 16 | "import pandas as pd\n", 17 | "import numpy as np\n", 18 | "from sklearn.model_selection import train_test_split\n", 19 | "from pydts.examples_utils.generate_simulations_data import generate_quick_start_df\n", 20 | "\n", 21 | "# Data Generation\n", 22 | "real_coef_dict = {\n", 23 | " \"alpha\": {\n", 24 | " 1: lambda t: -1 - 0.3 * np.log(t),\n", 25 | " 2: lambda t: -1.75 - 0.15 * np.log(t)},\n", 26 | " \"beta\": {\n", 27 | " 1: -np.log([0.8, 3, 3, 2.5, 2]),\n", 28 | " 2: -np.log([1, 3, 4, 3, 2])}}\n", 29 | "\n", 30 | "patients_df = generate_quick_start_df(n_patients=50000, n_cov=5, d_times=30, j_events=2, pid_col='pid', \n", 31 | " seed=0, censoring_prob=0.8, real_coef_dict=real_coef_dict)\n", 32 | "\n", 33 | "train_df, test_df = train_test_split(patients_df, test_size=0.2)\n", 34 | "\n", 35 | "# DataExpansionFitter Usage\n", 36 | "from pydts.fitters import DataExpansionFitter\n", 37 | "fitter = DataExpansionFitter()\n", 38 | "fitter.fit(df=train_df.drop(['C', 'T'], axis=1))\n", 39 | "\n", 40 | "pred_df = fitter.predict_cumulative_incident_function(test_df.drop(['J', 'T', 'C', 'X'], axis=1))\n", 41 | "\n", 42 | "# TwoStagesFitter Usage\n", 43 | "from pydts.fitters import TwoStagesFitter\n", 44 | "new_fitter = TwoStagesFitter()\n", 45 | "new_fitter.fit(df=train_df.drop(['C', 'T'], axis=1))\n", 46 | "\n", 47 | "pred_df = fitter.predict_cumulative_incident_function(test_df.drop(['J', 'T', 'C', 'X'], axis=1))\n", 48 | "\n", 49 | "# Training with Regularization\n", 50 | "L1_regularized_fitter = TwoStagesFitter()\n", 51 | "fit_beta_kwargs = {'model_kwargs': {'penalizer': 0.003, 'l1_ratio': 1}}\n", 52 | "L1_regularized_fitter.fit(df=train_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)\n", 53 | "\n", 54 | "L2_regularized_fitter = TwoStagesFitter()\n", 55 | "fit_beta_kwargs = {'model_kwargs': {'penalizer': 0.003, 'l1_ratio': 0}}\n", 56 | "L2_regularized_fitter.fit(df=train_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)\n", 57 | "\n", 58 | "EN_regularized_fitter = TwoStagesFitter()\n", 59 | "fit_beta_kwargs = {'model_kwargs': {'penalizer': 0.003, 'l1_ratio': 0.5}}\n", 60 | "EN_regularized_fitter.fit(df=train_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "b05f4b71", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [] 70 | } 71 | ], 72 | "metadata": { 73 | "kernelspec": { 74 | "display_name": "Python 3 (ipykernel)", 75 | "language": "python", 76 | "name": "python3" 77 | }, 78 | "language_info": { 79 | "codemirror_mode": { 80 | "name": "ipython", 81 | "version": 3 82 | }, 83 | "file_extension": ".py", 84 | "mimetype": "text/x-python", 85 | "name": "python", 86 | "nbconvert_exporter": "python", 87 | "pygments_lexer": "ipython3", 88 | "version": "3.8.2" 89 | } 90 | }, 91 | "nbformat": 4, 92 | "nbformat_minor": 5 93 | } 94 | -------------------------------------------------------------------------------- /docs/api/cross_validation.md: -------------------------------------------------------------------------------- 1 | ::: pydts.cross_validation.TwoStagesCV 2 | ::: pydts.cross_validation.TwoStagesCVExact 3 | ::: pydts.cross_validation.PenaltyGridSearchCV 4 | ::: pydts.cross_validation.PenaltyGridSearchCVExact 5 | -------------------------------------------------------------------------------- /docs/api/data_expansion_fitter.md: -------------------------------------------------------------------------------- 1 | ::: pydts.fitters.DataExpansionFitter -------------------------------------------------------------------------------- /docs/api/evaluation.md: -------------------------------------------------------------------------------- 1 | ::: pydts.evaluation -------------------------------------------------------------------------------- /docs/api/event_times_sampler.md: -------------------------------------------------------------------------------- 1 | ::: pydts.data_generation.EventTimesSampler -------------------------------------------------------------------------------- /docs/api/model_selection.md: -------------------------------------------------------------------------------- 1 | ::: pydts.model_selection.PenaltyGridSearch 2 | ::: pydts.model_selection.PenaltyGridSearchExact -------------------------------------------------------------------------------- /docs/api/screening.md: -------------------------------------------------------------------------------- 1 | ::: pydts.screening.SISTwoStagesFitter 2 | ::: pydts.screening.SISTwoStagesFitterExact -------------------------------------------------------------------------------- /docs/api/two_stages_fitter.md: -------------------------------------------------------------------------------- 1 | ::: pydts.fitters.TwoStagesFitter -------------------------------------------------------------------------------- /docs/api/two_stages_fitter_exact.md: -------------------------------------------------------------------------------- 1 | ::: pydts.fitters.TwoStagesFitterExact -------------------------------------------------------------------------------- /docs/api/utils.md: -------------------------------------------------------------------------------- 1 | ::: pydts.utils.get_expanded_df -------------------------------------------------------------------------------- /docs/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomer1812/pydts/578929a457c111efe009d3461aab531b793b33d0/docs/icon.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | [![pypi version](https://img.shields.io/pypi/v/pydts)](https://pypi.org/project/pydts/) 2 | [![Tests](https://github.com/tomer1812/pydts/workflows/Tests/badge.svg)](https://github.com/tomer1812/pydts/actions?workflow=Tests) 3 | [![documentation](https://img.shields.io/badge/docs-mkdocs%20material-blue.svg?style=flat)](https://tomer1812.github.io/pydts) 4 | [![codecov](https://codecov.io/gh/tomer1812/pydts/branch/main/graph/badge.svg)](https://codecov.io/gh/tomer1812/pydts) 5 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.15296343.svg)](https://doi.org/10.5281/zenodo.15296343) 6 | 7 | # Discrete Time Survival Analysis 8 | A Python package for discrete-time survival data analysis with competing risks. 9 | 10 | ![PyDTS](icon.png) 11 | 12 | [Tomer Meir](https://tomer1812.github.io/), [Rom Gutman](https://github.com/RomGutman), [Malka Gorfine](https://www.tau.ac.il/~gorfinem/) 2022 13 | 14 | ## Installation 15 | ```console 16 | pip install pydts 17 | ``` 18 | 19 | ## Quick Start 20 | 21 | ```python 22 | from pydts.fitters import TwoStagesFitter 23 | from pydts.examples_utils.generate_simulations_data import generate_quick_start_df 24 | 25 | patients_df = generate_quick_start_df(n_patients=10000, n_cov=5, d_times=14, j_events=2, pid_col='pid', seed=0) 26 | 27 | fitter = TwoStagesFitter() 28 | fitter.fit(df=patients_df.drop(['C', 'T'], axis=1)) 29 | fitter.print_summary() 30 | ``` 31 | 32 | ## Examples 33 | 1. [Usage Example](https://tomer1812.github.io/pydts/UsageExample-Intro/) 34 | 2. [Hospital Length of Stay Simulation Example](https://tomer1812.github.io/pydts/SimulatedDataset/) 35 | 36 | ## Citation 37 | If you found PyDTS useful, please cite: 38 | 39 | ```bibtex 40 | @article{Meir_PyDTS_2022, 41 | author = {Meir, Tomer and Gutman, Rom, and Gorfine, Malka}, 42 | doi = {10.48550/arXiv.2204.05731}, 43 | title = {{PyDTS: A Python Package for Discrete Time Survival Analysis with Competing Risks}}, 44 | url = {https://arxiv.org/abs/2204.05731}, 45 | year = {2022} 46 | } 47 | 48 | @article{Meir_Gorfine_DTSP_2025, 49 | author = {Meir, Tomer and Gorfine, Malka}, 50 | doi = {10.1093/biomtc/ujaf040}, 51 | title = {{Discrete-Time Competing-Risks Regression with or without Penalization}}, 52 | year = {2025}, 53 | journal = {Biometrics}, 54 | volume = {81}, 55 | number = {2}, 56 | url = {https://academic.oup.com/biometrics/article/81/2/ujaf040/8120014}, 57 | } 58 | ``` 59 | 60 | and please consider starring the project [on GitHub](https://github.com/tomer1812/pydts) 61 | 62 | ## How to Contribute 63 | 1. Open Github issues to suggest new features or to report bugs\errors 64 | 2. Contact Tomer or Rom if you want to add a usage example to the documentation 65 | 3. If you want to become a developer (thank you, we appreciate it!) - please contact Tomer or Rom for developers' on-boarding 66 | 67 | Tomer Meir: tomer1812@gmail.com, Rom Gutman: rom.gutman1@gmail.com 68 | -------------------------------------------------------------------------------- /docs/intro.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | Based on 4 | 5 | "PyDTS: A Python Package for Discrete-Time Survival Analysis with Competing Risks" 6 | 7 | Tomer Meir\*, Rom Gutman\*, and Malka Gorfine (2022) [[1]](#1). 8 | 9 | and 10 | 11 | "Discrete-time Competing-Risks Regression with or without Penalization" 12 | 13 | Tomer Meir and Malka Gorfine (2023) [[2]](#2). 14 | 15 | ## Discrete-data survival analysis 16 | Discrete-data survival analysis refers to the case where data can only take values over a discrete grid. Sometimes, events can only occur at regular, discrete points in time. For example, in the United States a change in party controlling the presidency only occurs quadrennially in the month of January [[3]](#3). In other situations events may occur at any point in time, but available data record only the particular interval of time in which each event occurs. For example, death from cancer measured by months since time of diagnosis [[4]](#4), or length of stay in hospital recorded on a daily basis. It is well-known that naively using standard continuous-time models (even after correcting for ties) with discrete-time data may result in biased estimators for the discrete time models. 17 | 18 | ## Competing events 19 | 20 | Competing events arise when individuals are susceptible to several types of events but can experience at most one event. For example, competing risks for hospital length of stay are discharge and in-hospital death. Occurrence of one of these events precludes us from observing the other event on this patient. Another classical example of competing risks is cause-specific mortality, such as death from heart disease, death from cancer and death from other causes [[5, 6]](#5#6). 21 | 22 | 23 | PyDTS is an open source Python package which implements tools for discrete-time survival analysis with competing risks. 24 | 25 | 26 | ## References 27 | [1] 28 | Meir, Tomer\*, Gutman, Rom\*, and Gorfine, Malka, 29 | "PyDTS: A Python Package for Discrete-Time Survival Analysis with Competing Risks" 30 | (2022) 31 | 32 | [2] 33 | Meir, Tomer and Gorfine, Malka, 34 | "Discrete-time Competing-Risks Regression with or without Penalization", Biometrics (2025), doi: 10.1093/biomtc/ujaf040 35 | 36 | [3] 37 | Allison, Paul D. 38 | "Discrete-Time Methods for the Analysis of Event Histories" 39 | Sociological Methodology (1982), 40 | doi: 10.2307/270718 41 | 42 | [4] 43 | Lee, Minjung and Feuer, Eric J. and Fine, Jason P. 44 | "On the analysis of discrete time competing risks data" 45 | Biometrics (2018) 46 | doi: 10.1111/biom.12881 47 | 48 | [5] 49 | Kalbfleisch, John D. and Prentice, Ross L. 50 | "The Statistical Analysis of Failure Time Data" 2nd Ed., 51 | Wiley (2011) 52 | ISBN: 978-1-118-03123-0 53 | 54 | [6] 55 | Klein, John P. and Moeschberger, Melvin L. 56 | "Survival Analysis", 57 | Springer (2003) 58 | ISBN: 978-0-387-95399-1 59 | -------------------------------------------------------------------------------- /docs/methods.md: -------------------------------------------------------------------------------- 1 | # Methods 2 | 3 | 4 | ## Definitions 5 | 6 | We let $T$ denote a discrete event time that can take on only the values $\{1,2,...,d\}$ and $J$ denote the type of event, $J \in \{1,\ldots,M\}$. Consider a $p \times 1$ vector of baseline covariates $Z$. A general discrete cause-specific hazard function is of the form 7 | $$ 8 | \lambda_j(t|Z) = \Pr(T=t,J=j|T\geq t, Z) \hspace{0.3cm} t \in \{1,2,...,d\} \hspace{0.3cm} j=1,\ldots,M \, . 9 | $$ 10 | A popular semi-parametric model of the above hazard function based on a transformation regression model is of the form 11 | $$ 12 | h(\lambda_{j}(t|Z)) = \alpha_{jt} +Z^T \beta_j \hspace{0.3cm} t \in \{1,2,...,d\} \hspace{0.3cm} j=1, \ldots,M 13 | $$ 14 | such that $h$ is a known function [[2, and reference therein]](#2). The total number of parameters in the model is $M(d+p)$. The logit function $h(a)=\log \{ a/(1-a) \}$ yields 15 | \begin{equation}\label{eq:logis} 16 | \lambda_j(t|Z)=\frac{\exp(\alpha_{jt}+Z^T\beta_j)}{1+\exp(\alpha_{jt}+Z^T\beta_j)} \, . 17 | \end{equation} 18 | It should be noted that leaving $\alpha_{jt}$ unspecified is analogous to having an unspecified baseline hazard function in the Cox proportional hazard model [[3]](#3), and thus we consider the above as a semi-parametric model. 19 | 20 | 21 | Let $S(t|Z) = \Pr(T>t|Z)$ be the overall survival given $Z$. Then, the probability of experiencing event of type $j$ at time $t$ equals 22 | $$ 23 | \Pr(T=t,J=j|Z)=\lambda_j(t|Z) \prod_{k=1}^{t-1} \left\lbrace 1- 24 | \sum_{j'=1}^M\lambda_{j'}(k|Z) \right\rbrace 25 | $$ 26 | and the cumulative incident function (CIF) of cause $j$ is given by 27 | $$ 28 | F_j(t|Z) = \Pr(T \leq t, J=j|Z) = \sum_{m=1}^{t} \lambda_j(m|Z) S(m-1|Z) = \sum_{m=1}^{t}\lambda_j(m|Z) \prod_{k=1}^{m-1} \left\lbrace 1-\sum_{j'=1}^M\lambda_{j'}(k|Z) \right\rbrace \, . 29 | $$ 30 | Finally, the marginal probability of event type $j$ (marginally with respect to the time of event), given $Z$, equals 31 | $$ 32 | \Pr(J=j|Z) = \sum_{m=1}^{d} \lambda_j(m|Z) \prod_{k=1}^{m-1} \left\lbrace 1-\sum_{j'=1}^M\lambda_{j'}(k|Z) \right\rbrace \, . 33 | $$ 34 | In the next section we provide a fast estimation technique of the parameters $\{\alpha_{j1},\ldots,\alpha_{jd},\beta_j^T \, ; \, j=1,\ldots,M\}$. 35 | 36 | 37 | ## The Collapsed Log-Likelihood Approach and the Proposed Estimators 38 | For simplicity of presentation, we assume two competing events, i.e., $M=2$ and our goal is estimating $\{\alpha_{11},\ldots,\alpha_{1d},\beta_1^T,\alpha_{21},\ldots,\alpha_{2d},\beta_2^T\}$ along with the standard error of the estimators. The data at hand consist of $n$ independent observations, each with $(X_i,\delta_i,J_i,Z_i)$ where $X_i=\min(C_i,T_i)$, $C_i$ is a right-censoring time, 39 | $\delta_i=I(X_i=T_i)$ is the event indicator and $J_i\in\{0,1,2\}$, where $J_i=0$ if and only if $\delta_i=0$. Assume that given the covariates, the censoring and failure time are independent and non-informative. Then, the likelihood function is proportional to 40 | $$ 41 | L = \prod_{i=1}^n \left\lbrace\frac{\lambda_1(X_i|Z_i)}{1-\lambda_1(X_i|Z_i)-\lambda_2(X_i|Z_i)}\right\rbrace^{I(\delta_{1i}=1)} \left\lbrace\frac{\lambda_2(X_i|Z_i)}{1-\lambda_1(X_i|Z_i)-\lambda_2(X_i|Z_i)}\right\rbrace^{I(\delta_{2i}=1)} \prod_{k=1}^{X_i}\lbrace 1-\lambda_1(k|Z_i)-\lambda_2(k|Z_i)\rbrace 42 | $$ 43 | or, equivalently, 44 | $$ 45 | L = \prod_{i=1}^n \left\[ \prod_{j=1}^2 \prod_{m=1}^{X_i} \left\lbrace \frac{\lambda_j(m|Z_i)}{1-\lambda_1(m|Z_i)-\lambda_2(m|Z_i)}\right\rbrace^{\delta_{jim}}\right] \prod_{k=1}^{X_i}\lbrace 1-\lambda_1(k|Z_i)-\lambda_2(k|Z_i)\rbrace 46 | $$ 47 | where $\delta_{jim}$ equals one if subject $i$ experienced event type $j$ at time $m$; and 0 otherwise. Clearly $L$ cannot be decomposed into separate likelihoods for each cause-specific 48 | hazard function $\lambda_j$. 49 | The log likelihood becomes 50 | $$ 51 | \log L = \sum_{i=1}^n \left\[ \sum_{j=1}^2 \sum_{m=1}^{X_i} \left\[ \delta_{jim} \log \lambda_j(m|Z_i) - \delta_{jim}\{1-\lambda_1(m|Z_i)-\lambda_2(m|Z_i)\}\right\] \right\. 52 | +\left\.\sum_{k=1}^{X_i}\log \lbrace 1-\lambda_1(k|Z_i)-\lambda_2(k|Z_i)\rbrace \right\] 53 | $$ 54 | $$ 55 | = \sum_{i=1}^n \sum_{m=1}^{X_i} \left\[ \delta_{1im} \log \lambda_1(m|Z_i)+\delta_{2im} \log \lambda_2(m|Z_i) \right\. +\left\. \lbrace 1-\delta_{1im}-\delta_{2im}\rbrace \log\lbrace 1-\lambda_1(m|Z_i)-\lambda_2(m|Z_i)\rbrace \right\] \, . 56 | $$ 57 | 58 | Instead of maximizing the $M(d+p)$ parameters simultaneously based on the above log-likelihood, the collapsed log-likelihood of Lee et al. [[4]](#4) can be adopted. Specifically, the data are expanded such that for each observation $i$ the expanded dataset includes $X_i$ rows, one row for each time $t$, $t \leq X_i$. At each time point $t$ the expanded data are conditionally multinomial with one of three possible outcomes $\{\delta_{1it},\delta_{2it},1-\delta_{1it}-\delta_{2it}\}$, as in [Table 1](#tbl:expanded). 59 | 60 | Table 1: Original and expanded datasets with $M = 2$ competing events [[Lee et al. (2018)]](#4). 61 | 62 | 63 | | $i$ | $X_i$ | $\delta_i$ | $Z_i$ | $i$ | $\tilde{X}_i$ | $\delta_{1it}$ | $\delta_{2it}$ | $1 - \delta_{1it} - \delta_{2it}$ | $Z_i$ | 64 | |-----|--------|--------------|-------|-----|----------------|----------------|----------------|-------------------------------|-------| 65 | | 1 | 2 | 1 | $Z_1$ | 1 | 1 | 0 | 0 | 1 | $Z_1$ | 66 | | | | | | 1 | 2 | 1 | 0 | 0 | $Z_1$ | 67 | | 2 | 3 | 2 | $Z_2$ | 2 | 1 | 0 | 0 | 1 | $Z_2$ | 68 | | | | | | 2 | 2 | 0 | 0 | 1 | $Z_2$ | 69 | | | | | | 2 | 3 | 0 | 1 | 0 | $Z_2$ | 70 | | 3 | 3 | 0 | $Z_3$ | 3 | 1 | 0 | 0 | 1 | $Z_3$ | 71 | | | | | | 3 | 2 | 0 | 0 | 1 | $Z_3$ | 72 | | | | | | 3 | 3 | 0 | 0 | 1 | $Z_3$ | 73 | 74 | 75 | 76 | 77 | Then, for estimating $\{\alpha_{11},\ldots,\alpha_{1d},\beta_1^T\}$, we combine $\delta_{2it}$ and $1-\delta_{1it}-\delta_{2it}$, and the collapsed log-likelihood for cause $J=1$ based on a binary regression model with $\delta_{1it}$ as the outcome is given by 78 | $$ 79 | \log L_1 = \sum_{i=1}^n \sum_{m=1}^{X_i}\left\[ \delta_{1im} \log \lambda_1(m|Z_i)+(1-\delta_{1im})\log \lbrace 1-\lambda_1(m|Z_i)\rbrace \right\] \, . 80 | $$ 81 | Similarly, the collapsed log-likelihood for cause $J=2$ based on a binary regression model with $\delta_{2it}$ as the outcome becomes 82 | $$ 83 | \log L_2 = \sum_{i=1}^n \sum_{m=1}^{X_i}\left\[ \delta_{2im} \log \lambda_2(m|Z_i)+(1-\delta_{2im})\log \lbrace 1-\lambda_2(m|Z_i)\rbrace \right\] 84 | $$ 85 | and one can fit the two models, separately. 86 | 87 | In general, for $M$ competing events, 88 | the estimators of $\{\alpha_{j1},\ldots,\alpha_{jd},\beta_j^T\}$, $j=1,\ldots,M$, are the respective values that maximize 89 | $$ 90 | \log L_j = \sum_{i=1}^n \sum_{m=1}^{X_i}\left[ \delta_{jim} \log \lambda_j(m|Z_i)+(1-\delta_{jim})\log \{1-\lambda_j(m|Z_i)\} \right] \, . 91 | $$ 92 | Namely, each maximization $j$, $j=1,\ldots,M$, consists of maximizing $d + p$ parameters simultaneously. 93 | 94 | ### Proposed Estimators 95 | Alternatively, we propose the following simpler and faster estimation procedure, with a negligible efficiency loss, if any. Our idea exploits the close relationship between conditional logistic regression analysis and stratified Cox regression analysis [[5]](#5). We propose to estimate each $\beta_j$ separately, and given $\beta_j$, $\alpha_{jt}$, $t=1\ldots,d$, are separately estimated. In particular, the proposed estimating procedure consists of the following two speedy steps: 96 | #### Step 1. 97 | Use the expanded dataset and estimate each vector $\beta_j$, $j \in \{1,\ldots, M\}$, by a simple conditional logistic regression, conditioning on the event time $X$, using a stratified Cox analysis. 98 | 99 | #### Step 2. 100 | Given the estimators $\widehat{\beta}_j$ , $j \in \{1,\ldots, M\}$, of Step 1, use the original (un-expanded) data and estimate each $\alpha_{jt}$, $j \in \{1,\ldots,M\}$, $t=1,\ldots,d$, separately, by 101 | 102 | $$\widehat{ \alpha }_{jt} = argmin_{a} \left\lbrace \frac{1}{y_t} \sum_{i=1}^n I(X_i \geq t) \frac{ \exp(a+Z_i^T \widehat{\beta}_j)}{1 + \exp(a + Z_i^T \widehat{\beta}_j)} - \frac{n_{tj}}{y_t} \right\rbrace ^2 $$ 103 | 104 | where $y_t=\sum_{i=1}^n I(X_i \geq t)$ and $n_{tj}=\sum_{i=1}^n I(X_i = t, J_i=j)$. 105 | 106 | The above equation consists minimizing the squared distance between the observed proportion of failures of type $j$ at time $t$ 107 | ($n_{tj}/y_t$) and the expected proportion of failures given model defined above for $\lambda_j$ and $\widehat{\beta}_j$. 108 | The simulation results of section Simple Simulation reveals that the above two-step procedure performs well in terms of bias, and provides similar standard error of that of [[3]](#3). However, the improvement in computational time, by using our procedure, could be improved by a factor of 1.5-3.5 depending on d. Standard errors of $\widehat{\beta}_j$, $j \in \{1,\ldots,M\}$, can be derived directly from the stratified Cox analysis. 109 | 110 | ### Time-dependent covariates 111 | 112 | Similarly to the continuous-time Cox model, the simplest way to code time-dependent covariates uses intervals of time [[Therneau et al. (2000)]](#6). Then, the data is encoded by breaking the individual’s time into multiple time intervals, with one row of data for each interval. Hence combining this data expansion step with the expansion demonstrated in [Table 1](#tbl:expanded) is straightforward. 113 | 114 | ### Regularized regression models 115 | 116 | Penalized regression methods, such as lasso, adaptive lasso, and elastic net [[Hastie et al. 2009]](#7), place a constraint on the size of the regression coefficients. The estimation procedure of [[Meir and Gorfine (2023)]](#8) that separates the estimation of $\beta_j$ and $\alpha_{jt}$ can easily incorporate such constraints in Lagrangian form by minimizing 117 | 118 | $$ 119 | -\log L_j^c(\beta_j) + \eta_j P(\beta_j) \, , \quad j=1,\ldots,M \, , 120 | $$ 121 | 122 | where $P$ is a penalty function and $\eta_j \geq 0$ are shrinkage tuning parameters. The parameters $\alpha_{jt}$ are estimated once the regularization step is completed and $\beta_j$ are estimated. 123 | 124 | Clearly, any regularized Cox regression model routines can be used for estimating $\beta_j$, $j=1,\ldots,M$, based on the above equation, for example, the `CoxPHFitter` of the `lifelines` Python package [[Davidson-Pilon (2019)]](#9) with penalization. 125 | 126 | 127 | ### Sure Independence Screening 128 | 129 | When the number of available covariates greatly exceeds the number of observations (as common in genetic datasets, for example), i.e., the ultra-high setting, most regularized methods suffer from the curse of dimensionality, high variance, and overfitting [[Hastie et al. (2009)]](#7). 130 | Sure Independent Screening (SIS) is a marginal screening technique designed to filter out uninformative covariates. 131 | Penalized regression methods can be applied after the marginal screening process to the remaining covariates. 132 | 133 | We start the SIS procedure by ranking all the covariates using a utility measure between the response and each covariate, and then retain only covariates with estimated coefficients that exceeds a threshold value. 134 | We focus on SIS and SIS followed by lasso (SIS-L) [[Fan et al. (2010); Saldana and Feng (2018)]](#10) within the proposed two-step procedure. 135 | 136 | We start by fitting a marginal regression for each covariate by maximizing: 137 | 138 | $$ 139 | L_j^{\mathcal{C}}(\beta_{jr}) \quad \text{for } j=1,\ldots,M, \quad r=1,\ldots,p 140 | $$ 141 | 142 | where $\boldsymbol{\beta}_j = (\beta_{j1},\ldots,\beta_{jp})^T$. 143 | Then we rank the features based on the magnitude of their marginal regression coefficients. 144 | The selected sets of variables are given by: 145 | 146 | $$ 147 | \widehat{\mathcal{M}}_{j,w_n} = \left\{1 \leq k \leq p \, : \, |\widehat{\beta}_{jk}| \geq w_n \right\}, \quad j=1,\ldots,M, 148 | $$ 149 | 150 | where $w_n$ is a threshold value. 151 | We adopt the data-driven threshold of [[Saldana and Feng (2018)]](#11). 152 | Given data of the form $\{X_i, \delta_i, J_i, \mathbf{Z}_i \, ; \, i = 1, \ldots, n\}$, a random permutation $\pi$ of $\{1,\ldots,n\}$ is used to decouple $\mathbf{Z}_i$ and $(X_i, \delta_i, J_i)$, so that the permuted data $\{X_i, \delta_i, J_i, \mathbf{Z}_{\pi(i)}\}$ follow a model where the covariates have no predictive power over the survival time of any event type. 153 | 154 | For the permuted data, we re-estimate individual regression coefficients and obtain $\widehat{\beta}^*_{jr}$. The data-driven threshold is defined as: 155 | 156 | $$ 157 | w_n = \max_{1 \leq j \leq M, \, 1 \leq k \leq p} |\widehat{\beta}^*_{jk}|. 158 | $$ 159 | 160 | For the SIS-L procedure, lasso regularization is then applied in the first step of the two-step procedure to the set of covariates selected by SIS. 161 | 162 | 163 | 164 | ## References 165 | [1] 166 | Meir, Tomer\*, Gutman, Rom\*, and Gorfine, Malka 167 | "PyDTS: A Python Package for Discrete Time Survival-analysis with Competing Risks" 168 | (2022) 169 | 170 | [2] 171 | Allison, Paul D. 172 | "Discrete-Time Methods for the Analysis of Event Histories" 173 | Sociological Methodology (1982), 174 | doi: 10.2307/270718 175 | 176 | [3] 177 | Cox, D. R. 178 | "Regression Models and Life-Tables" 179 | Journal of the Royal Statistical Society: Series B (Methodological) (1972) 180 | doi: 10.1111/j.2517-6161.1972.tb00899.x 181 | 182 | [4] 183 | Lee, Minjung and Feuer, Eric J. and Fine, Jason P. 184 | "On the analysis of discrete time competing risks data" 185 | Biometrics (2018) 186 | doi: 10.1111/biom.12881 187 | 188 | [5] 189 | Prentice, Ross L and Breslow, Norman E 190 | "Retrospective studies and failure time models" 191 | Biometrika (1978) 192 | doi: 10.1111/j.2517-6161.1972.tb00899.x 193 | 194 | [6] 195 | Therneau, Terry M and Grambsch, Patricia M, 196 | "Modeling Survival Data: Extending the Cox Model", Springer-Verlag, 197 | (2000) 198 | 199 | [7] 200 | Hastie, Trevor and Tibshirani, Robert and Friedman, Jerome H, 201 | "The Elements of Statistical Learning: Data Mining, Inference, and Prediction.", Springer-Verlag, 202 | (2009) 203 | 204 | [8] 205 | Meir, Tomer and Gorfine, Malka, 206 | "Discrete-time Competing-Risks Regression with or without Penalization", Biometrics (2025), doi: 10.1093/biomtc/ujaf040 207 | 208 | 209 | [9] 210 | Davidson-Pilon, Cameron, 211 | "lifelines: Survival Analysis in Python", Journal of Open Source Software, 212 | (2019) 213 | 214 | [10] 215 | Fan, J and Feng, Y and Wu, Y, 216 | "High-dimensional variable selection for Cox’s proportional hazards model", 217 | Institute of Mathematical Statistics, 218 | (2010) 219 | 220 | [11] 221 | Saldana, DF and Feng, Y, 222 | "SIS: An R package for sure independence screening in ultrahigh-dimensional statistical models", 223 | Journal of Statistical Software, 224 | (2018) 225 | -------------------------------------------------------------------------------- /docs/methodsevaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluation Measures 2 | Let 3 | 4 | $$ 5 | \pi_{ij}(t) = \widehat{\Pr}(T_i=t, J_i=j \mid Z_i) = \widehat{\lambda}_j (t \mid Z_i) \widehat{S}(t-1 \mid Z_i) 6 | $$ 7 | 8 | and 9 | 10 | $$ 11 | D_{ij} (t) = I(T_i = t, J_i = j) 12 | $$ 13 | 14 | The cause-specific incidence/dynamic area under the receiver operating characteristics curve (AUC) is defined and estimated in the spirit of Heagerty and Zheng (2005) and Blanche et al. (2015) as the probability of a random observation with observed event $j$ at time $t$ having a higher risk prediction for cause $j$ than a randomly selected observation $m$, at risk at time $t$, without the observed event $j$ at time $t$. Namely, 15 | 16 | $$ 17 | \mbox{AUC}_j(t) = \Pr (\pi_{ij}(t) > \pi_{mj}(t) \mid D_{ij} (t) = 1, D_{mj} (t) = 0, T_m \geq t) 18 | $$ 19 | 20 | In the presence of censored data and under the assumption that the censoring is independent of the failure time and observed covariates, an inverse probability censoring weighting (IPCW) estimator of $\mbox{AUC}_j(t)$ becomes 21 | 22 | $$ 23 | \widehat{\mbox{AUC}}_j (t) = \frac{\sum_{i=1}^{n}\sum_{m=1}^{n} D_{ij}(t)(1-D_{mj}(t))I(X_m \geq t) W_{ij}(t) W_{mj}(t) \{I(\pi_{ij}(t) > \pi_{mj}(t))+0.5I(\pi_{ij}(t)=\pi_{mj}(t))\}}{\sum_{i=1}^{n}\sum_{m=1}^{n} D_{ij}(t)(1-D_{mj}(t))I(X_m \geq t) W_{ij}(t) W_{mj}(t)} 24 | $$ 25 | 26 | And can be simplified as: 27 | 28 | $$ 29 | \widehat{\mbox{AUC}}_j (t) = \frac{\sum_{i=1}^{n}\sum_{m=1}^{n} D_{ij}(t)(1-D_{mj}(t))I(X_m \geq t) \{I(\pi_{ij}(t) > \pi_{mj}(t))+0.5I(\pi_{ij}(t)=\pi_{mj}(t))\}}{\sum_{i=1}^{n}\sum_{m=1}^{n} D_{ij}(t)(1-D_{mj}(t))I(X_m \geq t)} 30 | $$ 31 | 32 | where 33 | 34 | $$ 35 | W_{ij}(t) = \frac{D_{ij}(t)}{\widehat{G}_C(T_i)} + I(X_i \geq t)\frac{1-D_{ij}(t)}{\widehat{G}_C(t)} = \frac{D_{ij}(t)}{\widehat{G}_C(t)} + I(X_i \geq t)\frac{1-D_{ij}(t)}{\widehat{G}_C(t)} = I(X_i \geq t) / \widehat{G}_C(t) 36 | $$ 37 | 38 | and $\widehat{G}_C(\cdot)$ is the estimated survival function of the censoring (e.g., the Kaplan-Meier estimator). Interestingly, the IPCWs have no effect on $\widehat{\mbox{AUC}}_j (t)$. 39 | 40 | An integrated cause-specific AUC can be estimated as a weighted sum by 41 | 42 | $$ 43 | \widehat{\mbox{AUC}}_j = \sum_t \widehat{\mbox{AUC}}_j (t) w_j (t) 44 | $$ 45 | 46 | and we adopt a simple data-driven weight function of the form 47 | 48 | $$ 49 | w_j(t) = \frac{N_j(t)}{\sum_t N_j(t)} 50 | $$ 51 | 52 | A global AUC can be defined as 53 | 54 | $$ 55 | \widehat{\mbox{AUC}} = \sum_j \widehat{\mbox{AUC}}_j v_j 56 | $$ 57 | 58 | where 59 | 60 | $$ 61 | v_j = \frac{\sum_{t} N_j(t)}{ \sum_{j=1}^M \sum_{t} N_j(t) } 62 | $$ 63 | 64 | Another well-known performance measure is the Brier Score (BS). In the spirit of Blanche et al. (2015) we define 65 | 66 | $$ 67 | \widehat{\mbox{BS}}_{j}(t) = \frac{1}{Y_{\cdot}(t)} {\sum_{i=1}^n W_{ij}(t) \left( D_{ij}(t) - \pi_{ij}(t)\right)^2} \, . 68 | $$ 69 | 70 | An integrated cause-specific BS can be estimated by the weighted sum 71 | 72 | $$ 73 | \widehat{\mbox{BS}}_{j} = \sum_t \widehat{\mbox{BS}}_{j}(t) w_j(t) 74 | $$ 75 | 76 | and an estimated global BS is given by 77 | 78 | $$ 79 | \widehat{\mbox{BS}} = \sum_j \widehat{\mbox{BS}}_{j} v_j \, . 80 | $$ 81 | -------------------------------------------------------------------------------- /docs/methodsintro.md: -------------------------------------------------------------------------------- 1 | # Methods 2 | 3 | In this section, we outline the statistical background for the tools incorporated in PyDTS. We commence with some definitions, present the collapsed log-likelihood approach and the estimation procedure of Lee et al. (2018) [[4]](#4), introduce our estimation method [[1]](#1)-[[2]](#2), and conclude with evaluation metrics. For additional details, check out the references. 4 | 5 | 6 | ## References 7 | [1] 8 | Meir, Tomer\*, Gutman, Rom\*, and Gorfine, Malka, 9 | "PyDTS: A Python Package for Discrete-Time Survival Analysis with Competing Risks" 10 | (2022) 11 | 12 | [2] 13 | Meir, Tomer and Gorfine, Malka, 14 | "Discrete-time Competing-Risks Regression with or without Penalization", Biometrics (2025), doi: 10.1093/biomtc/ujaf040 15 | 16 | 17 | [3] 18 | Allison, Paul D. 19 | "Discrete-Time Methods for the Analysis of Event Histories" 20 | Sociological Methodology (1982), 21 | doi: 10.2307/270718 22 | 23 | [4] 24 | Lee, Minjung and Feuer, Eric J. and Fine, Jason P. 25 | "On the analysis of discrete time competing risks data" 26 | Biometrics (2018) 27 | doi: 10.1111/biom.12881 28 | 29 | [5] 30 | Kalbfleisch, John D. and Prentice, Ross L. 31 | "The Statistical Analysis of Failure Time Data" 2nd Ed., 32 | Wiley (2011) 33 | ISBN: 978-1-118-03123-0 34 | 35 | [6] 36 | Klein, John P. and Moeschberger, Melvin L. 37 | "Survival Analysis", 38 | Springer-Verlag (2003) 39 | ISBN: 978-0-387-95399-1 40 | -------------------------------------------------------------------------------- /docs/models_params.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomer1812/pydts/578929a457c111efe009d3461aab531b793b33d0/docs/models_params.png -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | lifelines==0.26.4 2 | matplotlib==3.5.1 3 | numpy==1.23.4 4 | pandarallel==1.5.7 5 | pandas==1.4.1 6 | patsy==0.5.2 7 | pillow==9.0.1 8 | psutil==5.9.4 9 | scikit-learn==1.0.2 10 | scipy==1.8.0 11 | statsmodels==0.13.2 12 | threadpoolctl==3.1.0 13 | tqdm==4.63.0 14 | seaborn==0.11.2 15 | tableone==0.7.10 16 | pydts==0.7.8 -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | theme: 2 | name: material 3 | features: 4 | - navigation.sections # Sections are included in the navigation on the left. 5 | - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right. 6 | - header.autohide # header disappears as you scroll 7 | 8 | site_name: PyDTS 9 | site_description: The documentation for the PyDTS software library. 10 | site_author: Tomer Meir, Rom Gutman 11 | 12 | repo_url: https://github.com/tomer1812/pydts/ 13 | repo_name: tomer1812/pydts 14 | edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples_utils via symlink, so it's impossible for them all to be accurate 15 | 16 | strict: false # Don't allow warnings during the build process 17 | 18 | #markdown_extensions: 19 | # - pymdownx.highlight: 20 | # anchor_linenums: true 21 | # - pymdownx.inlinehilite 22 | # - pymdownx.arithmatex: # Render LaTeX via MathJax 23 | # generic: true 24 | # inline_syntax: ['$', '$'] 25 | # - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme. 26 | # - pymdownx.details # Allowing hidden expandable regions denoted by ??? 27 | # - pymdownx.snippets: # Include one Markdown file into another 28 | # base_path: docs 29 | # - admonition 30 | # - toc: 31 | # permalink: "¤" # Adds a clickable permalink to each section heading 32 | # toc_depth: 4 # Prevents h5, h6 from showing up in the TOC. 33 | # 34 | #extra_javascript: 35 | # - javascripts/mathjax.js 36 | # - https://polyfill.io/v3/polyfill.min.js?features=es6 37 | # - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 38 | 39 | markdown_extensions: 40 | - pymdownx.highlight: 41 | anchor_linenums: true 42 | - pymdownx.inlinehilite 43 | - pymdownx.arithmatex: 44 | generic: true 45 | - pymdownx.superfences 46 | - pymdownx.details 47 | - pymdownx.snippets: 48 | base_path: docs 49 | - admonition 50 | - toc: 51 | permalink: "¤" 52 | toc_depth: 4 53 | 54 | extra_javascript: 55 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 56 | 57 | nav: 58 | - Home: 'index.md' 59 | - Introduction: intro.md 60 | - Methods: 61 | - Introduction: methodsintro.md 62 | - Definitions and Estimation: methods.md 63 | - Evaluation Metrics: methodsevaluation.md 64 | - Examples: 65 | - Event Times Sampler: EventTimesSampler.ipynb 66 | - Estimation Example: 67 | - Introduction: UsageExample-Intro.ipynb 68 | - Data Preparation: UsageExample-DataPreparation.ipynb 69 | - Estimation with TwoStagesFitter: UsageExample-FittingTwoStagesFitter.ipynb 70 | - Estimation with DataExpansionFitter: UsageExample-FittingDataExpansionFitter.ipynb 71 | - Data Regrouping Example: UsageExample-RegroupingData.ipynb 72 | - Comparing the Estimation Methods: ModelsComparison.ipynb 73 | - Evaluation: PerformanceMeasures.ipynb 74 | - Regularization: Regularization.ipynb 75 | - Small Sample Size Example: UsageExample-FittingTwoStagesFitterExact-FULL.ipynb 76 | - Screening Example: UsageExample-SIS-SIS-L.ipynb 77 | - Hospitalization LOS Simulation: SimulatedDataset.ipynb 78 | - API: 79 | - The Two Stages Procedure of Meir and Gorfine (2023) - Efron: 'api/two_stages_fitter.md' 80 | - The Two Stages Procedure of Meir and Gorfine (2023) - Exact: 'api/two_stages_fitter_exact.md' 81 | - Data Expansion Procedure of Lee et al. (2018): 'api/data_expansion_fitter.md' 82 | - Event Times Sampler: 'api/event_times_sampler.md' 83 | - Evaluation: 'api/evaluation.md' 84 | - Cross Validation: 'api/cross_validation.md' 85 | - Model Selection: 'api/model_selection.md' 86 | - Sure Independent Screening: 'api/screening.md' 87 | - Utils: 'api/utils.md' 88 | 89 | plugins: 90 | - mknotebooks 91 | - search 92 | # - mkdocs-jupyter 93 | - mkdocstrings: 94 | handlers: 95 | python: 96 | options: 97 | members: true 98 | show_inheritance: true 99 | inherited_members: true 100 | show_root_heading: true 101 | show_source: true 102 | merge_init_into_class: true 103 | docstring_style: google 104 | 105 | extra: 106 | copyright: Copyright © 2022 Tomer Meir, Rom Gutman, Malka Gorfine 107 | analytics: 108 | provider: google 109 | property: G-Z0XYP3868P -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pydts" 3 | version = "0.9.7" 4 | description = "Discrete time survival analysis with competing risks" 5 | authors = ["Tomer Meir ", "Rom Gutman ", "Malka Gorfine "] 6 | license = "GNU GPLv3" 7 | readme = "README.md" 8 | homepage = "https://github.com/tomer1812/pydts" 9 | repository = "https://github.com/tomer1812/pydts" 10 | keywords = ["Discrete Time", "Time to Event" ,"Survival Analysis", "Competing Events"] 11 | documentation = "https://tomer1812.github.io/pydts" 12 | 13 | [tool.poetry.dependencies] 14 | python = ">=3.9,<3.11" 15 | pandas = "^1.4.1" 16 | mkdocs = "^1.4.3" 17 | mkdocs-material = "^9.0.0" 18 | #mknotebooks = "^0.7.1" 19 | lifelines = "^0.26.4" 20 | scipy = "^1.8.0" 21 | scikit-learn = "^1.0.2" 22 | tqdm = "^4.63.0" 23 | statsmodels = "^0.13.2" 24 | pandarallel = "^1.5.7" 25 | ipython = "^8.2.0" 26 | numpy = ">=1.23.4, <2.0.0" 27 | psutil = "^5.9.4" 28 | setuptools = "^68.0.0" 29 | seaborn = "^0.12.2" 30 | mkdocstrings = "^0.28" 31 | mkdocstrings-python = "^1.16" 32 | mknotebooks = "^0.8.0" 33 | 34 | [tool.poetry.dev-dependencies] 35 | pytest = "^7.0.1" 36 | coverage = {extras = ["toml"], version = "^6.3.2"} 37 | pytest-cov = "^3.0.0" 38 | mkdocs = "^1.2.3" 39 | mkdocs-material = "^9.0.0" 40 | #mknotebooks = "^0.7.1" 41 | jupyter = "^1.0.0" 42 | scikit-survival = "^0.17.1" 43 | tableone = "^0.7.10" 44 | 45 | [build-system] 46 | requires = ["poetry-core>=1.0.0"] 47 | build-backend = "poetry.core.masonry.api" 48 | 49 | [tool.coverage.paths] 50 | source = ["src", "*/site-packages"] 51 | 52 | [tool.coverage.run] 53 | branch = true 54 | source = ["pydts"] 55 | 56 | [tool.coverage.report] 57 | show_missing = true 58 | -------------------------------------------------------------------------------- /src/pydts/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" -------------------------------------------------------------------------------- /src/pydts/base_fitters.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Tuple, Union 2 | 3 | import pandas as pd 4 | import numpy as np 5 | from .utils import get_expanded_df 6 | 7 | 8 | 9 | class BaseFitter: 10 | """ 11 | This class implements the basic fitter methods and attributes api 12 | """ 13 | 14 | def __init__(self): 15 | self.event_models = {} 16 | self.expanded_df = pd.DataFrame() 17 | self.event_type_col = None 18 | self.duration_col = None 19 | self.pid_col = None 20 | self.events = None 21 | self.covariates = None 22 | self.formula = None 23 | self.times = None 24 | 25 | def fit(self, df: pd.DataFrame, event_type_col: str = 'J', duration_col: str = 'X', pid_col: str = 'pid', 26 | **kwargs) -> dict: 27 | raise NotImplementedError 28 | 29 | def predict(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame: 30 | raise NotImplementedError 31 | 32 | def evaluate(self, test_df: pd.DataFrame, oracle_col: str = 'T', **kwargs) -> float: 33 | raise NotImplementedError 34 | 35 | def print_summary(self, **kwargs) -> None: 36 | raise NotImplementedError 37 | 38 | def _validate_t(self, t, return_iter=True): 39 | _t = np.array([t]) if not isinstance(t, Iterable) else t 40 | t_i_not_fitted = [t_i for t_i in _t if (t_i not in self.times)] 41 | assert len(t_i_not_fitted) == 0, \ 42 | f"Cannot predict for times which were not included during .fit(): {t_i_not_fitted}" 43 | if return_iter: 44 | return _t 45 | return t 46 | 47 | def _validate_covariates_in_df(self, df): 48 | cov_not_fitted = [] 49 | if isinstance(self.covariates, list): 50 | cov_not_fitted = [cov for cov in self.covariates if cov not in df.columns] 51 | elif isinstance(self.covariates, dict): 52 | for event in self.events: 53 | event_cov_not_fitted = [cov for cov in self.covariates[event] if cov not in df.columns] 54 | cov_not_fitted.extend(event_cov_not_fitted) 55 | assert len(cov_not_fitted) == 0, \ 56 | f"Cannot predict - required covariates are missing from df: {cov_not_fitted}" 57 | 58 | def _validate_cols(self, df, event_type_col, duration_col, pid_col): 59 | assert event_type_col in df.columns, f'Event type column is missing from df: {event_type_col}' 60 | assert duration_col in df.columns, f'Duration column is missing from df: {duration_col}' 61 | assert pid_col in df.columns, f'Observation ID column is missing from df: {pid_col}' 62 | 63 | 64 | class ExpansionBasedFitter(BaseFitter): 65 | """ 66 | This class implements the data expansion method which is common for the existing fitters 67 | """ 68 | 69 | def _expand_data(self, 70 | df: pd.DataFrame, 71 | event_type_col: str, 72 | duration_col: str, 73 | pid_col: str) -> pd.DataFrame: 74 | """ 75 | This method expands the raw data as explained in Lee et al. 2018 76 | 77 | Args: 78 | df (pandas.DataFrame): Dataframe to expand. 79 | event_type_col (str): The event type column name (must be a column in df), 80 | Right censored sample (i) is indicated by event value 0, df.loc[i, event_type_col] = 0. 81 | duration_col (str): Last follow up time column name (must be a column in df). 82 | pid_col (str): Sample ID column name (must be a column in df). 83 | 84 | Returns: 85 | Expanded df (pandas.DataFrame): the expanded dataframe. 86 | """ 87 | self._validate_cols(df, event_type_col, duration_col, pid_col) 88 | return get_expanded_df(df=df, event_type_col=event_type_col, duration_col=duration_col, pid_col=pid_col) 89 | 90 | def predict_hazard_jt(self, df: pd.DataFrame, event: Union[str, int], t: Union[Iterable, int]) -> pd.DataFrame: 91 | """ 92 | This method calculates the hazard for the given event at the given time values if they were included in 93 | the training set of the event. 94 | 95 | Args: 96 | df (pd.DataFrame): samples to predict for 97 | event (Union[str, int]): event name 98 | t (Union[Iterable, int]): times to calculate the hazard for 99 | 100 | Returns: 101 | df (pd.DataFrame): samples with the prediction columns 102 | """ 103 | raise NotImplementedError 104 | 105 | def predict_hazard_t(self, df: pd.DataFrame, t: Union[int, np.array]) -> pd.DataFrame: 106 | """ 107 | This function calculates the hazard for all the events at the requested time values if they were included in 108 | the training set of each event. 109 | 110 | Args: 111 | df (pd.DataFrame): samples to predict for 112 | t (int, np.array): times to calculate the hazard for 113 | 114 | Returns: 115 | df (pd.DataFrame): samples with the prediction columns 116 | """ 117 | t = self._validate_t(t) 118 | self._validate_covariates_in_df(df.head()) 119 | 120 | for event, model in self.event_models.items(): 121 | df = self.predict_hazard_jt(df=df, event=event, t=t) 122 | return df 123 | 124 | def predict_hazard_all(self, df: pd.DataFrame) -> pd.DataFrame: 125 | """ 126 | This function calculates the hazard for all the events at all time values included in the training set for each 127 | event. 128 | 129 | Args: 130 | df (pd.DataFrame): samples to predict for 131 | 132 | Returns: 133 | df (pd.DataFrame): samples with the prediction columns 134 | 135 | """ 136 | self._validate_covariates_in_df(df.head()) 137 | df = self.predict_hazard_t(df, t=self.times[:-1]) 138 | return df 139 | 140 | def predict_overall_survival(self, 141 | df: pd.DataFrame, 142 | t: int = None, 143 | return_hazards: bool = False) -> pd.DataFrame: 144 | """ 145 | This function adds columns of the overall survival until time t. 146 | Args: 147 | df (pandas.DataFrame): dataframe with covariates columns 148 | t (int): time 149 | return_hazards (bool): if to keep the hazard columns 150 | 151 | Returns: 152 | df (pandas.DataFrame): dataframe with the additional overall survival columns 153 | 154 | """ 155 | if t is not None: 156 | self._validate_t(t, return_iter=False) 157 | self._validate_covariates_in_df(df.head()) 158 | 159 | all_hazards = self.predict_hazard_all(df) 160 | _times = self.times[:-1] if t is None else [_t for _t in self.times[:-1] if _t <= t] 161 | overall = pd.DataFrame() 162 | for t_i in _times: 163 | cols = [f'hazard_j{e}_t{t_i}' for e in self.events] 164 | t_i_hazard = 1 - all_hazards[cols].sum(axis=1) 165 | t_i_hazard.name = f'overall_survival_t{t_i}' 166 | overall = pd.concat([overall, t_i_hazard], axis=1) 167 | overall = pd.concat([df, overall.cumprod(axis=1)], axis=1) 168 | 169 | if return_hazards: 170 | cols = all_hazards.columns[all_hazards.columns.str.startswith("hazard_")] 171 | cols = cols.difference(overall.columns) 172 | if len(cols) > 0: 173 | overall = pd.concat([overall, all_hazards[cols]], axis=1) 174 | return overall 175 | 176 | def predict_prob_event_j_at_t(self, df: pd.DataFrame, event: Union[str, int], t: int) -> pd.DataFrame: 177 | """ 178 | This function adds a column with probability of occurance of a specific event at a specific a time. 179 | 180 | Args: 181 | df (pandas.DataFrame): dataframe with covariates columns 182 | event (Union[str, int]): event name 183 | t (int): time 184 | 185 | Returns: 186 | df (pandas.DataFrame): dataframe an additional probability column 187 | 188 | """ 189 | assert event in self.events, \ 190 | f"Cannot predict for event {event} - it was not included during .fit()" 191 | self._validate_t(t, return_iter=False) 192 | self._validate_covariates_in_df(df.head()) 193 | 194 | if f'prob_j{event}_at_t{t}' not in df.columns: 195 | if t == 1: 196 | if f'hazard_j{event}_t{t}' not in df.columns: 197 | df = self.predict_hazard_jt(df=df, event=event, t=t) 198 | df[f'prob_j{event}_at_t{t}'] = df[f'hazard_j{event}_t{t}'] 199 | return df 200 | elif not f'overall_survival_t{t - 1}' in df.columns: 201 | df = self.predict_overall_survival(df, t=t, return_hazards=True) 202 | elif not f'hazard_j{event}_t{t}' in df.columns: 203 | df = self.predict_hazard_t(df, t=np.array([_t for _t in self.times[:-1] if _t <= t])) 204 | df[f'prob_j{event}_at_t{t}'] = df[f'overall_survival_t{t - 1}'] * df[f'hazard_j{event}_t{t}'] 205 | return df 206 | 207 | def predict_prob_event_j_all(self, df: pd.DataFrame, event: Union[str, int]) -> pd.DataFrame: 208 | """ 209 | This function adds columns of a specific event occurrence probabilities. 210 | 211 | Args: 212 | df (pandas.DataFrame): dataframe with covariates columns 213 | event (Union[str, int]): event name 214 | 215 | Returns: 216 | df (pandas.DataFrame): dataframe with probabilities columns 217 | 218 | """ 219 | assert event in self.events, \ 220 | f"Cannot predict for event {event} - it was not included during .fit()" 221 | self._validate_covariates_in_df(df.head()) 222 | 223 | if f'overall_survival_t{self.times[-2]}' not in df.columns: 224 | df = self.predict_overall_survival(df, return_hazards=True) 225 | for t in self.times[:-1]: 226 | if f'prob_j{event}_at_t{t}' not in df.columns: 227 | df = self.predict_prob_event_j_at_t(df=df, event=event, t=t) 228 | return df 229 | 230 | def predict_prob_events(self, df: pd.DataFrame) -> pd.DataFrame: 231 | """ 232 | This function adds columns of all the events occurance probabilities. 233 | Args: 234 | df (pandas.DataFrame): dataframe with covariates columns 235 | 236 | Returns: 237 | df (pandas.DataFrame): dataframe with probabilities columns 238 | 239 | """ 240 | self._validate_covariates_in_df(df.head()) 241 | 242 | for event in self.events: 243 | df = self.predict_prob_event_j_all(df=df, event=event) 244 | return df 245 | 246 | def predict_event_cumulative_incident_function(self, df: pd.DataFrame, event: Union[str, int]) -> pd.DataFrame: 247 | """ 248 | This function adds a specific event columns of the predicted hazard function, overall survival, probabilities 249 | of event occurance and cumulative incident function (CIF) to the given dataframe. 250 | 251 | Args: 252 | df (pandas.DataFrame): dataframe with covariates columns included 253 | event (Union[str, int]): event name 254 | 255 | Returns: 256 | df (pandas.DataFrame): dataframe with additional prediction columns 257 | 258 | """ 259 | assert event in self.events, \ 260 | f"Cannot predict for event {event} - it was not included during .fit()" 261 | self._validate_covariates_in_df(df.head()) 262 | 263 | if f'prob_j{event}_at_t{self.times[-2]}' not in df.columns: 264 | df = self.predict_prob_events(df=df) 265 | cols = [f'prob_j{event}_at_t{t}' for t in self.times[:-1]] 266 | cif_df = df[cols].cumsum(axis=1) 267 | cif_df.columns = [f'cif_j{event}_at_t{t}' for t in self.times[:-1]] 268 | df = pd.concat([df, cif_df], axis=1) 269 | return df 270 | 271 | def predict_cumulative_incident_function(self, df: pd.DataFrame) -> pd.DataFrame: 272 | """ 273 | This function adds columns of the predicted hazard function, overall survival, probabilities of event occurance 274 | and cumulative incident function (CIF) to the given dataframe. 275 | 276 | Args: 277 | df (pandas.DataFrame): dataframe with covariates columns included 278 | 279 | Returns: 280 | df (pandas.DataFrame): dataframe with additional prediction columns 281 | 282 | """ 283 | self._validate_covariates_in_df(df.head()) 284 | 285 | for event in self.events: 286 | if f'cif_j{event}_at_t{self.times[-2]}' not in df.columns: 287 | df = self.predict_event_cumulative_incident_function(df=df, event=event) 288 | return df 289 | 290 | def predict_marginal_prob_event_j(self, df: pd.DataFrame, event: Union[str, int]) -> pd.DataFrame: 291 | """ 292 | This function calculates the marginal probability of an event given the covariates. 293 | 294 | Args: 295 | df (pandas.DataFrame): dataframe with covariates columns included 296 | event (Union[str, int]): event name 297 | 298 | Returns: 299 | df (pandas.DataFrame): dataframe with additional prediction columns 300 | """ 301 | 302 | assert event in self.events, \ 303 | f"Cannot predict for event {event} - it was not included during .fit()" 304 | self._validate_covariates_in_df(df.head()) 305 | 306 | if f'prob_j{event}_at_t{self.times[-2]}' not in df.columns: 307 | df = self.predict_prob_event_j_all(df=df, event=event) 308 | cols = [f'prob_j{event}_at_t{_t}' for _t in self.times[:-1]] 309 | marginal_prob = df[cols].sum(axis=1) 310 | marginal_prob.name = f'marginal_prob_j{event}' 311 | return pd.concat([df, marginal_prob], axis=1) 312 | 313 | def predict_marginal_prob_all_events(self, df: pd.DataFrame) -> pd.DataFrame: 314 | """ 315 | This function calculates the marginal probability per event given the covariates for all the events. 316 | 317 | Args: 318 | df (pandas.DataFrame): dataframe with covariates columns included 319 | 320 | Returns: 321 | df (pandas.DataFrame): dataframe with additional prediction columns 322 | """ 323 | self._validate_covariates_in_df(df.head()) 324 | for event in self.events: 325 | df = self.predict_marginal_prob_event_j(df=df, event=event) 326 | return df 327 | -------------------------------------------------------------------------------- /src/pydts/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | PROJECT_DIR = os.path.join(os.path.dirname(__file__)) 4 | OUTPUT_DIR = os.path.join(PROJECT_DIR, '../../output') 5 | if not os.path.isdir(OUTPUT_DIR): 6 | os.mkdir(OUTPUT_DIR) 7 | -------------------------------------------------------------------------------- /src/pydts/cross_validation.py: -------------------------------------------------------------------------------- 1 | __all__ = ["TwoStagesCV", "TwoStagesCVExact", "PenaltyGridSearchCV", "PenaltyGridSearchCVExact"] 2 | 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from .fitters import TwoStagesFitter, TwoStagesFitterExact 7 | import warnings 8 | from copy import deepcopy 9 | from sklearn.model_selection import KFold 10 | pd.set_option("display.max_rows", 500) 11 | warnings.filterwarnings('ignore') 12 | slicer = pd.IndexSlice 13 | from typing import Optional, List, Union 14 | import psutil 15 | from .evaluation import events_brier_score_at_t, events_integrated_brier_score, global_brier_score, \ 16 | events_integrated_auc, global_auc, events_auc_at_t 17 | from .model_selection import PenaltyGridSearch, PenaltyGridSearchExact 18 | from time import time 19 | 20 | WORKERS = psutil.cpu_count(logical=False) 21 | 22 | 23 | class BaseTwoStagesCV(object): 24 | """ 25 | This class implements K-fold cross-validation using TwoStagesFitters and TwoStagesFittersExact 26 | """ 27 | 28 | def __init__(self): 29 | self.models = {} 30 | self.test_pids = {} 31 | self.results = pd.DataFrame() 32 | self.global_auc = {} 33 | self.integrated_auc = {} 34 | self.global_bs = {} 35 | self.integrated_bs = {} 36 | self.TwoStagesFitter_type = 'CoxPHFitter' 37 | 38 | def cross_validate(self, 39 | full_df: pd.DataFrame, 40 | n_splits: int = 5, 41 | shuffle: bool = True, 42 | seed: Union[int, None] = None, 43 | fit_beta_kwargs: dict = {}, 44 | covariates=None, 45 | event_type_col: str = 'J', 46 | duration_col: str = 'X', 47 | pid_col: str = 'pid', 48 | x0: Union[np.array, int] = 0, 49 | verbose: int = 2, 50 | nb_workers: int = WORKERS, 51 | metrics=['BS', 'IBS', 'GBS', 'AUC', 'IAUC', 'GAUC']): 52 | 53 | """ 54 | This method implements K-fold cross-validation using TwoStagesFitters and full_df data. 55 | 56 | Args: 57 | full_df (pd.DataFrame): Data to cross validate. 58 | n_splits (int): Number of folds, defaults to 5. 59 | shuffle (bool): Shuffle samples before splitting to folds. Defaults to True. 60 | seed: Pseudo-random seed to KFold instance. Defaults to None. 61 | fit_beta_kwargs (dict, Optional): Keyword arguments to pass on to the estimation procedure. If different model for beta is desired, it can be defined here. 62 | covariates (list): list of covariates to be used in estimating the regression coefficients. 63 | event_type_col (str): The event type column name (must be a column in df), Right-censored sample (i) is indicated by event value 0, df.loc[i, event_type_col] = 0. 64 | duration_col (str): Last follow up time column name (must be a column in full_df). 65 | pid_col (str): Sample ID column name (must be a column in full_df). 66 | x0 (Union[numpy.array, int], Optional): initial guess to pass to scipy.optimize.minimize function 67 | verbose (int, Optional): The verbosity level of pandaallel 68 | nb_workers (int, Optional): The number of workers to pandaallel. If not sepcified, defaults to the number of workers available. 69 | metrics (str, list): Evaluation metrics. 70 | 71 | Returns: 72 | Results (pd.DataFrame): Cross validation metrics results 73 | """ 74 | 75 | if isinstance(metrics, str): 76 | metrics = [metrics] 77 | 78 | self.models = {} 79 | self.kfold_cv = KFold(n_splits=n_splits, shuffle=shuffle, random_state=seed) 80 | 81 | if 'C' in full_df.columns: 82 | full_df = full_df.drop(['C'], axis=1) 83 | if 'T' in full_df.columns: 84 | full_df = full_df.drop(['T'], axis=1) 85 | 86 | for i_fold, (train_index, test_index) in enumerate(self.kfold_cv.split(full_df)): 87 | self.test_pids[i_fold] = full_df.iloc[test_index][pid_col].values 88 | train_df, test_df = full_df.iloc[train_index], full_df.iloc[test_index] 89 | if self.TwoStagesFitter_type == 'Exact': 90 | fold_fitter = TwoStagesFitterExact() 91 | else: 92 | fold_fitter = TwoStagesFitter() 93 | print(f'Fitting fold {i_fold+1}/{n_splits}') 94 | fold_fitter.fit(df=train_df, 95 | covariates=covariates, 96 | event_type_col=event_type_col, 97 | duration_col=duration_col, 98 | pid_col=pid_col, 99 | x0=x0, 100 | fit_beta_kwargs=fit_beta_kwargs, 101 | verbose=verbose, 102 | nb_workers=nb_workers) 103 | 104 | #self.models[i_fold] = deepcopy(fold_fitter) 105 | self.models[i_fold] = fold_fitter 106 | 107 | pred_df = self.models[i_fold].predict_prob_events(test_df) 108 | 109 | for metric in metrics: 110 | if metric == 'IAUC': 111 | self.integrated_auc[i_fold] = events_integrated_auc(pred_df, event_type_col=event_type_col, 112 | duration_col=duration_col) 113 | elif metric == 'GAUC': 114 | self.global_auc[i_fold] = global_auc(pred_df, event_type_col=event_type_col, 115 | duration_col=duration_col) 116 | elif metric == 'IBS': 117 | self.integrated_bs[i_fold] = events_integrated_brier_score(pred_df, event_type_col=event_type_col, 118 | duration_col=duration_col) 119 | elif metric == 'GBS': 120 | self.global_bs[i_fold] = global_brier_score(pred_df, event_type_col=event_type_col, 121 | duration_col=duration_col) 122 | elif metric == 'AUC': 123 | tmp_res = events_auc_at_t(pred_df, event_type_col=event_type_col, 124 | duration_col=duration_col) 125 | tmp_res = pd.concat([tmp_res], keys=[i_fold], names=['fold']) 126 | tmp_res = pd.concat([tmp_res], keys=[metric], names=['metric']) 127 | self.results = pd.concat([self.results, tmp_res], axis=0) 128 | elif metric == 'BS': 129 | tmp_res = events_brier_score_at_t(pred_df, event_type_col=event_type_col, 130 | duration_col=duration_col) 131 | tmp_res = pd.concat([tmp_res], keys=[i_fold], names=['fold']) 132 | tmp_res = pd.concat([tmp_res], keys=[metric], names=['metric']) 133 | self.results = pd.concat([self.results, tmp_res], axis=0) 134 | 135 | return self.results 136 | 137 | 138 | class BasePenaltyGridSearchCV(object): 139 | """ 140 | This class implements K-fold cross-validation of the PenaltyGridSearch 141 | """ 142 | 143 | def __init__(self): 144 | self.folds_grids = {} 145 | self.test_pids = {} 146 | self.global_auc = {} 147 | self.integrated_auc = {} 148 | self.global_bs = {} 149 | self.integrated_bs = {} 150 | self.TwoStagesFitter_type = 'CoxPHFitter' 151 | 152 | def cross_validate(self, 153 | full_df: pd.DataFrame, 154 | l1_ratio: float, 155 | penalizers: list, 156 | n_splits: int = 5, 157 | shuffle: bool = True, 158 | seed: Union[int, None] = None, 159 | event_type_col: str = 'J', 160 | duration_col: str = 'X', 161 | pid_col: str = 'pid', 162 | twostages_fit_kwargs: dict = {'nb_workers': WORKERS}, 163 | metrics=['IBS', 'GBS', 'IAUC', 'GAUC']) -> pd.DataFrame: 164 | 165 | """ 166 | This method implements K-fold cross-validation using PenaltyGridSearch and full_df data. 167 | 168 | Args: 169 | full_df (pd.DataFrame): Data to cross validate. 170 | l1_ratio (float): regularization ratio for the CoxPHFitter (see lifelines.fitters.coxph_fitter.CoxPHFitter documentation). 171 | penalizers (list): penalizer options for each event (see lifelines.fitters.coxph_fitter.CoxPHFitter documentation). 172 | n_splits (int): Number of folds, defaults to 5. 173 | shuffle (boolean): Shuffle samples before splitting to folds. Defaults to True. 174 | seed: Pseudo-random seed to KFold instance. Defaults to None. 175 | event_type_col (str): The event type column name (must be a column in df), Right-censored sample (i) is indicated by event value 0, df.loc[i, event_type_col] = 0. 176 | duration_col (str): Last follow up time column name (must be a column in full_df). 177 | pid_col (str): Sample ID column name (must be a column in full_df). 178 | twostages_fit_kwargs (dict): keyword arguments to pass to each TwoStagesFitter. 179 | metrics (str, list): Evaluation metrics. 180 | 181 | Returns: 182 | gauc_output_df (pd.DataFrame): Global AUC k-fold mean and standard error for all possible combination of the penalizers. 183 | """ 184 | 185 | if isinstance(metrics, str): 186 | metrics = [metrics] 187 | 188 | self.folds_grids = {} 189 | self.kfold_cv = KFold(n_splits=n_splits, shuffle=shuffle, random_state=seed) 190 | 191 | if 'C' in full_df.columns: 192 | full_df = full_df.drop(['C'], axis=1) 193 | if 'T' in full_df.columns: 194 | full_df = full_df.drop(['T'], axis=1) 195 | 196 | for i_fold, (train_index, test_index) in enumerate(self.kfold_cv.split(full_df)): 197 | print(f'Starting fold {i_fold+1}/{n_splits}') 198 | start = time() 199 | self.test_pids[i_fold] = full_df.iloc[test_index][pid_col].values 200 | train_df, test_df = full_df.iloc[train_index], full_df.iloc[test_index] 201 | if self.TwoStagesFitter_type == 'Exact': 202 | fold_pgs = PenaltyGridSearchExact() 203 | else: 204 | fold_pgs = PenaltyGridSearch() 205 | 206 | fold_pgs.evaluate(train_df=train_df, 207 | test_df=test_df, 208 | l1_ratio=l1_ratio, 209 | penalizers=penalizers, 210 | metrics=metrics, 211 | seed=seed, 212 | event_type_col=event_type_col, 213 | duration_col=duration_col, 214 | pid_col=pid_col, 215 | twostages_fit_kwargs=twostages_fit_kwargs) 216 | 217 | self.folds_grids[i_fold] = fold_pgs 218 | 219 | for metric in metrics: 220 | if metric == 'GAUC': 221 | self.global_auc[i_fold] = fold_pgs.convert_results_dict_to_df(fold_pgs.global_auc) 222 | elif metric == 'IAUC': 223 | self.integrated_auc[i_fold] = fold_pgs.convert_results_dict_to_df(fold_pgs.integrated_auc) 224 | elif metric == 'GBS': 225 | self.global_bs[i_fold] = fold_pgs.convert_results_dict_to_df(fold_pgs.global_bs) 226 | elif metric == 'IBS': 227 | self.integrated_bs[i_fold] = fold_pgs.convert_results_dict_to_df(fold_pgs.integrated_bs) 228 | 229 | end = time() 230 | print(f'Finished fold {i_fold+1}/{n_splits}, {int(end-start)} seconds') 231 | 232 | if 'GAUC' in metrics: 233 | res = [v for k, v in self.global_auc.items()] 234 | gauc_output_df = pd.concat([pd.concat(res, axis=1).mean(axis=1), 235 | pd.concat(res, axis=1).std(axis=1)], 236 | keys=['Mean', 'SE'], axis=1) 237 | else: 238 | gauc_output_df = pd.DataFrame() 239 | return gauc_output_df 240 | 241 | 242 | class PenaltyGridSearchCV(BasePenaltyGridSearchCV): 243 | 244 | def __init__(self): 245 | super().__init__() 246 | self.TwoStagesFitter_type = 'CoxPHFitter' 247 | 248 | 249 | class PenaltyGridSearchCVExact(BasePenaltyGridSearchCV): 250 | 251 | def __init__(self): 252 | super().__init__() 253 | self.TwoStagesFitter_type = 'Exact' 254 | 255 | 256 | class TwoStagesCV(BaseTwoStagesCV): 257 | 258 | def __init__(self): 259 | super().__init__() 260 | self.TwoStagesFitter_type = 'CoxPHFitter' 261 | 262 | 263 | class TwoStagesCVExact(BaseTwoStagesCV): 264 | 265 | def __init__(self): 266 | super().__init__() 267 | self.TwoStagesFitter_type = 'Exact' 268 | -------------------------------------------------------------------------------- /src/pydts/data_generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import expit 3 | from typing import Union 4 | import pandas as pd 5 | 6 | 7 | class EventTimesSampler(object): 8 | 9 | def __init__(self, d_times: int, j_event_types: int): 10 | """ 11 | This class implements sampling procedure for discrete event times and censoring times for given observations. 12 | 13 | Args: 14 | d_times (int): number of possible event times 15 | j_event_types (int): number of possible event types 16 | """ 17 | 18 | self.d_times = d_times 19 | self.times = range(1, self.d_times + 2) # d + 1 for administrative censoring 20 | self.j_event_types = j_event_types 21 | self.events = range(1, self.j_event_types + 1) 22 | 23 | def _validate_prob_dfs_list(self, dfs_list: list, numerical_error_tolerance: float = 0.001) -> list: 24 | for df in dfs_list: 25 | if (((df < (0-numerical_error_tolerance)).any().any()) or ((df > (1+numerical_error_tolerance)).any().any())): 26 | raise ValueError("The chosen sampling parameters result in invalid probabilities for event j at time t") 27 | # Only fixes numerical errors smaller than the tolerance size 28 | df.clip(0, 1, inplace=True) 29 | return dfs_list 30 | 31 | def calculate_hazards(self, observations_df: pd.DataFrame, hazard_coefs: dict, events: list = None) -> list: 32 | """ 33 | Calculates the hazard function for the observations given the hazard coefficients. 34 | 35 | Args: 36 | observations_df (pd.DataFrame): Dataframe with observations covariates. 37 | coefficients (dict): time coefficients and covariates coefficients for each event type. 38 | 39 | Returns: 40 | hazards_dfs (list): A list of dataframes, one for each event type, with the hazard function at time t to each of the observations. 41 | """ 42 | events = events if events is not None else self.events 43 | a_t = {} 44 | for event in events: 45 | if callable(hazard_coefs['alpha'][event]): 46 | a_t[event] = {t: hazard_coefs['alpha'][event](t) for t in range(1, self.d_times + 1)} 47 | else: 48 | a_t[event] = {t: hazard_coefs['alpha'][event][t-1] for t in range(1, self.d_times + 1)} 49 | b = pd.concat([observations_df.dot(hazard_coefs['beta'][j]) for j in events], axis=1, keys=events) 50 | hazards_dfs = [pd.concat([expit((a_t[j][t] + b[j]).astype(float)) for t in range(1, self.d_times + 1)], 51 | axis=1, keys=(range(1, self.d_times + 1))) for j in events] 52 | return hazards_dfs 53 | 54 | def calculate_overall_survival(self, hazards: list, numerical_error_tolerance: float = 0.001) -> pd.DataFrame: 55 | """ 56 | Calculates the overall survival function given the hazards. 57 | 58 | Args: 59 | hazards (list): A list of hazards dataframes for each event type (as returned from EventTimesSampler.calculate_hazards function). 60 | numerical_error_tolerance (float): Tolerate numerical errors of probabilities up to this value. 61 | 62 | Returns: 63 | overall_survival (pd.Dataframe): The overall survival functions. 64 | """ 65 | if (((sum(hazards)) > (1 + numerical_error_tolerance)).sum().sum() > 0): 66 | raise ValueError("The chosen sampling parameters result in negative values of the overall survival function") 67 | sum_hazards = sum(hazards).clip(0, 1) 68 | overall_survival = pd.concat([pd.Series(1, index=hazards[0].index), 69 | (1 - sum_hazards).cumprod(axis=1).iloc[:, :-1]], axis=1) 70 | overall_survival.columns += 1 71 | return overall_survival 72 | 73 | def calculate_prob_event_at_t(self, hazards: list, overall_survival: pd.DataFrame, 74 | numerical_error_tolerance: float = 0.001) -> list: 75 | """ 76 | Calculates the probability for event j at time t. 77 | 78 | Args: 79 | hazards (list): A list of hazards dataframes for each event type (as returned from EventTimesSampler.calculate_hazards function) 80 | overall_survival (pd.Dataframe): The overall survival functions 81 | numerical_error_tolerance (float): Tolerate numerical errors of probabilities up to this value. 82 | 83 | Returns: 84 | prob_event_at_t (list): A list of dataframes, one for each event type, with the probability of event occurrance at time t to each of the observations. 85 | """ 86 | prob_event_at_t = [hazard * overall_survival for hazard in hazards] 87 | prob_event_at_t = self._validate_prob_dfs_list(prob_event_at_t, numerical_error_tolerance) 88 | return prob_event_at_t 89 | 90 | def calculate_prob_event_j(self, prob_j_at_t: list, numerical_error_tolerance: float = 0.001) -> list: 91 | """ 92 | Calculates the total probability for event j. 93 | 94 | Args: 95 | prob_j_at_t (list): A list of dataframes, one for each event type, with the probability of event occurrance at time t to each of the observations. 96 | numerical_error_tolerance (float): Tolerate numerical errors of probabilities up to this value. 97 | 98 | Returns: 99 | total_prob_j (list): A list of dataframes, one for each event type, with the total probability of event occurrance to each of the observations. 100 | """ 101 | total_prob_j = [prob.sum(axis=1) for prob in prob_j_at_t] 102 | total_prob_j = self._validate_prob_dfs_list(total_prob_j, numerical_error_tolerance) 103 | return total_prob_j 104 | 105 | def calc_prob_t_given_j(self, prob_j_at_t, total_prob_j, numerical_error_tolerance=0.001): 106 | """ 107 | Calculates the conditional probability for event occurrance at time t given J_i=j. 108 | 109 | Args: 110 | prob_j_at_t (list): A list of dataframes, one for each event type, with the probability of event occurrance at time t to each of the observations. 111 | total_prob_j (list): A list of dataframes, one for each event type, with the total probability of event occurrance to each of the observations. 112 | numerical_error_tolerance (float): Tolerate numerical errors of probabilities up to this value. 113 | 114 | Returns: 115 | conditional_prob (list): A list of dataframes, one for each event type, with the conditional probability of event occurrance at t given event type j to each of the observations. 116 | """ 117 | conditional_prob = [prob.div(sumj, axis=0) for prob, sumj in zip(prob_j_at_t, total_prob_j)] 118 | conditional_prob = self._validate_prob_dfs_list(conditional_prob, numerical_error_tolerance) 119 | return conditional_prob 120 | 121 | def sample_event_times(self, observations_df: pd.DataFrame, 122 | hazard_coefs: dict, 123 | covariates: Union[list, None] = None, 124 | events: Union[list, None] = None, 125 | seed: Union[int, None] = None) -> pd.DataFrame: 126 | """ 127 | Sample event type and event occurance times. 128 | 129 | Args: 130 | observations_df (pd.DataFrame): Dataframe with observations covariates. 131 | covariates (list): list of covariates name, must be a subset of observations_df.columns 132 | coefficients (dict): time coefficients and covariates coefficients for each event type. 133 | seed (int, None): numpy seed number for pseudo random sampling. 134 | 135 | Returns: 136 | observations_df (pd.DataFrame): Dataframe with additional columns for sampled event time (T) and event type (J). 137 | """ 138 | np.random.seed(seed) 139 | if covariates is None: 140 | covariates = [c for c in observations_df.columns if c not in ['X', 'T', 'C', 'J']] 141 | events = events if events is not None else self.events 142 | cov_df = observations_df[covariates] 143 | hazards = self.calculate_hazards(cov_df, hazard_coefs, events=events) 144 | overall_survival = self.calculate_overall_survival(hazards) 145 | probs_j_at_t = self.calculate_prob_event_at_t(hazards, overall_survival) 146 | total_prob_j = self.calculate_prob_event_j(probs_j_at_t) 147 | probs_t_given_j = self.calc_prob_t_given_j(probs_j_at_t, total_prob_j) 148 | sampled_jt = self.sample_jt(total_prob_j, probs_t_given_j) 149 | if 'J' in observations_df.columns: 150 | observations_df.drop('J', axis=1, inplace=True) 151 | if 'T' in observations_df.columns: 152 | observations_df.drop('T', axis=1, inplace=True) 153 | observations_df = pd.concat([observations_df, sampled_jt], axis=1) 154 | return observations_df 155 | 156 | def sample_jt(self, total_prob_j: list, probs_t_given_j: list, numerical_error_tolerance: float = 0.001) -> pd.DataFrame: 157 | """ 158 | Sample event type and event time for each observation. 159 | 160 | Args: 161 | total_prob_j (list): A list of dataframes, one for each event type, with the total probability of event occurrance to each of the observations. 162 | probs_t_given_j (list): A list of dataframes, one for each event type, with the conditional probability of event occurrance at t given event type j to each of the observations. 163 | 164 | Returns: 165 | sampled_df (pd.DataFrame): A dataframe with sampled event time and event type for each observation. 166 | """ 167 | 168 | total_prob_j = self._validate_prob_dfs_list(total_prob_j, numerical_error_tolerance) 169 | probs_t_given_j = self._validate_prob_dfs_list(probs_t_given_j, numerical_error_tolerance) 170 | 171 | # Add administrative censoring (no event occured until Tmax) probability as J=0 172 | temp_sums = pd.concat([1 - sum(total_prob_j), *total_prob_j], axis=1, keys=[0, *self.events]) 173 | if (((temp_sums < (0 - numerical_error_tolerance)).any().any()) or \ 174 | ((temp_sums > (1 + numerical_error_tolerance)).any().any())): 175 | raise ValueError("The chosen sampling parameters result in invalid probabilities") 176 | # Only fixes numerical errors smaller than the tolerance size 177 | temp_sums.clip(0, 1, inplace=True) 178 | 179 | # Efficient way to sample j for each observation with different event probabilities 180 | sampled_df = (temp_sums.cumsum(1) > np.random.rand(temp_sums.shape[0])[:, None]).idxmax(axis=1).to_frame('J') 181 | 182 | temp_ts = [] 183 | for j in self.events: 184 | # Get the index of the observations with J_i = j 185 | rel_j = sampled_df.query("J==@j").index 186 | 187 | # Get probs dataframe from the dfs list 188 | prob_df = probs_t_given_j[j - 1] # the prob j to sample from 189 | 190 | # Sample time of occurrence given J_i = j 191 | temp_ts.append((prob_df.loc[rel_j].cumsum(1) >= np.random.rand(rel_j.shape[0])[:, None]).idxmax(axis=1)) 192 | 193 | # Add Tmax+1 for observations with J_i = 0 194 | temp_ts.append(pd.Series(self.d_times + 1, index=sampled_df.query('J==0').index)) 195 | sampled_df["T"] = pd.concat(temp_ts).sort_index() 196 | return sampled_df 197 | 198 | def sample_independent_lof_censoring(self, observations_df: pd.DataFrame, 199 | prob_lof_at_t: np.array, seed: Union[int, None] = None) -> pd.DataFrame: 200 | """ 201 | Samples loss of follow-up censoring time from probabilities independent of covariates. 202 | 203 | Args: 204 | observations_df (pd.DataFrame): Dataframe with observations covariates. 205 | prob_lof_at_t (np.array): Array of probabilities for sampling each of the possible times. 206 | seed (int): pseudo random seed number for numpy.random.seed() 207 | 208 | Returns: 209 | observations_df (pd.DataFrame): Upadted dataframe including sampled censoring time. 210 | """ 211 | np.random.seed(seed) 212 | administrative_censoring_prob = (1 - sum(prob_lof_at_t)) 213 | assert (administrative_censoring_prob >= 0), "Check the sum of prob_lof_at_t argument." 214 | assert (administrative_censoring_prob <= 1), "Check the sum of prob_lof_at_t argument." 215 | 216 | prob_lof_at_t = np.append(prob_lof_at_t, administrative_censoring_prob) 217 | sampled_df = pd.DataFrame(np.random.choice(a=self.times, size=len(observations_df), p=prob_lof_at_t), 218 | index=observations_df.index, columns=['C']) 219 | # No follow-up censoring, C=d+2 such that T wins when building X column: 220 | #sampled_df.loc[sampled_df['C'] == self.times[-1], 'C'] = self.d_times + 2 221 | if 'C' in observations_df.columns: 222 | observations_df.drop('C', axis=1, inplace=True) 223 | observations_df = pd.concat([observations_df, sampled_df], axis=1) 224 | return observations_df 225 | 226 | def sample_hazard_lof_censoring(self, observations_df: pd.DataFrame, censoring_hazard_coefs: dict, 227 | seed: Union[int, None] = None, 228 | covariates: Union[list, None] = None) -> pd.DataFrame: 229 | """ 230 | Samples loss of follow-up censoring time from hazard coefficients. 231 | 232 | Args: 233 | observations_df (pd.DataFrame): Dataframe with observations covariates. 234 | censoring_hazard_coefs (dict): time coefficients and covariates coefficients for the censoring hazard. 235 | seed (int): pseudo random seed number for numpy.random.seed() 236 | covariates (list): list of covariates names, must be a subset of observations_df.columns 237 | 238 | Returns: 239 | observations_df (pd.DataFrame): Upadted dataframe including sampled censoring time. 240 | """ 241 | if covariates is None: 242 | covariates = [c for c in observations_df.columns if c not in ['X', 'T', 'C', 'J']] 243 | cov_df = observations_df[covariates] 244 | tmp_ets = EventTimesSampler(d_times=self.d_times, j_event_types=1) 245 | sampled_df = tmp_ets.sample_event_times(cov_df, censoring_hazard_coefs, seed=seed, covariates=covariates, 246 | events=[0]) 247 | 248 | # No follow-up censoring, C=d+2 such that T wins when building X column: 249 | #sampled_df.loc[sampled_df['J'] == 0, 'T'] = self.d_times + 2 250 | sampled_df = sampled_df[['T']] 251 | sampled_df.columns = ['C'] 252 | if 'C' in observations_df.columns: 253 | observations_df.drop('C', axis=1, inplace=True) 254 | observations_df = pd.concat([observations_df, sampled_df], axis=1) 255 | return observations_df 256 | 257 | def update_event_or_lof(self, observations_df: pd.DataFrame) -> pd.DataFrame: 258 | """ 259 | Updates time column 'X' to be the minimum between event time column 'T' and censoring time column 'C'. 260 | Event type 'J' will be changed to 0 for observation with 'C' < 'T'. 261 | 262 | Args: 263 | observations_df (pd.DataFrame): Dataframe with observations after sampling event times 'T' and censoring time 'C'. 264 | 265 | Returns: 266 | observations_df (pd.DataFrame): Dataframe with updated time column 'X' and event type column 'J' 267 | """ 268 | assert ('T' in observations_df.columns), "Trying to update event or censoring before sampling event times" 269 | assert ('C' in observations_df.columns), "Trying to update event or censoring before sampling censoring time" 270 | observations_df['X'] = observations_df[['T', 'C']].min(axis=1) 271 | observations_df.loc[observations_df.loc[(observations_df['C'] < observations_df['T'])].index, 'J'] = 0 272 | return observations_df 273 | -------------------------------------------------------------------------------- /src/pydts/examples_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomer1812/pydts/578929a457c111efe009d3461aab531b793b33d0/src/pydts/examples_utils/__init__.py -------------------------------------------------------------------------------- /src/pydts/examples_utils/datasets.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pydts.config import * 3 | 4 | DATASETS_DIR = os.path.join(os.path.dirname((os.path.dirname(__file__))), 'datasets') 5 | 6 | def load_LOS_simulated_data(): 7 | os.path.join(os.path.dirname(__file__)) 8 | return pd.read_csv(os.path.join(DATASETS_DIR, 'LOS_simulated_data.csv')) -------------------------------------------------------------------------------- /src/pydts/examples_utils/generate_simulations_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pydts.examples_utils.simulations_data_config import * 3 | from pydts.config import * 4 | import pandas as pd 5 | from scipy.special import expit 6 | from pandarallel import pandarallel 7 | 8 | 9 | def sample_los(new_patient, age_mean, age_std, bmi_mean, bmi_std, coefs=COEFS, baseline_hazard_scale=8, 10 | los_bounds=[1, 150]): 11 | # Columns normalization: 12 | new_patient[AGE_COL] = (new_patient[AGE_COL] - age_mean) / age_std 13 | new_patient[GENDER_COL] = 2 * (new_patient[GENDER_COL] - 0.5) 14 | new_patient[BMI_COL] = (new_patient[BMI_COL] - bmi_mean) / bmi_std 15 | new_patient[SMOKING_COL] = new_patient[SMOKING_COL] - 1 16 | new_patient[HYPERTENSION_COL] = 2 * (new_patient[HYPERTENSION_COL] - 0.5) 17 | new_patient[DIABETES_COL] = 2 * (new_patient[DIABETES_COL] - 0.5) 18 | new_patient[ART_FIB_COL] = 2 * (new_patient[ART_FIB_COL] - 0.5) 19 | new_patient[COPD_COL] = 2 * (new_patient[COPD_COL] - 0.5) 20 | new_patient[CRF_COL] = 2 * (new_patient[CRF_COL] - 0.5) 21 | new_patient = pd.Series(new_patient) 22 | 23 | # Baseline hazard 24 | baseline_hazard = np.random.exponential(scale=baseline_hazard_scale) 25 | 26 | # Patient's correction 27 | beta_x = coefs.dot(new_patient[coefs.index]) 28 | 29 | # Sample, round (for ties), and clip to bounds the patient's length of stay at the hospital 30 | los = np.clip(np.round(baseline_hazard * np.exp(beta_x)), a_min=los_bounds[0], a_max=los_bounds[1]) 31 | los_death = np.nan if new_patient[IN_HOSPITAL_DEATH_COL] == 0 else los 32 | return los, los_death 33 | 34 | 35 | def hide_weight_info(row): 36 | #admyear = row[ADMISSION_YEAR_COL] 37 | p_weight = 0.3 #+ int(admyear > (min_year + 3)) * 0.8 * ((admyear - min_year) / (max_year - min_year)) 38 | sample_weight = np.random.binomial(1, p=p_weight) 39 | if sample_weight == 0: 40 | row[WEIGHT_COL] = np.nan 41 | row[BMI_COL] = np.nan 42 | return row 43 | 44 | 45 | def main(seed=0, N_patients=DEFAULT_N_PATIENTS, output_dir=OUTPUT_DIR, filename=SIMULATED_DATA_FILENAME): 46 | # Set random seed for consistent sampling 47 | np.random.seed(seed) 48 | 49 | # Female - 1, Male - 0 50 | gender = np.random.binomial(n=1, p=0.5, size=N_patients) 51 | 52 | simulated_patients_df = pd.DataFrame() 53 | 54 | for p in range(N_patients): 55 | # Sample gender dependent age for each patient 56 | age = np.round(np.random.normal(loc=72 + 5 * gender[p], scale=12), decimals=1) 57 | 58 | # Random sample admission year 59 | admyear = np.random.randint(low=min_year, high=max_year) 60 | 61 | # Sample gender dependent height 62 | height = np.random.normal(loc=175 - 5 * gender[p], scale=7) 63 | 64 | # Sample height, gender and age dependent weight 65 | weight = np.random.normal(loc=(height / 175) * 80 - 5 * gender[p] + (age / 20), scale=8) 66 | 67 | # Calculate body mass index (BMI) from weight and height 68 | bmi = weight / ((height / 100) ** 2) 69 | 70 | # Random sample of previous admissions 71 | admserial = np.clip(np.round(np.random.lognormal(mean=0, sigma=0.75)), 1, 20) 72 | 73 | # Random sample of categorical smoking status: No - 0, Previously - 1, Currently - 2 74 | smoking = np.random.choice([0, 1, 2], p=[0.5, 0.3, 0.2]) 75 | 76 | # Sample patient's preconditions based on gender, age, BMI, and smoking status with limits on the value of p 77 | pre_p = np.clip((bmi_coef * bmi + gender_coef * gender[p] + age_coef * age + smk_coef * smoking), 78 | a_min=0.05, a_max=max_p) 79 | hypertension = np.random.binomial(n=1, p=pre_p) 80 | diabetes = np.random.binomial(n=1, p=pre_p + bmi_coef * bmi) 81 | artfib = np.random.binomial(n=1, p=pre_p) # Arterial Fibrillation 82 | copd = np.random.binomial(n=1, p=pre_p + smk_coef * smoking) # Chronic Obstructive Pulmonary Disease 83 | crf = np.random.binomial(n=1, p=pre_p) # Chronic Renal Failure 84 | 85 | new_patient = { 86 | PATIENT_NO_COL: p, 87 | AGE_COL: age, 88 | GENDER_COL: gender[p], 89 | ADMISSION_YEAR_COL: int(admyear), 90 | FIRST_ADMISSION_COL: int(admserial == 1), 91 | ADMISSION_SERIAL_COL: int(admserial), 92 | WEIGHT_COL: weight, 93 | HEIGHT_COL: height, 94 | BMI_COL: bmi, 95 | SMOKING_COL: smoking, 96 | HYPERTENSION_COL: hypertension, 97 | DIABETES_COL: diabetes, 98 | ART_FIB_COL: artfib, 99 | COPD_COL: copd, 100 | CRF_COL: crf, 101 | } 102 | 103 | simulated_patients_df = simulated_patients_df.append(new_patient, ignore_index=True) 104 | 105 | # [age, gender, ADMISSION_YEAR, FIRST_ADMISSION, ADMISSION_SERIAL, WEIGHT_COL, HEIGHT_COL, BMI, SMOKING_COL, 5*preconditions] 106 | #b1 = [-0.003, 0.05, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 107 | b2 = [0.003, -0.05, 0, 0, 0, 0.003, 0.0001, 0.005, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02] 108 | 109 | real_coef_dict = { 110 | "alpha": { 111 | 1: lambda t: -2.2 - 0.2 * np.log(t), 112 | 2: lambda t: -2.3 - 0.2 * np.log(t) 113 | }, 114 | "beta": { 115 | 1: [-1*b for b in b2], 116 | 2: [b for b in b2] 117 | } 118 | } 119 | 120 | # Sample event type and relative time 121 | events_df = new_sample_logic(simulated_patients_df.set_index(PATIENT_NO_COL), j_events=2, d_times=30, 122 | real_coef_dict=real_coef_dict) 123 | simulated_patients_df = pd.concat([simulated_patients_df, events_df], axis=1) 124 | 125 | simulated_patients_df[DEATH_RELATIVE_COL] = np.nan 126 | simulated_patients_df.loc[simulated_patients_df.J == 1, DEATH_RELATIVE_COL] = simulated_patients_df.loc[ 127 | simulated_patients_df.J == 1, 'T'].values 128 | simulated_patients_df = simulated_patients_df.rename(columns={'T': DISCHARGE_RELATIVE_COL}) 129 | simulated_patients_df.loc[simulated_patients_df.J == 0, DISCHARGE_RELATIVE_COL] = 31 130 | simulated_patients_df = simulated_patients_df.drop('J', axis=1) 131 | 132 | # Remove weight and bmi based on admission year 133 | simulated_patients_df = simulated_patients_df.apply(hide_weight_info, axis=1) 134 | 135 | simulated_patients_df[DEATH_MISSING_COL] = simulated_patients_df[DEATH_RELATIVE_COL].isnull().astype(int) 136 | simulated_patients_df[IN_HOSPITAL_DEATH_COL] = simulated_patients_df[DEATH_RELATIVE_COL].notnull().astype(int) 137 | simulated_patients_df[RETURNING_PATIENT_COL] = pd.cut(simulated_patients_df[ADMISSION_SERIAL_COL], 138 | bins=ADMISSION_SERIAL_BINS, labels=ADMISSION_SERIAL_LABELS) 139 | 140 | simulated_patients_df.set_index(PATIENT_NO_COL).to_csv(os.path.join(output_dir, filename)) 141 | 142 | 143 | def default_sampling_logic(Z, d_times): 144 | alpha1t = -1 -0.3*np.log(np.arange(start=1, stop=d_times+1)) 145 | beta1 = -np.log([0.8, 3, 3, 2.5, 2]) 146 | alpha2t = -1.75 -0.15*np.log(np.arange(start=1, stop=d_times+1)) 147 | beta2 = -np.log([1, 3, 4, 3, 2]) 148 | hazard1 = expit(alpha1t+(Z*beta1).sum()) 149 | hazard2 = expit(alpha2t+(Z*beta2).sum()) 150 | surv_func = np.array([1, *np.cumprod(1-hazard1-hazard2)[:-1]]) 151 | proba1 = hazard1*surv_func 152 | proba2 = hazard2*surv_func 153 | sum1 = np.sum(proba1) 154 | sum2 = np.sum(proba2) 155 | probj1t = proba1 / sum1 156 | probj2t = proba2 / sum2 157 | j_i = np.random.choice(a=[0, 1, 2], p=[1-sum1-sum2, sum1, sum2]) 158 | if j_i == 0: 159 | T_i = d_times 160 | elif j_i == 1: 161 | T_i = np.random.choice(a=np.arange(1, d_times+1), p=probj1t) 162 | else: 163 | T_i = np.random.choice(a=np.arange(1, d_times+1), p=probj2t) 164 | return j_i, T_i 165 | 166 | 167 | def calculate_jt(sums, probs_jt, d_times, j_events): 168 | events = range(1, j_events + 1) 169 | temp_sums = pd.concat([1 - sum(sums), *sums], axis=1, keys=[0, *events]) 170 | 171 | j_df = (temp_sums.cumsum(1) > np.random.rand(temp_sums.shape[0])[:, None]).idxmax(axis=1).to_frame('J') 172 | 173 | temp_ts = [] 174 | for j in events: 175 | rel_j = j_df.query("J==@j").index 176 | prob_df = probs_jt[j - 1] # the prob j to sample from 177 | # sample T 178 | temp_ts.append((prob_df.loc[rel_j].cumsum(1) >= np.random.rand(rel_j.shape[0])[:, None]).idxmax(axis=1)) 179 | 180 | temp_ts.append(pd.Series(d_times+1, index=j_df.query('J==0').index)) 181 | 182 | j_df["T"] = pd.concat(temp_ts).sort_index() 183 | return j_df 184 | 185 | 186 | def new_sample_logic(patients_df: pd.DataFrame, j_events: int, d_times: int, real_coef_dict: dict) -> pd.DataFrame: 187 | """ 188 | A quicker sample logic, that uses coefficients supplied by the user 189 | """ 190 | events = range(1, j_events + 1) 191 | a_t = {event: {t: real_coef_dict['alpha'][event](t) for t in range(1, d_times+1)} for event in events} 192 | b = pd.concat([patients_df.dot(real_coef_dict['beta'][j]) for j in events], axis=1, keys=events) 193 | 194 | hazards = [pd.concat([expit(a_t[j][t] + b[j]) for t in range(1, d_times + 1)], 195 | axis=1, keys=(range(1, d_times + 1))) for j in events] 196 | surv_func = pd.concat([pd.Series(1, index=hazards[0].index), 197 | (1 - sum(hazards)).cumprod(axis=1).iloc[:, :-1]], axis=1) 198 | surv_func.columns += 1 199 | 200 | probs = [hazard * surv_func for hazard in hazards] 201 | sums = [prob.sum(axis=1) for prob in probs] 202 | probjt = [prob.div(sumj, axis=0) for prob, sumj in zip(probs, sums)] 203 | 204 | ret = calculate_jt(sums, probjt, d_times, j_events) 205 | return ret 206 | 207 | 208 | def generate_quick_start_df(n_patients=10000, d_times=30, j_events=2, n_cov=5, seed=0, pid_col='pid', 209 | real_coef_dict: dict = DEFAULT_REAL_COEF_DICT, sampling_logic=new_sample_logic, 210 | censoring_prob=1.): 211 | np.random.seed(seed) 212 | assert real_coef_dict is not None, "The user should supply the coefficients of the experiment" 213 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 214 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]), 215 | columns=covariates) 216 | sampled = sampling_logic(patients_df, j_events, d_times, real_coef_dict) 217 | patients_df = pd.concat([patients_df, sampled], axis=1) 218 | patients_df.index.name = pid_col 219 | patients_df['C'] = np.where(np.random.rand(n_patients) < censoring_prob, 220 | np.random.randint(low=1, high=d_times+1, 221 | size=n_patients), d_times+1) 222 | patients_df['X'] = patients_df[['T', 'C']].min(axis=1) 223 | patients_df.loc[patients_df['C'] < patients_df['T'], 'J'] = 0 224 | return patients_df.reset_index() 225 | 226 | 227 | if __name__ == "__main__": 228 | main(N_patients=50000) 229 | #generate_quick_start_df(n_patients=2) 230 | -------------------------------------------------------------------------------- /src/pydts/examples_utils/mimic_consts.py: -------------------------------------------------------------------------------- 1 | 2 | ADMISSION_TIME_COL = 'admittime' 3 | DISCHARGE_TIME_COL = 'dischtime' 4 | DEATH_TIME_COL = 'deathtime' 5 | ED_REG_TIME = 'edregtime' 6 | ED_OUT_TIME = 'edouttime' 7 | AGE_COL = 'anchor_age' 8 | GENDER_COL = 'gender' 9 | 10 | AGE_BINS = list(range(0, 125, 5)) 11 | AGE_LABELS = [f'{AGE_BINS[a]}' for a in range(len(AGE_BINS) - 1)] 12 | 13 | font_sz = 14 14 | title_sz = 18 15 | 16 | YEAR_GROUP_COL = 'anchor_year_group' 17 | SUBSET_YEAR_GROUP = '2017 - 2019' 18 | SUBJECT_ID_COL = 'subject_id' 19 | ADMISSION_ID_COL = 'hadm_id' 20 | ADMISSION_TYPE_COL = 'admission_type' 21 | CHART_TIME_COL = 'charttime' 22 | STORE_TIME_COL = 'storetime' 23 | LOS_EXACT_COL = 'LOS exact' 24 | LOS_DAYS_COL = 'LOS days' 25 | ADMISSION_LOCATION_COL = 'admission_location' 26 | DISCHARGE_LOCATION_COL = 'discharge_location' 27 | RACE_COL = 'race' 28 | INSURANCE_COL = 'insurance' 29 | ADMISSION_TO_RESULT_COL = 'admission_to_result_time' 30 | ADMISSION_AGE_COL = 'admission_age' 31 | ADMISSION_YEAR_COL = 'admission_year' 32 | ADMISSION_COUNT_COL = 'admissions_count' 33 | ITEM_ID_COL = 'itemid' 34 | NIGHT_ADMISSION_FLAG = 'night_admission' 35 | MARITAL_STATUS_COL = 'marital_status' 36 | STANDARDIZED_AGE_COL = 'standardized_age' 37 | COEF_COL = ' coef ' 38 | STDERR_COL = ' std err ' 39 | DIRECT_IND_COL = 'direct_emrgency_flag' 40 | PREV_ADMISSION_IND_COL = 'last_less_than_diff' 41 | ADMISSION_COUNT_GROUP_COL = ADMISSION_COUNT_COL + '_group' 42 | 43 | DISCHARGE_REGROUPING_DICT = { 44 | 'HOME': 'HOME', 45 | 'HOME HEALTH CARE': 'HOME', 46 | 'SKILLED NURSING FACILITY': 'FURTHER TREATMENT', 47 | 'DIED': 'DIED', 48 | 'REHAB': 'HOME', 49 | 'CHRONIC/LONG TERM ACUTE CARE': 'FURTHER TREATMENT', 50 | 'HOSPICE': 'FURTHER TREATMENT', 51 | 'AGAINST ADVICE': 'CENSORED', 52 | 'ACUTE HOSPITAL': 'FURTHER TREATMENT', 53 | 'PSYCH FACILITY': 'FURTHER TREATMENT', 54 | 'OTHER FACILITY': 'FURTHER TREATMENT', 55 | 'ASSISTED LIVING': 'HOME', 56 | 'HEALTHCARE FACILITY': 'FURTHER TREATMENT', 57 | } 58 | 59 | RACE_REGROUPING_DICT = { 60 | 'WHITE': 'WHITE', 61 | 'UNKNOWN': 'OTHER', 62 | 'BLACK/AFRICAN AMERICAN': 'BLACK', 63 | 'OTHER': 'OTHER', 64 | 'ASIAN': 'ASIAN', 65 | 'WHITE - OTHER EUROPEAN': 'WHITE', 66 | 'HISPANIC/LATINO - PUERTO RICAN': 'HISPANIC', 67 | 'HISPANIC/LATINO - DOMINICAN': 'HISPANIC', 68 | 'ASIAN - CHINESE': 'ASIAN', 69 | 'BLACK/CARIBBEAN ISLAND': 'BLACK', 70 | 'BLACK/AFRICAN': 'BLACK', 71 | 'BLACK/CAPE VERDEAN': 'BLACK', 72 | 'PATIENT DECLINED TO ANSWER': 'OTHER', 73 | 'WHITE - BRAZILIAN': 'WHITE', 74 | 'PORTUGUESE': 'HISPANIC', 75 | 'ASIAN - SOUTH EAST ASIAN': 'ASIAN', 76 | 'WHITE - RUSSIAN': 'WHITE', 77 | 'ASIAN - ASIAN INDIAN': 'ASIAN', 78 | 'WHITE - EASTERN EUROPEAN': 'WHITE', 79 | 'AMERICAN INDIAN/ALASKA NATIVE': 'OTHER', 80 | 'HISPANIC/LATINO - GUATEMALAN': 'HISPANIC', 81 | 'HISPANIC/LATINO - MEXICAN': 'HISPANIC', 82 | 'HISPANIC/LATINO - SALVADORAN': 'HISPANIC', 83 | 'SOUTH AMERICAN': 'HISPANIC', 84 | 'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER': 'OTHER', 85 | 'HISPANIC/LATINO - COLUMBIAN': 'HISPANIC', 86 | 'HISPANIC/LATINO - CUBAN': 'HISPANIC', 87 | 'ASIAN - KOREAN': 'ASIAN', 88 | 'HISPANIC/LATINO - HONDURAN': 'HISPANIC', 89 | 'HISPANIC/LATINO - CENTRAL AMERICAN': 'HISPANIC', 90 | 'UNABLE TO OBTAIN': 'OTHER', 91 | 'HISPANIC OR LATINO': 'HISPANIC' 92 | } 93 | 94 | # 'MCH': 'Mean Cell Hemoglobin', 95 | # 'MCHC': 'Mean Cell Hemoglobin Concentration', 96 | table1_rename_columns = { 97 | 'AnionGap': 'Anion gap', 98 | 'Bicarbonate': 'Bicarbonate', 99 | 'CalciumTotal': 'Calcium total', 100 | 'Chloride': 'Chloride', 101 | 'Creatinine': 'Creatinine', 102 | 'Glucose': 'Glucose', 103 | 'Magnesium': 'Magnesium', 104 | 'Phosphate': 'Phosphate', 105 | 'Potassium': 'Potassium', 106 | 'Sodium': 'Sodium', 107 | 'UreaNitrogen': 'Urea nitrogen', 108 | 'Hematocrit': 'Hematocrit', 109 | 'Hemoglobin': 'Hemoglobin', 110 | 'MCH': 'MCH', 111 | 'MCHC': 'MCHC', 112 | 'MCV': 'MCV', 113 | 'PlateletCount': 'Platelet count', 114 | 'RDW': 'RDW', 115 | 'RedBloodCells': 'Red blood cells', 116 | 'WhiteBloodCells': 'White blood cells', 117 | NIGHT_ADMISSION_FLAG: 'Night admission', 118 | GENDER_COL: 'Sex', 119 | DIRECT_IND_COL: 'Direct emergency', 120 | PREV_ADMISSION_IND_COL: 'Previous admission this month', 121 | ADMISSION_AGE_COL: 'Admission age', 122 | INSURANCE_COL: 'Insurance', 123 | MARITAL_STATUS_COL: 'Marital status', 124 | RACE_COL: 'Race', 125 | ADMISSION_COUNT_GROUP_COL: 'Admissions number', 126 | LOS_DAYS_COL: 'LOS (days)', 127 | DISCHARGE_LOCATION_COL: 'Discharge location' 128 | } 129 | 130 | table1_rename_sex = {0: 'Male', 1: 'Female'} 131 | table1_rename_race = {'ASIAN': 'Asian', 'BLACK': 'Black', 'HISPANIC': 'Hispanic', 'OTHER': 'Other', 132 | 'WHITE': 'White'} 133 | table1_rename_marital = {'SINGLE': 'Single', 'MARRIED': 'Married', 'DIVORCED': 'Divorced', 'WIDOWED': 'Widowed'} 134 | table1_rename_yes_no = {0: 'No', 1: 'Yes'} 135 | table1_rename_normal_abnormal = {0: 'Normal', 1: 'Abnormal'} 136 | table1_rename_discharge = {1: 'Home', 2: 'Further Treatment', 3: 'Died', 0: 'Censored'} 137 | 138 | 139 | 140 | 141 | rename_beta_index = { 142 | 'AdmsCount 2': 'Admissions number 2', 143 | 'AdmsCount 3up': 'Admissions number 3+', 144 | 'AnionGap': 'Anion gap', 145 | 'Bicarbonate': 'Bicarbonate', 146 | 'CalciumTotal': 'Calcium total', 147 | 'Chloride': 'Chloride', 148 | 'Creatinine': 'Creatinine', 149 | 'Ethnicity BLACK': 'Ethnicity black', 150 | 'Ethnicity HISPANIC': 'Ethnicity hispanic', 151 | 'Ethnicity OTHER': 'Ethnicity other', 152 | 'Ethnicity WHITE': 'Ethnicity white', 153 | 'Glucose': 'Glucose', 154 | 'Hematocrit': 'Hematocrit', 155 | 'Hemoglobin': 'Hemoglobin', 156 | 'Insurance Medicare': 'Insurance medicare', 157 | 'Insurance Other': 'Insurance other', 158 | 'MCH': 'MCH', 159 | 'MCHC': 'MCHC', 160 | 'MCV': 'MCV', 161 | 'Magnesium': 'Magnesium', 162 | 'Marital MARRIED': 'Marital married', 163 | 'Marital SINGLE': 'Marital single', 164 | 'Marital WIDOWED': 'Marital widowed', 165 | 'Phosphate': 'Phosphate', 166 | 'PlateletCount': 'Platelet count', 167 | 'Potassium': 'Potassium', 168 | 'RDW': 'RDW', 169 | 'RedBloodCells': 'Red blood cells', 170 | 'Sodium': 'Sodium', 171 | 'UreaNitrogen': 'Urea nitrogen', 172 | 'WhiteBloodCells': 'White blood cells', 173 | 'direct emrgency flag': 'Direct emergency', 174 | 'gender': 'Sex', 175 | 'last less than diff': 'Recent admission', 176 | 'night admission': 'Night admission', 177 | 'standardized age': 'Standardized age', 178 | } 179 | 180 | beta_units = { 181 | 'Admissions number 2': '2', 182 | 'Admissions number 3+': '3+', 183 | 'Anion gap': 'Abnormal', 184 | 'Bicarbonate': 'Abnormal', 185 | 'Calcium total': 'Abnormal', 186 | 'Chloride': 'Abnormal', 187 | 'Creatinine': 'Abnormal', 188 | 'Ethnicity black': 'Black', 189 | 'Ethnicity hispanic': 'Hispanic', 190 | 'Ethnicity other': 'Other', 191 | 'Ethnicity white': 'White', 192 | 'Glucose': 'Abnormal', 193 | 'Hematocrit': 'Abnormal', 194 | 'Hemoglobin': 'Abnormal', 195 | 'Insurance medicare': 'Medicare', 196 | 'Insurance other': 'Other', 197 | 'MCH': 'Abnormal', 198 | 'MCHC': 'Abnormal', 199 | 'MCV': 'Abnormal', 200 | 'Magnesium': 'Abnormal', 201 | 'Marital married': 'Married', 202 | 'Marital single': 'Single', 203 | 'Marital widowed': 'Widowed', 204 | 'Phosphate': 'Abnormal', 205 | 'Platelet count': 'Abnormal', 206 | 'Potassium': 'Abnormal', 207 | 'RDW': 'Abnormal', 208 | 'Red blood cells': 'Abnormal', 209 | 'Sodium': 'Abnormal', 210 | 'Urea nitrogen': 'Abnormal', 211 | 'White blood cells': 'Abnormal', 212 | 'Direct emergency': 'Yes', 213 | 'Sex': 'Female', 214 | 'Recent admission': 'Yes', 215 | 'Night admission': 'Yes', 216 | 'Standardized age': '', 217 | } 218 | 219 | # ADMISSION_TIME_COL = 'admittime' 220 | # DISCHARGE_TIME_COL = 'dischtime' 221 | # DEATH_TIME_COL = 'deathtime' 222 | # ED_REG_TIME = 'edregtime' 223 | # ED_OUT_TIME = 'edouttime' 224 | # AGE_COL = 'anchor_age' 225 | # GENDER_COL = 'gender' 226 | # 227 | # AGE_BINS = list(range(0, 125, 5)) 228 | # AGE_LABELS = [f'{AGE_BINS[a]}' for a in range(len(AGE_BINS) - 1)] 229 | # 230 | # font_sz = 14 231 | # title_sz = 18 232 | # 233 | # YEAR_GROUP_COL = 'anchor_year_group' 234 | # SUBSET_YEAR_GROUP = '2017 - 2019' 235 | # SUBJECT_ID_COL = 'subject_id' 236 | # ADMISSION_ID_COL = 'hadm_id' 237 | # ADMISSION_TYPE_COL = 'admission_type' 238 | # CHART_TIME_COL = 'charttime' 239 | # STORE_TIME_COL = 'storetime' 240 | # LOS_EXACT_COL = 'LOS exact' 241 | # LOS_DAYS_COL = 'LOS days' 242 | # ADMISSION_LOCATION_COL = 'admission_location' 243 | # DISCHARGE_LOCATION_COL = 'discharge_location' 244 | # RACE_COL = 'race' 245 | # INSURANCE_COL = 'insurance' 246 | # ADMISSION_TO_RESULT_COL = 'admission_to_result_time' 247 | # ADMISSION_AGE_COL = 'admission_age' 248 | # ADMISSION_YEAR_COL = 'admission_year' 249 | # ADMISSION_COUNT_COL = 'admissions_count' 250 | # ITEM_ID_COL = 'itemid' 251 | # NIGHT_ADMISSION_FLAG = 'night_admission' 252 | # MARITAL_STATUS_COL = 'marital_status' 253 | # STANDARDIZED_AGE_COL = 'standardized_age' 254 | # COEF_COL = ' coef ' 255 | # STDERR_COL = ' std err ' 256 | # DIRECT_IND_COL = 'direct_emrgency_flag' 257 | # PREV_ADMISSION_IND_COL = 'last_less_than_diff' 258 | # ADMISSION_COUNT_GROUP_COL = ADMISSION_COUNT_COL + '_group' 259 | # 260 | # DISCHARGE_REGROUPING_DICT = { 261 | # 'HOME': 'HOME', 262 | # 'HOME HEALTH CARE': 'HOME', 263 | # 'SKILLED NURSING FACILITY': 'FURTHER TREATMENT', 264 | # 'DIED': 'DIED', 265 | # 'REHAB': 'HOME', 266 | # 'CHRONIC/LONG TERM ACUTE CARE': 'FURTHER TREATMENT', 267 | # 'HOSPICE': 'FURTHER TREATMENT', 268 | # 'AGAINST ADVICE': 'CENSORED', 269 | # 'ACUTE HOSPITAL': 'FURTHER TREATMENT', 270 | # 'PSYCH FACILITY': 'FURTHER TREATMENT', 271 | # 'OTHER FACILITY': 'FURTHER TREATMENT', 272 | # 'ASSISTED LIVING': 'HOME', 273 | # 'HEALTHCARE FACILITY': 'FURTHER TREATMENT', 274 | # } 275 | # 276 | # RACE_REGROUPING_DICT = { 277 | # 'WHITE': 'WHITE', 278 | # 'UNKNOWN': 'OTHER', 279 | # 'BLACK/AFRICAN AMERICAN': 'BLACK', 280 | # 'OTHER': 'OTHER', 281 | # 'ASIAN': 'ASIAN', 282 | # 'WHITE - OTHER EUROPEAN': 'WHITE', 283 | # 'HISPANIC/LATINO - PUERTO RICAN': 'HISPANIC', 284 | # 'HISPANIC/LATINO - DOMINICAN': 'HISPANIC', 285 | # 'ASIAN - CHINESE': 'ASIAN', 286 | # 'BLACK/CARIBBEAN ISLAND': 'BLACK', 287 | # 'BLACK/AFRICAN': 'BLACK', 288 | # 'BLACK/CAPE VERDEAN': 'BLACK', 289 | # 'PATIENT DECLINED TO ANSWER': 'OTHER', 290 | # 'WHITE - BRAZILIAN': 'WHITE', 291 | # 'PORTUGUESE': 'HISPANIC', 292 | # 'ASIAN - SOUTH EAST ASIAN': 'ASIAN', 293 | # 'WHITE - RUSSIAN': 'WHITE', 294 | # 'ASIAN - ASIAN INDIAN': 'ASIAN', 295 | # 'WHITE - EASTERN EUROPEAN': 'WHITE', 296 | # 'AMERICAN INDIAN/ALASKA NATIVE': 'OTHER', 297 | # 'HISPANIC/LATINO - GUATEMALAN': 'HISPANIC', 298 | # 'HISPANIC/LATINO - MEXICAN': 'HISPANIC', 299 | # 'HISPANIC/LATINO - SALVADORAN': 'HISPANIC', 300 | # 'SOUTH AMERICAN': 'HISPANIC', 301 | # 'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER': 'OTHER', 302 | # 'HISPANIC/LATINO - COLUMBIAN': 'HISPANIC', 303 | # 'HISPANIC/LATINO - CUBAN': 'HISPANIC', 304 | # 'ASIAN - KOREAN': 'ASIAN', 305 | # 'HISPANIC/LATINO - HONDURAN': 'HISPANIC', 306 | # 'HISPANIC/LATINO - CENTRAL AMERICAN': 'HISPANIC', 307 | # 'UNABLE TO OBTAIN': 'OTHER', 308 | # 'HISPANIC OR LATINO': 'HISPANIC' 309 | # } 310 | # 311 | # # 'MCH': 'Mean Cell Hemoglobin', 312 | # # 'MCHC': 'Mean Cell Hemoglobin Concentration', 313 | # table1_rename_columns = { 314 | # 'AnionGap': 'Anion gap', 315 | # 'Bicarbonate': 'Bicarbonate', 316 | # 'CalciumTotal': 'Calcium total', 317 | # 'Chloride': 'Chloride', 318 | # 'Creatinine': 'Creatinine', 319 | # 'Glucose': 'Glucose', 320 | # 'Magnesium': 'Magnesium', 321 | # 'Phosphate': 'Phosphate', 322 | # 'Potassium': 'Potassium', 323 | # 'Sodium': 'Sodium', 324 | # 'UreaNitrogen': 'Urea nitrogen', 325 | # 'Hematocrit': 'Hematocrit', 326 | # 'Hemoglobin': 'Hemoglobin', 327 | # 'MCH': 'MCH', 328 | # 'MCHC': 'MCHC', 329 | # 'MCV': 'MCV', 330 | # 'PlateletCount': 'Platelet count', 331 | # 'RDW': 'RDW', 332 | # 'RedBloodCells': 'Red blood cells', 333 | # 'WhiteBloodCells': 'White blood cells', 334 | # NIGHT_ADMISSION_FLAG: 'Night admission', 335 | # GENDER_COL: 'Sex', 336 | # DIRECT_IND_COL: 'Direct emergency', 337 | # PREV_ADMISSION_IND_COL: 'Previous admission this month', 338 | # ADMISSION_AGE_COL: 'Admission age', 339 | # INSURANCE_COL: 'Insurance', 340 | # MARITAL_STATUS_COL: 'Marital status', 341 | # RACE_COL: 'Race', 342 | # ADMISSION_COUNT_GROUP_COL: 'Admissions number', 343 | # LOS_DAYS_COL: 'LOS (days)', 344 | # DISCHARGE_LOCATION_COL: 'Discharge location' 345 | # } 346 | # 347 | # table1_rename_sex = {0: 'Male', 1: 'Female'} 348 | # table1_rename_race = {'ASIAN': 'Asian', 'BLACK': 'Black', 'HISPANIC': 'Hispanic', 'OTHER': 'Other', 349 | # 'WHITE': 'White'} 350 | # table1_rename_marital = {'SINGLE': 'Single', 'MARRIED': 'Married', 'DIVORCED': 'Divorced', 'WIDOWED': 'Widowed'} 351 | # table1_rename_yes_no = {0: 'No', 1: 'Yes'} 352 | # table1_rename_normal_abnormal = {0: 'Normal', 1: 'Abnormal'} 353 | # table1_rename_discharge = {1: 'Home', 2: 'Further Treatment', 3: 'Died', 0: 'Censored'} 354 | 355 | # rename_beta_index = { 356 | # 'AdmsCount 2': 'Admissions number 2', 357 | # 'AdmsCount 3up': 'Admissions number 3+', 358 | # 'AnionGap': 'Anion gap', 359 | # 'Bicarbonate': 'Bicarbonate', 360 | # 'CalciumTotal': 'Calcium total', 361 | # 'Chloride': 'Chloride', 362 | # 'Creatinine': 'Creatinine', 363 | # 'Ethnicity BLACK': 'Ethnicity black', 364 | # 'Ethnicity HISPANIC': 'Ethnicity hispanic', 365 | # 'Ethnicity OTHER': 'Ethnicity other', 366 | # 'Ethnicity WHITE': 'Ethnicity white', 367 | # 'Glucose': 'Glucose', 368 | # 'Hematocrit': 'Hematocrit', 369 | # 'Hemoglobin': 'Hemoglobin', 370 | # 'Insurance Medicare': 'Insurance medicare', 371 | # 'Insurance Other': 'Insurance other', 372 | # 'MCH': 'MCH', 373 | # 'MCHC': 'MCHC', 374 | # 'MCV': 'MCV', 375 | # 'Magnesium': 'Magnesium', 376 | # 'Marital MARRIED': 'Marital married', 377 | # 'Marital SINGLE': 'Marital single', 378 | # 'Marital WIDOWED': 'Marital widowed', 379 | # 'Phosphate': 'Phosphate', 380 | # 'PlateletCount': 'Platelet count', 381 | # 'Potassium': 'Potassium', 382 | # 'RDW': 'RDW', 383 | # 'RedBloodCells': 'Red blood cells', 384 | # 'Sodium': 'Sodium', 385 | # 'UreaNitrogen': 'Urea nitrogen', 386 | # 'WhiteBloodCells': 'White blood cells', 387 | # 'direct emrgency flag': 'Direct emergency', 388 | # 'gender': 'Sex', 389 | # 'last less than diff': 'Recent admission', 390 | # 'night admission': 'Night admission', 391 | # 'standardized age': 'Standardized age', 392 | # } 393 | # 394 | # beta_units = { 395 | # 'Admissions number 2': '2', 396 | # 'Admissions number 3+': '3+', 397 | # 'Anion gap': 'Abnormal', 398 | # 'Bicarbonate': 'Abnormal', 399 | # 'Calcium total': 'Abnormal', 400 | # 'Chloride': 'Abnormal', 401 | # 'Creatinine': 'Abnormal', 402 | # 'Ethnicity black': 'Black', 403 | # 'Ethnicity hispanic': 'Hispanic', 404 | # 'Ethnicity other': 'Other', 405 | # 'Ethnicity white': 'White', 406 | # 'Glucose': 'Abnormal', 407 | # 'Hematocrit': 'Abnormal', 408 | # 'Hemoglobin': 'Abnormal', 409 | # 'Insurance medicare': 'Medicare', 410 | # 'Insurance other': 'Other', 411 | # 'MCH': 'Abnormal', 412 | # 'MCHC': 'Abnormal', 413 | # 'MCV': 'Abnormal', 414 | # 'Magnesium': 'Abnormal', 415 | # 'Marital married': 'Married', 416 | # 'Marital single': 'Single', 417 | # 'Marital widowed': 'Widowed', 418 | # 'Phosphate': 'Abnormal', 419 | # 'Platelet count': 'Abnormal', 420 | # 'Potassium': 'Abnormal', 421 | # 'RDW': 'Abnormal', 422 | # 'Red blood cells': 'Abnormal', 423 | # 'Sodium': 'Abnormal', 424 | # 'Urea nitrogen': 'Abnormal', 425 | # 'White blood cells': 'Abnormal', 426 | # 'Direct emergency': 'Yes', 427 | # 'Sex': 'Female', 428 | # 'Recent admission': 'Yes', 429 | # 'Night admission': 'Yes', 430 | # 'Standardized age': '', 431 | # } -------------------------------------------------------------------------------- /src/pydts/examples_utils/simulations_data_config.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | DEFAULT_N_PATIENTS = 10000 5 | 6 | min_year = 2000 7 | max_year = 2015 8 | 9 | bmi_coef = 0.003 10 | age_coef = 0.002 11 | gender_coef = -0.15 12 | smk_coef = 0.1 13 | max_p = 0.65 14 | 15 | PATIENT_NO_COL = 'ID' 16 | AGE_COL = 'Age' 17 | GENDER_COL = 'Sex' 18 | ADMISSION_YEAR_COL = 'Admyear' 19 | FIRST_ADMISSION_COL = 'Firstadm' 20 | ADMISSION_SERIAL_COL = 'Admserial' 21 | WEIGHT_COL = 'Weight' 22 | HEIGHT_COL = 'Height' 23 | BMI_COL = 'BMI' 24 | SMOKING_COL = 'Smoking' 25 | HYPERTENSION_COL = 'Hypertension' 26 | DIABETES_COL = 'Diabetes' 27 | ART_FIB_COL = 'AF' # Arterial Fibrillation 28 | COPD_COL = 'COPD' # Chronic Obstructive Pulmonary Disease 29 | CRF_COL = 'CRF' # Chronic Renal Failure 30 | IN_HOSPITAL_DEATH_COL = 'In_hospital_death' 31 | DISCHARGE_RELATIVE_COL = 'Discharge_relative_date' 32 | DEATH_RELATIVE_COL = 'Death_relative_date_in_hosp' 33 | DEATH_MISSING_COL = 'Death_date_in_hosp_missing' 34 | RETURNING_PATIENT_COL = 'Returning_patient' 35 | 36 | COEFS = pd.Series({ 37 | AGE_COL: 0.1, 38 | GENDER_COL: -0.1, 39 | BMI_COL: 0.2, 40 | SMOKING_COL: 0.2, 41 | HYPERTENSION_COL: 0.2, 42 | DIABETES_COL: 0.2, 43 | ART_FIB_COL: 0.2, 44 | COPD_COL: 0.2, 45 | CRF_COL: 0.2, 46 | }) 47 | 48 | ADMISSION_SERIAL_BINS = [0, 1.5, 4.5, 8.5, 21] 49 | ADMISSION_SERIAL_LABELS = [0, 1, 2, 3] 50 | SIMULATED_DATA_FILENAME = 'simulated_data.csv' 51 | 52 | preconditions = [SMOKING_COL, HYPERTENSION_COL, DIABETES_COL, ART_FIB_COL, COPD_COL, CRF_COL] 53 | font_sz = 14 54 | title_sz = 18 55 | AGE_BINS = list(range(0, 125, 5)) 56 | AGE_LABELS = [f'{AGE_BINS[a]}' for a in range(len(AGE_BINS)-1)] 57 | 58 | DEFAULT_REAL_COEF_DICT = { 59 | "alpha": { 60 | 1: lambda t: -1 - 0.3 * np.log(t), 61 | 2: lambda t: -1.75 - 0.15 * np.log(t) 62 | }, 63 | "beta": { 64 | 1: -np.log([0.8, 3, 3, 2.5, 2]), 65 | 2: -np.log([1, 3, 4, 3, 2]) 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/pydts/model_selection.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from itertools import product 4 | from .fitters import TwoStagesFitter, TwoStagesFitterExact 5 | import warnings 6 | from copy import deepcopy 7 | pd.set_option("display.max_rows", 500) 8 | warnings.filterwarnings('ignore') 9 | from time import time 10 | from typing import Union 11 | slicer = pd.IndexSlice 12 | from .evaluation import events_integrated_brier_score, global_brier_score, events_integrated_auc, global_auc 13 | 14 | 15 | class BasePenaltyGridSearch(object): 16 | 17 | """ This class implements the penalty parameter grid search. """ 18 | 19 | def __init__(self): 20 | self.l1_ratio = None 21 | self.penalizers = [] 22 | self.seed = None 23 | self.meta_models = {} 24 | self.train_df = None 25 | self.test_df = None 26 | self.global_auc = {} 27 | self.integrated_auc = {} 28 | self.global_bs = {} 29 | self.integrated_bs = {} 30 | self.TwoStagesFitter_type = 'CoxPHFitter' 31 | 32 | def evaluate(self, 33 | train_df: pd.DataFrame, 34 | test_df: pd.DataFrame, 35 | l1_ratio: float, 36 | penalizers: list, 37 | metrics: Union[list, str] = ['IBS', 'GBS', 'IAUC', 'GAUC'], 38 | seed: Union[None, int] = None, 39 | event_type_col: str = 'J', 40 | duration_col: str = 'X', 41 | pid_col: str = 'pid', 42 | twostages_fit_kwargs: dict = {}) -> tuple: 43 | 44 | """ 45 | This function implements model estimation using train_df and evaluation of the metrics using test_df to all the possible combinations of penalizers. 46 | 47 | Args: 48 | train_df (pd.DataFrame): training data for fitting the model. 49 | test_df (pd.DataFrame): testing data for evaluating the estimated model's performance. 50 | l1_ratio (float): regularization ratio for the CoxPHFitter (see lifelines.fitters.coxph_fitter.CoxPHFitter documentation). 51 | penalizers (list): penalizer options for each event (see lifelines.fitters.coxph_fitter.CoxPHFitter documentation). 52 | metrics (str, list): Evaluation metrics. 53 | seed (int): pseudo random seed number for numpy.random.seed() 54 | event_type_col (str): The event type column name (must be a column in df), Right-censored sample (i) is indicated by event value 0, df.loc[i, event_type_col] = 0. 55 | duration_col (str): Last follow up time column name (must be a column in df). 56 | pid_col (str): Sample ID column name (must be a column in df). 57 | twostages_fit_kwargs (dict): keyword arguments to pass to the TwoStagesFitter. 58 | 59 | Returns: 60 | output (Tuple): Penalizers with best performance in terms of Global-AUC, if 'GAUC' is in metrics. 61 | 62 | """ 63 | 64 | self.l1_ratio = l1_ratio 65 | self.penalizers = penalizers 66 | self.seed = seed 67 | np.random.seed(seed) 68 | 69 | for idp, penalizer in enumerate(penalizers): 70 | 71 | fit_beta_kwargs = self._get_model_fit_kwargs(penalizer, l1_ratio) 72 | 73 | if self.TwoStagesFitter_type == 'Exact': 74 | self.meta_models[penalizer] = TwoStagesFitterExact() 75 | else: 76 | self.meta_models[penalizer] = TwoStagesFitter() 77 | print(f"Started estimating the coefficients for penalizer {penalizer} ({idp+1}/{len(penalizers)})") 78 | start = time() 79 | self.meta_models[penalizer].fit(df=train_df, fit_beta_kwargs=fit_beta_kwargs, 80 | pid_col=pid_col, event_type_col=event_type_col, duration_col=duration_col, 81 | **twostages_fit_kwargs) 82 | end = time() 83 | print(f"Finished estimating the coefficients for penalizer {penalizer} ({idp+1}/{len(penalizers)}), {int(end - start)} seconds") 84 | 85 | events = [j for j in sorted(train_df[event_type_col].unique()) if j != 0] 86 | grid = [penalizers for e in events] 87 | penalizers_combinations = list(product(*grid)) 88 | 89 | for idc, combination in enumerate(penalizers_combinations): 90 | mixed_two_stages = self.get_mixed_two_stages_fitter(combination) 91 | 92 | pred_df = mixed_two_stages.predict_prob_events(test_df) 93 | 94 | for metric in metrics: 95 | if metric == 'IAUC': 96 | self.integrated_auc[combination] = events_integrated_auc(pred_df, event_type_col=event_type_col, 97 | duration_col=duration_col) 98 | elif metric == 'GAUC': 99 | self.global_auc[combination] = global_auc(pred_df, event_type_col=event_type_col, 100 | duration_col=duration_col) 101 | elif metric == 'IBS': 102 | self.integrated_bs[combination] = events_integrated_brier_score(pred_df, 103 | event_type_col=event_type_col, 104 | duration_col=duration_col) 105 | elif metric == 'GBS': 106 | self.global_bs[combination] = global_brier_score(pred_df, event_type_col=event_type_col, 107 | duration_col=duration_col) 108 | 109 | output = self.convert_results_dict_to_df(self.global_auc).idxmax().values[0] if 'GAUC' in metrics else [] 110 | return output 111 | 112 | def convert_results_dict_to_df(self, results_dict): 113 | """ 114 | This function converts a results dictionary to a pd.DataFrame format. 115 | 116 | Args: 117 | results_dict: one of the class attributes: global_auc, integrated_auc, global_bs, integrated_bs. 118 | 119 | Returns: 120 | df (pd.DataFrame): Results in a pd.DataFrame format. 121 | """ 122 | df = pd.DataFrame(results_dict.values(), index=pd.MultiIndex.from_tuples(results_dict.keys())) 123 | return df 124 | 125 | def get_mixed_two_stages_fitter(self, penalizers_combination: list) -> TwoStagesFitter: 126 | """ 127 | This function creates a mixed TwoStagesFitter from the estimated meta models for a specific penalizers combination. 128 | 129 | Args: 130 | penalizers_combination (list): List with length equals to the number of competing events. The penalizers value to each of the events. Each of the values must be one of the values that was previously passed to the evaluate() method. 131 | 132 | Returns: 133 | mixed_two_stages (pydts.fitters.TwoStagesFitter): TwoStagesFitter for the required penalty combination. 134 | """ 135 | _validate_estimated_value = [p for p in penalizers_combination if p not in list(self.meta_models.keys())] 136 | assert len(_validate_estimated_value) == 0, \ 137 | f"Values {_validate_estimated_value} were note estimated. All the penalizers in penalizers_combination must be estimated using evaluate() before a mixed combination can be generated." 138 | 139 | events = self.meta_models[penalizers_combination[0]].events 140 | event_type_col = self.meta_models[penalizers_combination[0]].event_type_col 141 | if self.TwoStagesFitter_type == 'Exact': 142 | mixed_two_stages = TwoStagesFitterExact() 143 | else: 144 | mixed_two_stages = TwoStagesFitter() 145 | 146 | for ide, event in enumerate(sorted(events)): 147 | if ide == 0: 148 | mixed_two_stages.covariates = self.meta_models[penalizers_combination[ide]].covariates 149 | mixed_two_stages.duration_col = self.meta_models[penalizers_combination[ide]].duration_col 150 | mixed_two_stages.event_type_col = self.meta_models[penalizers_combination[ide]].event_type_col 151 | mixed_two_stages.events = self.meta_models[penalizers_combination[ide]].events 152 | mixed_two_stages.pid_col = self.meta_models[penalizers_combination[ide]].pid_col 153 | mixed_two_stages.times = self.meta_models[penalizers_combination[ide]].times 154 | 155 | mixed_two_stages.beta_models[event] = self.meta_models[penalizers_combination[ide]].beta_models[event] 156 | mixed_two_stages.event_models[event] = [] 157 | mixed_two_stages.event_models[event].append(self.meta_models[penalizers_combination[ide]].beta_models[event]) 158 | 159 | event_alpha = self.meta_models[penalizers_combination[ide]].alpha_df.copy() 160 | event_alpha = event_alpha[event_alpha[event_type_col] == event] 161 | mixed_two_stages.alpha_df = pd.concat([mixed_two_stages.alpha_df, event_alpha]) 162 | mixed_two_stages.event_models[event].append(event_alpha) 163 | 164 | return mixed_two_stages 165 | 166 | def _get_model_fit_kwargs(self, penalizer, l1_ratio): 167 | if self.TwoStagesFitter_type == 'Exact': 168 | fit_beta_kwargs = { 169 | 'model_fit_kwargs': { 170 | 'alpha': penalizer, 171 | 'L1_wt': l1_ratio 172 | } 173 | } 174 | else: 175 | fit_beta_kwargs = { 176 | 'model_kwargs': { 177 | 'penalizer': penalizer, 178 | 'l1_ratio': l1_ratio 179 | }, 180 | } 181 | return fit_beta_kwargs 182 | 183 | 184 | class PenaltyGridSearch(BasePenaltyGridSearch): 185 | 186 | def __init__(self): 187 | super().__init__() 188 | self.TwoStagesFitter_type = 'CoxPHFitter' 189 | 190 | 191 | class PenaltyGridSearchExact(BasePenaltyGridSearch): 192 | 193 | def __init__(self): 194 | super().__init__() 195 | self.TwoStagesFitter_type = 'Exact' 196 | -------------------------------------------------------------------------------- /src/pydts/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Optional 2 | 3 | import pandas as pd 4 | import numpy as np 5 | from scipy.special import expit 6 | 7 | 8 | # def get_expanded_df(df, event_type_col='J', duration_col='X', pid_col='pid'): 9 | # """ 10 | # This function gets a dataframe describing each sample the time of the observed events, 11 | # and returns an expanded dataframe as explained in [1]-[2]. 12 | # Right censoring is allowed and must be marked as event type 0. 13 | # 14 | # :param df: original dataframe (pd.DataFrame) 15 | # :param event_type_col: event type column name (str) 16 | # :param duration_col: time column name (str) 17 | # :param pid_col: patient id column name (str) 18 | # 19 | # :return: result_df: expanded dataframe 20 | # 21 | # References: 22 | # [1] Meir, Tomer and Gorfine, Malka, "Discrete-time Competing-Risks Regression with or without Penalization", https://arxiv.org/abs/2303.01186 23 | # [2] Meir, Tomer and Gutman, Rom and Gorfine, Malka "PyDTS: A Python Package for Discrete-Time Survival (Regularized) Regression with Competing Risks", https://arxiv.org/abs/2204.05731 24 | # """ 25 | 26 | def get_expanded_df( 27 | df: pd.DataFrame, 28 | event_type_col: str = 'J', 29 | duration_col: str = 'X', 30 | pid_col: str = 'pid') -> pd.DataFrame: 31 | """ 32 | Expands a discrete-time survival dataset into a long-format dataframe suitable for modeling. This function receives a dataframe where each row corresponds to a subject with observed event type and duration. It returns an expanded dataframe where each subject is represented by multiple rows, one for each time point up to their observed time. Right censoring is allowed and should be indicated by event type 0. 33 | 34 | Args: 35 | df (pd.DataFrame): Original input dataframe containing one row per subject. 36 | event_type_col (str): Name of the column indicating event type. Censoring is marked by 0. 37 | duration_col (str): Name of the column indicating event or censoring time. 38 | pid_col (str): Name of the column indicating subject/patient ID. 39 | 40 | Returns: 41 | pd.DataFrame: Expanded dataframe in long format, with one row per subject-time pair. 42 | """ 43 | unique_times = df[duration_col].sort_values().unique() 44 | result_df = df.reindex(df.index.repeat(df[duration_col])) 45 | result_df[duration_col] = result_df.groupby(pid_col).cumcount() + 1 46 | # drop times that didn't happen 47 | result_df.drop(index=result_df.loc[~result_df[duration_col].isin(unique_times)].index, inplace=True) 48 | result_df.reset_index(drop=True, inplace=True) 49 | last_idx = result_df.drop_duplicates(subset=[pid_col], keep='last').index 50 | events = sorted(df[event_type_col].unique()) 51 | result_df.loc[last_idx, [f'j_{e}' for e in events]] = pd.get_dummies( 52 | result_df.loc[last_idx, event_type_col]).values 53 | result_df[[f'j_{e}' for e in events]] = result_df[[f'j_{e}' for e in events]].fillna(0) 54 | result_df[f'j_0'] = 1 - result_df[[f'j_{e}' for e in events if e > 0]].sum(axis=1) 55 | return result_df 56 | 57 | 58 | def compare_models_coef_per_event(first_model: pd.Series, 59 | second_model: pd.Series, 60 | real_values: np.array, 61 | event: int, 62 | first_model_label:str = "first", 63 | second_model_label:str = "second" 64 | ) -> pd.DataFrame: 65 | event_suffix = f"_{event}" 66 | assert (first_model.index == second_model.index).all(), "All index should be the same" 67 | models = pd.concat([first_model.to_frame(first_model_label), 68 | second_model.to_frame(second_model_label)], axis=1) 69 | models.index += event_suffix 70 | real_values_s = pd.Series(real_values, index=models.index) 71 | 72 | return pd.concat([models, real_values_s.to_frame("real")], axis=1) 73 | 74 | 75 | def present_coefs(res_dict): 76 | from IPython.display import display 77 | for coef_type, events_dict in res_dict.items(): 78 | print(f"for coef: {coef_type.capitalize()}") 79 | df = pd.concat([temp_df for temp_df in events_dict.values()]) 80 | display(df) 81 | 82 | 83 | def get_real_hazard(df, real_coef_dict, times, events): 84 | a_t = {event: {t: real_coef_dict['alpha'][event](t) for t in times} for event in events} 85 | b = pd.concat([df.dot(real_coef_dict['beta'][j]) for j in events], axis=1, keys=events) 86 | 87 | for j in events: 88 | df[[f'hazard_j{j}_t{t}' for t in times]] = pd.concat([expit(a_t[j][t] + b[j]) for t in times], 89 | axis=1).values 90 | return df 91 | 92 | 93 | def assert_fit(event_df, times, event_type_col='J', duration_col='X'): 94 | if not event_df['success'].all(): 95 | problematic_times = event_df.loc[~event_df['success'], duration_col].tolist() 96 | event = event_df[event_type_col].max() # all the events in the dataframe are the same 97 | raise RuntimeError(f"Number of observed events at some time points are too small. Consider collapsing neighbor time points." 98 | f"\n See https://tomer1812.github.io/pydts/UsageExample-RegroupingData/ for more details.") 99 | if not np.all([t in event_df[duration_col].values for t in times]): 100 | event = event_df[event_type_col].max() # all the events in the dataframe are the same 101 | problematic_times = pd.Index(event_df[duration_col]).symmetric_difference(times).tolist() 102 | raise RuntimeError(f"Number of observed events at some time points are too small. Consider collapsing neighbor time points." 103 | f"\n See https://tomer1812.github.io/pydts/UsageExample-RegroupingData/ for more details.") 104 | 105 | 106 | def create_df_for_cif_plots(df: pd.DataFrame, field: str, 107 | covariates: Iterable, 108 | vals: Optional[Iterable] = None, 109 | quantiles: Optional[Iterable] = None, 110 | zero_others: Optional[bool] = False 111 | ) -> pd.DataFrame: 112 | """ 113 | This method creates df for cif plot, where it zeros 114 | 115 | Args: 116 | df (pd.DataFrame): Dataframe which we yield the statiscal propetrics (means, quantiles, etc) and stacture 117 | field (str): The field which will represent the change 118 | covariates (Iterable): The covariates of the given model 119 | vals (Optional[Iterable]): The values to use for the field 120 | quantiles (Optional[Iterable]): The quantiles to use as values for the field 121 | zero_others (bool): Whether to zero the other covarites or to zero them 122 | 123 | Returns: 124 | df (pd.DataFrame): A dataframe that contains records per value for cif ploting 125 | """ 126 | 127 | cov_not_fitted = [cov for cov in covariates if cov not in df.columns] 128 | assert len(cov_not_fitted) == 0, \ 129 | f"Required covariates are missing from df: {cov_not_fitted}" 130 | 131 | df_for_ploting = df.copy() 132 | if vals is not None: 133 | pass 134 | elif quantiles is not None: 135 | vals = df_for_ploting[field].quantile(quantiles).values 136 | else: 137 | raise NotImplemented("Only Quantiles or specific values is supported") 138 | temp_series = [] 139 | template_s = df_for_ploting.iloc[0][covariates].copy() 140 | if zero_others: 141 | impute_val = 0 142 | else: 143 | impute_val = df_for_ploting[covariates].mean().values 144 | for val in vals: 145 | temp_s = template_s.copy() 146 | temp_s[covariates] = impute_val 147 | temp_s[field] = val 148 | temp_series.append(temp_s) 149 | 150 | return pd.concat(temp_series, axis=1).T -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomer1812/pydts/578929a457c111efe009d3461aab531b793b33d0/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_DataExpansionFitter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from src.pydts.examples_utils.generate_simulations_data import generate_quick_start_df 3 | from src.pydts.fitters import DataExpansionFitter 4 | import numpy as np 5 | from src.pydts.utils import get_real_hazard 6 | 7 | 8 | class TestDataExpansionFitter(unittest.TestCase): 9 | def setUp(self): 10 | self.real_coef_dict = { 11 | "alpha": { 12 | 1: lambda t: -1 - 0.3 * np.log(t), 13 | 2: lambda t: -1.75 - 0.15 * np.log(t) 14 | }, 15 | "beta": { 16 | 1: -np.log([0.8, 3, 3, 2.5, 2]), 17 | 2: -np.log([1, 3, 4, 3, 2]) 18 | } 19 | } 20 | self.df = generate_quick_start_df(n_patients=1000, n_cov=5, d_times=10, j_events=2, pid_col='pid', seed=0, 21 | real_coef_dict=self.real_coef_dict, censoring_prob=0.8) 22 | self.m = DataExpansionFitter() 23 | self.fitted_model = DataExpansionFitter() 24 | self.fitted_model.fit(df=self.df.drop(['C', 'T'], axis=1)) 25 | 26 | def test_fit_case_C_in_df(self): 27 | # 'C' named column cannot be passed in df to .fit() 28 | with self.assertRaises(ValueError): 29 | m = DataExpansionFitter() 30 | m.fit(df=self.df) 31 | 32 | def test_fit_case_event_col_not_in_df(self): 33 | # Event column (here 'J') must be passed in df to .fit() 34 | with self.assertRaises(AssertionError): 35 | self.m.fit(df=self.df.drop(['C', 'J', 'T'], axis=1)) 36 | 37 | def test_fit_case_duration_col_not_in_df(self): 38 | # Duration column (here 'X') must be passed in df to .fit() 39 | with self.assertRaises(AssertionError): 40 | self.m.fit(df=self.df.drop(['C', 'X', 'T'], axis=1)) 41 | 42 | def test_fit_case_pid_col_not_in_df(self): 43 | # Duration column (here 'pid') must be passed in df to .fit() 44 | with self.assertRaises(AssertionError): 45 | self.m.fit(df=self.df.drop(['C', 'pid', 'T'], axis=1)) 46 | 47 | def test_fit_case_cov_col_not_in_df(self): 48 | # Covariates columns (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .fit() 49 | with self.assertRaises(ValueError): 50 | self.m.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=['Z6']) 51 | 52 | def test_fit_case_correct_fit(self): 53 | # Fit should be successful 54 | m = DataExpansionFitter() 55 | m.fit(df=self.df.drop(['C', 'T'], axis=1)) 56 | 57 | def test_fit_with_kwargs(self): 58 | import statsmodels.api as sm 59 | m = DataExpansionFitter() 60 | m.fit(df=self.df.drop(columns=['C', 'T']), models_kwargs=dict(family=sm.families.Binomial())) 61 | 62 | def test_print_summary(self): 63 | self.fitted_model.print_summary() 64 | 65 | def test_get_beta_SE(self): 66 | self.fitted_model.get_beta_SE() 67 | 68 | def test_get_alpha_SE(self): 69 | self.fitted_model.get_alpha_df() 70 | 71 | def test_predict_hazard_jt_case_covariate_not_in_df(self): 72 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict() 73 | with self.assertRaises(AssertionError): 74 | self.fitted_model.predict_hazard_jt( 75 | df=self.df.drop(['C', 'T', 'Z1'], axis=1), 76 | event=self.fitted_model.events[0], 77 | t=self.fitted_model.times[0]) 78 | 79 | def test_predict_hazard_jt_case_event_not_in_events(self): 80 | # Event passed to .predict() must be in fitted events 81 | with self.assertRaises(AssertionError): 82 | self.fitted_model.predict_hazard_jt( 83 | df=self.df.drop(['C', 'T'], axis=1), event=100, t=self.fitted_model.times[0]) 84 | 85 | def test_predict_hazard_jt_case_time_not_in_times(self): 86 | # Event passed to .predict() must be in fitted events 87 | with self.assertRaises(AssertionError): 88 | self.fitted_model.predict_hazard_jt( 89 | df=self.df.drop(['C', 'T'], axis=1), event=self.fitted_model.events[0], t=1000) 90 | 91 | def test_predict_hazard_jt_case_successful_predict(self): 92 | self.fitted_model.predict_hazard_jt( 93 | df=self.df.drop(['C', 'T'], axis=1), 94 | event=self.fitted_model.events[0], t=self.fitted_model.times[0]) 95 | 96 | def test_predict_hazard_t_case_successful_predict(self): 97 | self.fitted_model.predict_hazard_t(df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[:3]) 98 | 99 | def test_predict_hazard_all_case_successful_predict(self): 100 | self.fitted_model.predict_hazard_all(df=self.df.drop(['C', 'T'], axis=1)) 101 | 102 | def test_predict_overall_survival_case_successful_predict(self): 103 | self.fitted_model.predict_overall_survival( 104 | df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[5], return_hazards=True) 105 | 106 | def test_predict_prob_event_j_at_t_case_successful_predict(self): 107 | self.fitted_model.predict_prob_event_j_at_t(df=self.df.drop(['C', 'T'], axis=1), 108 | event=self.fitted_model.events[0], 109 | t=self.fitted_model.times[3]) 110 | 111 | def test_predict_prob_event_j_all_case_successful_predict(self): 112 | self.fitted_model.predict_prob_event_j_all(df=self.df.drop(['C', 'T'], axis=1), 113 | event=self.fitted_model.events[0]) 114 | 115 | def test_predict_prob_events_case_successful_predict(self): 116 | self.fitted_model.predict_prob_events(df=self.df.drop(['C', 'T'], axis=1)) 117 | 118 | def test_predict_event_cumulative_incident_function_case_successful_predict(self): 119 | self.fitted_model.predict_event_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1), 120 | event=self.fitted_model.events[0]) 121 | 122 | def test_predict_cumulative_incident_function_case_successful_predict(self): 123 | self.fitted_model.predict_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1)) 124 | 125 | def test_predict_hazard_jt_case_hazard_already_on_df(self): 126 | df_temp = get_real_hazard(self.df.drop(['C', 'T', 'X', 'J'], axis=1).set_index('pid').copy(), 127 | real_coef_dict=self.real_coef_dict, 128 | times=self.fitted_model.times, 129 | events=self.fitted_model.events) 130 | assert (df_temp == self.fitted_model.predict_hazard_jt(df=df_temp, 131 | event=self.fitted_model.events[0], 132 | t=self.fitted_model.times 133 | ) 134 | ).all().all() -------------------------------------------------------------------------------- /tests/test_EventTimesSampler.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from src.pydts.data_generation import EventTimesSampler 3 | from src.pydts.fitters import TwoStagesFitter, DataExpansionFitter 4 | from time import time 5 | import numpy as np 6 | import pandas as pd 7 | slicer = pd.IndexSlice 8 | COEF_COL = ' coef ' 9 | STDERR_COL = ' std err ' 10 | 11 | 12 | class TestEventTimesSampler(unittest.TestCase): 13 | 14 | def test_sample_2_events_5_covariates(self): 15 | n_cov = 5 16 | n_patients = 1000 17 | seed = 0 18 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 19 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]), 20 | columns=covariates) 21 | 22 | ets = EventTimesSampler(d_times=15, j_event_types=2) 23 | real_coef_dict = { 24 | "alpha": { 25 | 1: lambda t: -1 - 0.3 * np.log(t), 26 | 2: lambda t: -1.75 - 0.15 * np.log(t), 27 | }, 28 | "beta": { 29 | 1: -np.log([0.8, 3, 3, 2.5, 2]), 30 | 2: -np.log([1, 3, 4, 3, 2]), 31 | } 32 | } 33 | ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 34 | 35 | def test_sample_3_events_4_covariates(self): 36 | n_cov = 4 37 | n_patients = 1000 38 | seed = 0 39 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 40 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]), 41 | columns=covariates) 42 | 43 | ets = EventTimesSampler(d_times=15, j_event_types=2) 44 | real_coef_dict = { 45 | "alpha": { 46 | 1: lambda t: -1 - 0.3 * np.log(t), 47 | 2: lambda t: -1.75 - 0.15 * np.log(t), 48 | 3: lambda t: -1.5 - 0.2 * np.log(t), 49 | }, 50 | "beta": { 51 | 1: -np.log([0.8, 3, 3, 2.5]), 52 | 2: -np.log([1, 3, 4, 3]), 53 | 3: -np.log([0.5, 2, 3, 4]), 54 | } 55 | } 56 | ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 57 | 58 | def test_sample_hazard_censoring(self): 59 | seed = 0 60 | n_cov = 5 61 | n_patients = 10000 62 | np.random.seed(seed) 63 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 64 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]), 65 | columns=covariates) 66 | 67 | ets = EventTimesSampler(d_times=15, j_event_types=4) 68 | censoring_coef_dict = { 69 | "alpha": { 70 | 0: lambda t: -1 - 0.3 * np.log(t), 71 | }, 72 | "beta": { 73 | 0: -np.log([0.8, 3, 3, 2.5, 2]), 74 | } 75 | } 76 | ets.sample_hazard_lof_censoring(patients_df, censoring_coef_dict, seed) 77 | 78 | def test_sample_independent_censoring(self): 79 | n_cov = 4 80 | n_patients = 1000 81 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 82 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]), 83 | columns=covariates) 84 | 85 | d_times = 15 86 | ets = EventTimesSampler(d_times=15, j_event_types=2) 87 | ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.03 * np.ones(d_times)) 88 | 89 | def test_update_event_or_lof(self): 90 | n_cov = 5 91 | n_patients = 1000 92 | seed = 0 93 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 94 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]), 95 | columns=covariates) 96 | 97 | ets = EventTimesSampler(d_times=15, j_event_types=3) 98 | real_coef_dict = { 99 | "alpha": { 100 | 1: lambda t: -1 - 0.3 * np.log(t), 101 | 2: lambda t: -1.75 - 0.15 * np.log(t), 102 | 3: lambda t: -1.75 - 0.15 * np.log(t), 103 | }, 104 | "beta": { 105 | 1: -np.log([0.8, 3, 3, 2.5, 2]), 106 | 2: -np.log([1, 3, 4, 3, 2]), 107 | 3: -np.log([1, 3, 4, 3, 2]), 108 | } 109 | } 110 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 111 | censoring_coef_dict = { 112 | "alpha": { 113 | 0: lambda t: -1 - 0.3 * np.log(t), 114 | }, 115 | "beta": { 116 | 0: -np.log([0.8, 3, 3, 2.5, 2]), 117 | } 118 | } 119 | 120 | patients_df = ets.sample_hazard_lof_censoring(patients_df, censoring_coef_dict, seed=seed) 121 | patients_df = ets.update_event_or_lof(patients_df) 122 | 123 | def test_update_event_or_lof_T_assertion(self): 124 | with self.assertRaises(AssertionError): 125 | seed = 0 126 | n_cov = 5 127 | n_patients = 10000 128 | np.random.seed(seed) 129 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 130 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]), 131 | columns=covariates) 132 | 133 | ets = EventTimesSampler(d_times=15, j_event_types=3) 134 | censoring_coef_dict = { 135 | "alpha": { 136 | 0: lambda t: -1 - 0.3 * np.log(t), 137 | }, 138 | "beta": { 139 | 0: -np.log([0.8, 3, 3, 2.5, 2]), 140 | } 141 | } 142 | 143 | patients_df = ets.sample_hazard_lof_censoring(patients_df, censoring_coef_dict, seed=seed) 144 | patients_df = ets.update_event_or_lof(patients_df) 145 | 146 | def test_update_event_or_lof_C_assertion(self): 147 | with self.assertRaises(AssertionError): 148 | n_cov = 5 149 | n_patients = 1000 150 | seed = 0 151 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 152 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]), 153 | columns=covariates) 154 | 155 | ets = EventTimesSampler(d_times=15, j_event_types=3) 156 | real_coef_dict = { 157 | "alpha": { 158 | 1: lambda t: -1 - 0.3 * np.log(t), 159 | 2: lambda t: -1.75 - 0.15 * np.log(t), 160 | 3: lambda t: -1.75 - 0.15 * np.log(t), 161 | }, 162 | "beta": { 163 | 1: -np.log([0.8, 3, 3, 2.5, 2]), 164 | 2: -np.log([1, 3, 4, 3, 2]), 165 | 3: -np.log([1, 3, 4, 3, 2]), 166 | } 167 | } 168 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 169 | patients_df = ets.update_event_or_lof(patients_df) 170 | 171 | # def test_sample_and_fit_from_multinormal(self): 172 | # # real_coef_dict = { 173 | # # "alpha": { 174 | # # 1: lambda t: -3 - 3 * np.log(t), 175 | # # 2: lambda t: -3 - 0.15 * np.log(t) 176 | # # }, 177 | # # "beta": { 178 | # # # 1: [0.3, -1, -0.5, 0.7, 0.9], 179 | # # # 2: [-0.5, 0.5, 0.7, -0.7, 0.5] 180 | # # 1: -np.log([0.8, 3, 3, 2.5, 2]), 181 | # # 2: -np.log([1, 3, 4, 3, 2]) 182 | # # } 183 | # # } 184 | # 185 | # real_coef_dict = { 186 | # "alpha": { 187 | # 1: lambda t: -1.75 - 0.3 * np.log(t), 188 | # 2: lambda t: -1.5 - 0.15 * np.log(t) 189 | # }, 190 | # "beta": { 191 | # 1: -0.5*np.log([0.8, 3, 3, 2.5, 2]), 192 | # 2: -0.5*np.log([1, 3, 4, 3, 2]) 193 | # } 194 | # } 195 | # 196 | # censoring_hazard_coef_dict = { 197 | # "alpha": { 198 | # 0: lambda t: -1.75 - 0.3 * np.log(t), 199 | # }, 200 | # "beta": { 201 | # 0: -0.5*np.log([1, 3, 4, 3, 2]), 202 | # } 203 | # } 204 | # 205 | # n_patients = 15000 206 | # n_cov = 5 207 | # d_times = 50 208 | # j_events = 2 209 | # clip_value = 1 210 | # means_vector = np.zeros(n_cov) 211 | # covariance_matrix = 0.4*np.identity(n_cov) 212 | # 213 | # ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 214 | # seed = 0 215 | # covariates = [f'Z{i + 1}' for i in range(n_cov)] 216 | # 217 | # np.random.seed(seed) 218 | # 219 | # patients_df = pd.DataFrame(data=pd.DataFrame( 220 | # data=np.random.multivariate_normal(means_vector, covariance_matrix, size=n_patients), 221 | # columns=covariates)) 222 | # # patients_df = pd.DataFrame(data=pd.DataFrame( 223 | # # data=np.random.uniform(0, 1, size=[n_patients, n_cov]), 224 | # # columns=covariates)) 225 | # 226 | # patients_df.clip(lower=-1*clip_value, upper=clip_value, inplace=True) 227 | # patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 228 | # patients_df = ets.sample_hazard_lof_censoring(patients_df, 229 | # censoring_hazard_coefs=censoring_hazard_coef_dict, 230 | # seed=seed + 1, events=[0]) 231 | # patients_df = ets.update_event_or_lof(patients_df) 232 | # patients_df.index.name = 'pid' 233 | # patients_df = patients_df.reset_index() 234 | # 235 | # # Two step fitter 236 | # new_fitter = TwoStagesFitter() 237 | # two_step_start = time() 238 | # new_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), nb_workers=1) 239 | # two_step_end = time() 240 | # 241 | # # Lee et al fitter 242 | # lee_fitter = DataExpansionFitter() 243 | # lee_start = time() 244 | # lee_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1)) 245 | # lee_end = time() 246 | # lee_alpha_results = lee_fitter.get_alpha_df().loc[:, 247 | # slicer[:, [COEF_COL, STDERR_COL]]].unstack().to_frame() 248 | # lee_beta_results = lee_fitter.get_beta_SE().loc[:, slicer[:, [COEF_COL, STDERR_COL]]].unstack().to_frame() 249 | # 250 | # # Save results only if both fitters were successful 251 | # two_step_fit_time = two_step_end - two_step_start 252 | # lee_fit_time = lee_end - lee_start 253 | # 254 | # two_step_alpha_results = new_fitter.alpha_df[['J', 'X', 'alpha_jt']].set_index(['J', 'X']) 255 | # two_step_beta_results = new_fitter.get_beta_SE().unstack().to_frame() 256 | # print('x') 257 | 258 | # def test_sample_and_fit_normal(self): 259 | # 260 | # real_coef_dict = { 261 | # "alpha": { 262 | # 1: lambda t: -3 - 0.3 * np.log(t), 263 | # 2: lambda t: -3 - 0.15 * np.log(t) 264 | # }, 265 | # "beta": { 266 | # 1: -np.log([0.8, 3, 3, 2.5, 2]), 267 | # 2: -np.log([1, 3, 4, 3, 2]) 268 | # } 269 | # } 270 | # 271 | # n_patients = 10000 272 | # n_cov = 5 273 | # d_times = 50 274 | # j_events = 2 275 | # 276 | # ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 277 | # seed = 0 278 | # covariates = [f'Z{i + 1}' for i in range(n_cov)] 279 | # 280 | # np.random.seed(seed) 281 | # 282 | # patients_df = pd.DataFrame(data=pd.DataFrame( 283 | # data=np.random.uniform(0, 1, size=[n_patients, n_cov]), 284 | # columns=covariates)) 285 | # patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 286 | # patients_df['X'] = patients_df['T'] 287 | # patients_df['C'] = 51 288 | # patients_df.index.name = 'pid' 289 | # patients_df = patients_df.reset_index() 290 | # 291 | # # Two step fitter 292 | # new_fitter = TwoStagesFitter() 293 | # two_step_start = time() 294 | # new_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), nb_workers=1) # , x0=-3 295 | # two_step_end = time() 296 | # 297 | # # Lee et al fitter 298 | # lee_fitter = DataExpansionFitter() 299 | # lee_start = time() 300 | # lee_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1)) 301 | # lee_end = time() 302 | # lee_alpha_results = lee_fitter.get_alpha_df().loc[:, 303 | # slicer[:, [COEF_COL, STDERR_COL]]].unstack().to_frame() 304 | # lee_beta_results = lee_fitter.get_beta_SE().loc[:, slicer[:, [COEF_COL, STDERR_COL]]].unstack().to_frame() 305 | # 306 | # # Save results only if both fitters were successful 307 | # two_step_fit_time = two_step_end - two_step_start 308 | # lee_fit_time = lee_end - lee_start 309 | # 310 | # two_step_alpha_results = new_fitter.alpha_df[['J', 'X', 'alpha_jt']].set_index(['J', 'X']) 311 | # two_step_beta_results = new_fitter.get_beta_SE().unstack().to_frame() 312 | # print('x') 313 | 314 | 315 | def test_raise_negative_values_overall_survival_assertion(self): 316 | with self.assertRaises(ValueError): 317 | real_coef_dict = { 318 | "alpha": { 319 | 1: lambda t: -9 + 3 * np.log(t), 320 | 2: lambda t: -7 + 2.5 * np.log(t) 321 | }, 322 | "beta": { 323 | 1: [1.3, 1.7, -1.5, 0.5, 1.6], 324 | 2: [-1.5, 1.5, 1.8, -1, 1.2] 325 | } 326 | } 327 | 328 | censoring_hazard_coef_dict = { 329 | "alpha": { 330 | 0: lambda t: -8 + 2.1 * np.log(t), 331 | }, 332 | "beta": { 333 | 0: [2, 1, -1.5, 1.5, -1.3], 334 | } 335 | } 336 | 337 | n_patients = 25000 338 | n_cov = 5 339 | d_times = 12 340 | j_events = 2 341 | means_vector = np.zeros(n_cov) 342 | covariance_matrix = np.identity(n_cov) 343 | 344 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 345 | seed = 0 346 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 347 | 348 | np.random.seed(seed) 349 | 350 | patients_df = pd.DataFrame(data=pd.DataFrame( 351 | data=np.random.multivariate_normal(means_vector, covariance_matrix, size=n_patients), 352 | columns=covariates)) 353 | 354 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 355 | 356 | -------------------------------------------------------------------------------- /tests/test_TwoStagesFitter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import matplotlib.pyplot as plt 4 | from unittest.mock import patch 5 | from src.pydts.examples_utils.generate_simulations_data import generate_quick_start_df 6 | from src.pydts.fitters import TwoStagesFitter 7 | from src.pydts.utils import get_real_hazard 8 | import numpy as np 9 | 10 | 11 | class TestTwoStagesFitter(unittest.TestCase): 12 | 13 | def setUp(self): 14 | self.real_coef_dict = { 15 | "alpha": { 16 | 1: lambda t: -1 - 0.3 * np.log(t), 17 | 2: lambda t: -1.75 - 0.15 * np.log(t) 18 | }, 19 | "beta": { 20 | 1: -np.log([0.8, 3, 3, 2.5, 2]), 21 | 2: -np.log([1, 3, 4, 3, 2]) 22 | } 23 | } 24 | self.df = generate_quick_start_df(n_patients=5000, n_cov=5, d_times=10, j_events=2, pid_col='pid', seed=0, 25 | real_coef_dict=self.real_coef_dict, censoring_prob=.8) 26 | self.m = TwoStagesFitter() 27 | self.fitted_model = TwoStagesFitter() 28 | self.fitted_model.fit(df=self.df.drop(['C', 'T'], axis=1)) 29 | 30 | def test_fit_case_event_col_not_in_df(self): 31 | # Event column (here 'J') must be passed in df to .fit() 32 | with self.assertRaises(AssertionError): 33 | self.m.fit(df=self.df.drop(['C', 'J', 'T'], axis=1)) 34 | 35 | def test_fit_case_duration_col_not_in_df(self): 36 | # Duration column (here 'X') must be passed in df to .fit() 37 | with self.assertRaises(AssertionError): 38 | self.m.fit(df=self.df.drop(['C', 'X', 'T'], axis=1)) 39 | 40 | def test_fit_case_pid_col_not_in_df(self): 41 | # Duration column (here 'pid') must be passed in df to .fit() 42 | with self.assertRaises(AssertionError): 43 | self.m.fit(df=self.df.drop(['C', 'pid', 'T'], axis=1)) 44 | 45 | def test_fit_case_cov_col_not_in_df(self): 46 | # Covariates columns (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .fit() 47 | with self.assertRaises(ValueError): 48 | self.m.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=['Z6']) 49 | 50 | def test_fit_case_missing_jt(self): 51 | # df to .fit() should contain observations to all (event, times) 52 | 53 | # drop (events[1], times[2]) 54 | tmp_df = self.df[ 55 | ~((self.df[self.fitted_model.duration_col] == self.fitted_model.times[2]) & 56 | (self.df[self.fitted_model.event_type_col] == self.fitted_model.events[1])) 57 | ] 58 | 59 | with self.assertRaises(RuntimeError): 60 | self.m.fit(df=tmp_df.drop(['C', 'T'], axis=1)) 61 | 62 | def test_fit_case_correct_fit(self): 63 | # Fit should be successful 64 | m = TwoStagesFitter() 65 | m.fit(df=self.df.drop(['C', 'T'], axis=1)) 66 | 67 | def test_print_summary(self): 68 | self.fitted_model.print_summary() 69 | 70 | def test_plot_event_alpha_case_correct_event(self): 71 | self.fitted_model.plot_event_alpha(event=self.fitted_model.events[0], show=False) 72 | 73 | def test_plot_event_alpha_case_correct_event_and_show(self): 74 | with patch('matplotlib.pyplot.show') as p: 75 | self.fitted_model.plot_event_alpha(event=self.fitted_model.events[0], show=True) 76 | 77 | def test_plot_event_alpha_case_incorrect_event(self): 78 | with self.assertRaises(AssertionError): 79 | self.fitted_model.plot_event_alpha(event=100, show=False) 80 | 81 | def test_plot_all_events_alpha(self): 82 | self.fitted_model.plot_all_events_alpha(show=False) 83 | 84 | def test_plot_all_events_alpha_case_show(self): 85 | with patch('matplotlib.pyplot.show') as p: 86 | self.fitted_model.plot_all_events_alpha(show=True) 87 | 88 | def test_plot_all_events_alpha_case_axis_exits(self): 89 | fig, ax = plt.subplots() 90 | self.fitted_model.plot_all_events_alpha(show=False, ax=ax) 91 | 92 | def test_get_beta_SE(self): 93 | self.fitted_model.get_beta_SE() 94 | 95 | def test_get_alpha_df(self): 96 | self.fitted_model.get_alpha_df() 97 | 98 | def test_plot_all_events_beta(self): 99 | self.fitted_model.plot_all_events_beta(show=False) 100 | 101 | def test_plot_all_events_beta_case_show(self): 102 | with patch('matplotlib.pyplot.show') as p: 103 | self.fitted_model.plot_all_events_beta(show=True) 104 | 105 | def test_plot_all_events_beta_case_ax_exists(self): 106 | fig, ax = plt.subplots() 107 | self.fitted_model.plot_all_events_beta(show=False, ax=ax) 108 | 109 | def test_predict_hazard_jt_case_covariate_not_in_df(self): 110 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict() 111 | with self.assertRaises(AssertionError): 112 | self.fitted_model.predict_hazard_jt( 113 | df=self.df.drop(['C', 'T', 'Z1'], axis=1), 114 | event=self.fitted_model.events[0], 115 | t=self.fitted_model.times[0]) 116 | 117 | def test_predict_hazard_jt_case_hazard_already_on_df(self): 118 | # print(self.df.drop(['C', 'T', 'X', 'J'], axis=1).set_index('pid').copy()) 119 | df_temp = get_real_hazard(self.df.drop(['C', 'T', 'X', 'J'], axis=1).set_index('pid').copy(), 120 | real_coef_dict=self.real_coef_dict, 121 | times=self.fitted_model.times, 122 | events=self.fitted_model.events) 123 | assert (df_temp == self.fitted_model.predict_hazard_jt(df=df_temp, 124 | event=self.fitted_model.events[0], 125 | t=self.fitted_model.times 126 | ) 127 | ).all().all() 128 | 129 | def test_hazard_transformation_result(self): 130 | from scipy.special import logit 131 | num = np.array([0.5]) 132 | a = logit(num) 133 | print(a) 134 | assert (a == self.fitted_model._hazard_transformation(num)).all() 135 | 136 | def test_predict_hazard_jt_case_event_not_in_events(self): 137 | # Event passed to .predict() must be in fitted events 138 | with self.assertRaises(AssertionError): 139 | self.fitted_model.predict_hazard_jt( 140 | df=self.df.drop(['C', 'T'], axis=1), event=100, t=self.fitted_model.times[0]) 141 | 142 | def test_predict_hazard_jt_case_time_not_in_times(self): 143 | # Event passed to .predict() must be in fitted events 144 | with self.assertRaises(AssertionError): 145 | self.fitted_model.predict_hazard_jt( 146 | df=self.df.drop(['C', 'T'], axis=1), event=self.fitted_model.events[0], t=1000) 147 | 148 | def test_predict_hazard_jt_case_successful_predict(self): 149 | self.fitted_model.predict_hazard_jt( 150 | df=self.df.drop(['C', 'T'], axis=1), 151 | event=self.fitted_model.events[0], t=self.fitted_model.times[0]) 152 | 153 | def test_predict_hazard_t_case_successful_predict(self): 154 | self.fitted_model.predict_hazard_t(df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[:3]) 155 | 156 | def test_predict_hazard_all_case_successful_predict(self): 157 | self.fitted_model.predict_hazard_all(df=self.df.drop(['C', 'T'], axis=1)) 158 | 159 | def test_predict_overall_survival_case_successful_predict(self): 160 | self.fitted_model.predict_overall_survival( 161 | df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[5], return_hazards=True) 162 | 163 | def test_predict_prob_event_j_at_t_case_successful_predict(self): 164 | self.fitted_model.predict_prob_event_j_at_t(df=self.df.drop(['C', 'T'], axis=1), 165 | event=self.fitted_model.events[0], 166 | t=self.fitted_model.times[3]) 167 | 168 | def test_predict_prob_event_j_all_case_successful_predict(self): 169 | self.fitted_model.predict_prob_event_j_all(df=self.df.drop(['C', 'T'], axis=1), 170 | event=self.fitted_model.events[0]) 171 | 172 | def test_predict_prob_events_case_successful_predict(self): 173 | self.fitted_model.predict_prob_events(df=self.df.drop(['C', 'T'], axis=1)) 174 | 175 | def test_predict_event_cumulative_incident_function_case_successful_predict(self): 176 | self.fitted_model.predict_event_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1), 177 | event=self.fitted_model.events[0]) 178 | 179 | def test_predict_cumulative_incident_function_case_successful_predict(self): 180 | self.fitted_model.predict_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1)) 181 | 182 | def test_predict_marginal_prob_function_case_successful(self): 183 | self.fitted_model.predict_marginal_prob_event_j(df=self.df.drop(columns=['C', 'T']), 184 | event=1) 185 | 186 | def test_predict_marginal_prob_function_case_event_not_exists(self): 187 | # Event passed to .predict_margnial() must be in fitted events 188 | with self.assertRaises(AssertionError): 189 | self.fitted_model.predict_marginal_prob_event_j( 190 | df=self.df.drop(['C', 'T'], axis=1), event=100) 191 | 192 | def test_predict_marginal_prob_function_cov_not_exists(self): 193 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict() 194 | with self.assertRaises(AssertionError): 195 | self.fitted_model.predict_marginal_prob_event_j( 196 | df=self.df.drop(columns=['C', 'T', 'Z1']), 197 | event=self.fitted_model.events[0]) 198 | 199 | def test_predict_marginal_prob_all_events_function_successful(self): 200 | self.fitted_model.predict_marginal_prob_all_events(df=self.df.drop(columns=['C', 'T'])) 201 | 202 | def test_predict_marginal_prob_all_events_cov_not_exits(self): 203 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict() 204 | with self.assertRaises(AssertionError): 205 | self.fitted_model.predict_marginal_prob_all_events( 206 | df=self.df.drop(columns=['C', 'T', 'Z1'])) 207 | 208 | def test_alpha_jt_function_value(self): 209 | t = 1 210 | j = 1 211 | row = self.fitted_model.alpha_df.query("X == @t and J == @j") 212 | x = row['alpha_jt'].item() 213 | y_t = (self.df["X"] 214 | .value_counts() 215 | .sort_index(ascending=False) # each event count for its occurring time and the times before 216 | .cumsum() 217 | .sort_index() 218 | ) 219 | rel_y_t = y_t.loc[t] 220 | rel_beta = self.fitted_model.beta_models[j].params_ 221 | n_jt = row['n_jt'] 222 | df = self.df.drop(columns=['C', 'T']) 223 | partial_df = df[df["X"] >= t] 224 | expit_add = np.dot(partial_df[self.fitted_model.covariates], rel_beta) 225 | from scipy.special import expit 226 | a_jt = ((1 / rel_y_t) * np.sum(expit(x + expit_add)) - (n_jt / rel_y_t)) ** 2 227 | a_jt_from_func = self.fitted_model._alpha_jt(x=x, df=df, 228 | y_t=rel_y_t, beta_j=rel_beta, 229 | n_jt=n_jt, t=t, event=j) 230 | self.assertEqual(a_jt.item(), a_jt_from_func.item()) 231 | 232 | def test_predict_event_jt_case_t1_not_hazard(self): 233 | self.fitted_model.predict_prob_event_j_at_t(df=self.df.drop(['C', 'T'], axis=1), 234 | event=self.fitted_model.events[0], 235 | t=self.fitted_model.times[0]) 236 | 237 | def test_predict_event_jt_case_t3_not_hazard(self): 238 | temp_df = self.df.drop(['C', 'T'], axis=1) 239 | temp_df = self.fitted_model.predict_overall_survival(df=temp_df, 240 | t=self.fitted_model.times[3], 241 | return_hazards=False) 242 | self.fitted_model.predict_prob_event_j_at_t(df=temp_df, 243 | event=self.fitted_model.events[0], 244 | t=self.fitted_model.times[3]) 245 | 246 | def test_regularization_same_for_all_beta_models(self): 247 | L1_regularized_fitter = TwoStagesFitter() 248 | 249 | fit_beta_kwargs = { 250 | 'model_kwargs': { 251 | 'penalizer': 0.003, 252 | 'l1_ratio': 1 253 | } 254 | } 255 | 256 | L1_regularized_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs) 257 | 258 | def test_regularization_different_to_each_beta_model(self): 259 | L1_regularized_fitter = TwoStagesFitter() 260 | 261 | fit_beta_kwargs = { 262 | 'model_kwargs': { 263 | 1: { 264 | 'penalizer': 0.003, 265 | 'l1_ratio': 1 266 | }, 267 | 2: { 268 | 'penalizer': 0.005, 269 | 'l1_ratio': 1 270 | } 271 | } 272 | } 273 | 274 | L1_regularized_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs) 275 | 276 | def test_different_covariates_to_each_beta_model(self): 277 | twostages_fitter = TwoStagesFitter() 278 | covariates = { 279 | 1: ['Z1', 'Z2', 'Z3'], 280 | 2: ['Z2', 'Z3', 'Z4', 'Z5'] 281 | } 282 | twostages_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=covariates) 283 | 284 | def test_different_covariates_to_each_beta_model_prediction(self): 285 | twostages_fitter = TwoStagesFitter() 286 | covariates = { 287 | 1: ['Z1', 'Z2', 'Z3'], 288 | 2: ['Z2', 'Z3', 'Z4', 'Z5'] 289 | } 290 | twostages_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=covariates) 291 | twostages_fitter.predict_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1)) 292 | 293 | -------------------------------------------------------------------------------- /tests/test_TwoStagesFitterExact.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import matplotlib.pyplot as plt 4 | from unittest.mock import patch 5 | from src.pydts.examples_utils.generate_simulations_data import generate_quick_start_df 6 | from src.pydts.fitters import TwoStagesFitterExact 7 | from src.pydts.data_generation import EventTimesSampler 8 | from src.pydts.utils import get_real_hazard 9 | import numpy as np 10 | import pandas as pd 11 | 12 | 13 | class TestTwoStagesFitterExact(unittest.TestCase): 14 | 15 | def setUp(self): 16 | self.real_coef_dict = { 17 | "alpha": { 18 | 1: lambda t: -1. + 0.4 * np.log(t), 19 | 2: lambda t: -1. + 0.4 * np.log(t), 20 | }, 21 | "beta": { 22 | 1: -0.4*np.log([0.8, 3, 3, 2.5, 2]), 23 | 2: -0.3*np.log([1, 3, 4, 3, 2]), 24 | } 25 | } 26 | self.df = generate_quick_start_df(n_patients=300, n_cov=5, d_times=4, j_events=2, pid_col='pid', seed=0, 27 | real_coef_dict=self.real_coef_dict, censoring_prob=0.1) 28 | 29 | self.m = TwoStagesFitterExact() 30 | self.fitted_model = TwoStagesFitterExact() 31 | self.fitted_model.fit(df=self.df.drop(['C', 'T'], axis=1)) 32 | 33 | def test_fit_case_event_col_not_in_df(self): 34 | # Event column (here 'J') must be passed in df to .fit() 35 | print(self.fitted_model.get_beta_SE()) 36 | 37 | 38 | def test_fit_case_event_col_not_in_df(self): 39 | # Event column (here 'J') must be passed in df to .fit() 40 | with self.assertRaises(AssertionError): 41 | self.m.fit(df=self.df.drop(['C', 'J', 'T'], axis=1)) 42 | 43 | def test_fit_case_duration_col_not_in_df(self): 44 | # Duration column (here 'X') must be passed in df to .fit() 45 | with self.assertRaises(AssertionError): 46 | self.m.fit(df=self.df.drop(['C', 'X', 'T'], axis=1)) 47 | 48 | def test_fit_case_pid_col_not_in_df(self): 49 | # Duration column (here 'pid') must be passed in df to .fit() 50 | with self.assertRaises(AssertionError): 51 | self.m.fit(df=self.df.drop(['C', 'pid', 'T'], axis=1)) 52 | 53 | def test_fit_case_cov_col_not_in_df(self): 54 | # Covariates columns (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .fit() 55 | with self.assertRaises(ValueError): 56 | self.m.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=['Z6']) 57 | 58 | def test_fit_case_missing_jt(self): 59 | # df to .fit() should contain observations to all (event, times) 60 | 61 | # drop (events[1], times[2]) 62 | tmp_df = self.df[ 63 | ~((self.df[self.fitted_model.duration_col] == self.fitted_model.times[2]) & 64 | (self.df[self.fitted_model.event_type_col] == self.fitted_model.events[1])) 65 | ] 66 | 67 | with self.assertRaises(RuntimeError): 68 | self.m.fit(df=tmp_df.drop(['C', 'T'], axis=1)) 69 | 70 | def test_fit_case_correct_fit(self): 71 | # Fit should be successful 72 | m = TwoStagesFitterExact() 73 | m.fit(df=self.df.drop(['C', 'T'], axis=1)) 74 | 75 | def test_print_summary(self): 76 | self.fitted_model.print_summary() 77 | 78 | def test_plot_event_alpha_case_correct_event(self): 79 | self.fitted_model.plot_event_alpha(event=self.fitted_model.events[0], show=False) 80 | 81 | def test_plot_event_alpha_case_correct_event_and_show(self): 82 | with patch('matplotlib.pyplot.show') as p: 83 | self.fitted_model.plot_event_alpha(event=self.fitted_model.events[0], show=True) 84 | 85 | def test_plot_event_alpha_case_incorrect_event(self): 86 | with self.assertRaises(AssertionError): 87 | self.fitted_model.plot_event_alpha(event=100, show=False) 88 | 89 | def test_plot_all_events_alpha(self): 90 | self.fitted_model.plot_all_events_alpha(show=False) 91 | 92 | def test_plot_all_events_alpha_case_show(self): 93 | with patch('matplotlib.pyplot.show') as p: 94 | self.fitted_model.plot_all_events_alpha(show=True) 95 | 96 | def test_plot_all_events_alpha_case_axis_exits(self): 97 | fig, ax = plt.subplots() 98 | self.fitted_model.plot_all_events_alpha(show=False, ax=ax) 99 | 100 | def test_get_beta_SE(self): 101 | self.fitted_model.get_beta_SE() 102 | 103 | def test_get_alpha_df(self): 104 | self.fitted_model.get_alpha_df() 105 | 106 | # def test_plot_all_events_beta(self): 107 | # self.fitted_model.plot_all_events_beta(show=False) 108 | # 109 | # def test_plot_all_events_beta_case_show(self): 110 | # with patch('matplotlib.pyplot.show') as p: 111 | # self.fitted_model.plot_all_events_beta(show=True) 112 | # 113 | # def test_plot_all_events_beta_case_ax_exists(self): 114 | # fig, ax = plt.subplots() 115 | # self.fitted_model.plot_all_events_beta(show=False, ax=ax) 116 | 117 | def test_predict_hazard_jt_case_covariate_not_in_df(self): 118 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict() 119 | with self.assertRaises(AssertionError): 120 | self.fitted_model.predict_hazard_jt( 121 | df=self.df.drop(['C', 'T', 'Z1'], axis=1), 122 | event=self.fitted_model.events[0], 123 | t=self.fitted_model.times[0]) 124 | 125 | def test_predict_hazard_jt_case_hazard_already_on_df(self): 126 | # print(self.df.drop(['C', 'T', 'X', 'J'], axis=1).set_index('pid').copy()) 127 | df_temp = get_real_hazard(self.df.drop(['C', 'T', 'X', 'J'], axis=1).set_index('pid').copy(), 128 | real_coef_dict=self.real_coef_dict, 129 | times=self.fitted_model.times, 130 | events=self.fitted_model.events) 131 | assert (df_temp == self.fitted_model.predict_hazard_jt(df=df_temp, 132 | event=self.fitted_model.events[0], 133 | t=self.fitted_model.times 134 | ) 135 | ).all().all() 136 | 137 | def test_hazard_transformation_result(self): 138 | from scipy.special import logit 139 | num = np.array([0.5]) 140 | a = logit(num) 141 | print(a) 142 | assert (a == self.fitted_model._hazard_transformation(num)).all() 143 | 144 | def test_predict_hazard_jt_case_event_not_in_events(self): 145 | # Event passed to .predict() must be in fitted events 146 | with self.assertRaises(AssertionError): 147 | self.fitted_model.predict_hazard_jt( 148 | df=self.df.drop(['C', 'T'], axis=1), event=100, t=self.fitted_model.times[0]) 149 | 150 | def test_predict_hazard_jt_case_time_not_in_times(self): 151 | # Event passed to .predict() must be in fitted events 152 | with self.assertRaises(AssertionError): 153 | self.fitted_model.predict_hazard_jt( 154 | df=self.df.drop(['C', 'T'], axis=1), event=self.fitted_model.events[0], t=1000) 155 | 156 | def test_predict_hazard_jt_case_successful_predict(self): 157 | self.fitted_model.predict_hazard_jt( 158 | df=self.df.drop(['C', 'T'], axis=1), 159 | event=self.fitted_model.events[0], t=self.fitted_model.times[0]) 160 | 161 | def test_predict_hazard_t_case_successful_predict(self): 162 | self.fitted_model.predict_hazard_t(df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[:3]) 163 | 164 | def test_predict_hazard_all_case_successful_predict(self): 165 | self.fitted_model.predict_hazard_all(df=self.df.drop(['C', 'T'], axis=1)) 166 | 167 | def test_predict_overall_survival_case_successful_predict(self): 168 | self.fitted_model.predict_overall_survival( 169 | df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[3], return_hazards=True) 170 | 171 | def test_predict_prob_event_j_at_t_case_successful_predict(self): 172 | self.fitted_model.predict_prob_event_j_at_t(df=self.df.drop(['C', 'T'], axis=1), 173 | event=self.fitted_model.events[0], 174 | t=self.fitted_model.times[3]) 175 | 176 | def test_predict_prob_event_j_all_case_successful_predict(self): 177 | self.fitted_model.predict_prob_event_j_all(df=self.df.drop(['C', 'T'], axis=1), 178 | event=self.fitted_model.events[0]) 179 | 180 | def test_predict_prob_events_case_successful_predict(self): 181 | self.fitted_model.predict_prob_events(df=self.df.drop(['C', 'T'], axis=1)) 182 | 183 | def test_predict_event_cumulative_incident_function_case_successful_predict(self): 184 | self.fitted_model.predict_event_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1), 185 | event=self.fitted_model.events[0]) 186 | 187 | def test_predict_cumulative_incident_function_case_successful_predict(self): 188 | self.fitted_model.predict_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1)) 189 | 190 | def test_predict_marginal_prob_function_case_successful(self): 191 | self.fitted_model.predict_marginal_prob_event_j(df=self.df.drop(columns=['C', 'T']), 192 | event=1) 193 | 194 | def test_predict_marginal_prob_function_case_event_not_exists(self): 195 | # Event passed to .predict_margnial() must be in fitted events 196 | with self.assertRaises(AssertionError): 197 | self.fitted_model.predict_marginal_prob_event_j( 198 | df=self.df.drop(['C', 'T'], axis=1), event=100) 199 | 200 | def test_predict_marginal_prob_function_cov_not_exists(self): 201 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict() 202 | with self.assertRaises(AssertionError): 203 | self.fitted_model.predict_marginal_prob_event_j( 204 | df=self.df.drop(columns=['C', 'T', 'Z1']), 205 | event=self.fitted_model.events[0]) 206 | 207 | def test_predict_marginal_prob_all_events_function_successful(self): 208 | self.fitted_model.predict_marginal_prob_all_events(df=self.df.drop(columns=['C', 'T'])) 209 | 210 | def test_predict_marginal_prob_all_events_cov_not_exits(self): 211 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict() 212 | with self.assertRaises(AssertionError): 213 | self.fitted_model.predict_marginal_prob_all_events( 214 | df=self.df.drop(columns=['C', 'T', 'Z1'])) 215 | 216 | def test_alpha_jt_function_value(self): 217 | t = 1 218 | j = 1 219 | row = self.fitted_model.alpha_df.query("X == @t and J == @j") 220 | x = row['alpha_jt'].item() 221 | y_t = (self.df["X"] 222 | .value_counts() 223 | .sort_index(ascending=False) # each event count for its occurring time and the times before 224 | .cumsum() 225 | .sort_index() 226 | ) 227 | rel_y_t = y_t.loc[t] 228 | #rel_beta = self.fitted_model.beta_models[j].params_ 229 | rel_beta = getattr(self.fitted_model.beta_models[j], self.fitted_model.beta_models_params_attr) 230 | n_jt = row['n_jt'] 231 | df = self.df.drop(columns=['C', 'T']) 232 | partial_df = df[df["X"] >= t] 233 | expit_add = np.dot(partial_df[self.fitted_model.covariates], rel_beta) 234 | from scipy.special import expit 235 | a_jt = ((1 / rel_y_t) * np.sum(expit(x + expit_add)) - (n_jt / rel_y_t)) ** 2 236 | a_jt_from_func = self.fitted_model._alpha_jt(x=x, df=df, 237 | y_t=rel_y_t, beta_j=rel_beta, 238 | n_jt=n_jt, t=t, event=j) 239 | self.assertEqual(a_jt.item(), a_jt_from_func.item()) 240 | 241 | def test_predict_event_jt_case_t1_not_hazard(self): 242 | self.fitted_model.predict_prob_event_j_at_t(df=self.df.drop(['C', 'T'], axis=1), 243 | event=self.fitted_model.events[0], 244 | t=self.fitted_model.times[0]) 245 | 246 | def test_predict_event_jt_case_t3_not_hazard(self): 247 | temp_df = self.df.drop(['C', 'T'], axis=1) 248 | temp_df = self.fitted_model.predict_overall_survival(df=temp_df, 249 | t=self.fitted_model.times[3], 250 | return_hazards=False) 251 | self.fitted_model.predict_prob_event_j_at_t(df=temp_df, 252 | event=self.fitted_model.events[0], 253 | t=self.fitted_model.times[3]) 254 | 255 | def test_regularization_same_for_all_beta_models(self): 256 | L1_regularized_fitter = TwoStagesFitterExact() 257 | 258 | fit_beta_kwargs = { 259 | 'model_kwargs': { 260 | 'alpha': 0.03, 261 | 'L1_wt': 1 262 | } 263 | } 264 | 265 | L1_regularized_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs) 266 | print(L1_regularized_fitter.get_beta_SE()) 267 | 268 | def test_regularization_different_to_each_beta_model(self): 269 | L1_regularized_fitter = TwoStagesFitterExact() 270 | 271 | fit_beta_kwargs = { 272 | 'model_fit_kwargs': { 273 | 1: { 274 | 'alpha': 0.003, 275 | 'L1_wt': 1 276 | }, 277 | 2: { 278 | 'alpha': 0.005, 279 | 'L1_wt': 1 280 | } 281 | } 282 | } 283 | 284 | L1_regularized_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs) 285 | # print(L1_regularized_fitter.get_beta_SE()) 286 | 287 | L2_regularized_fitter = TwoStagesFitterExact() 288 | 289 | fit_beta_kwargs = { 290 | 'model_fit_kwargs': { 291 | 1: { 292 | 'alpha': 0.003, 293 | 'L1_wt': 0 294 | }, 295 | 2: { 296 | 'alpha': 0.005, 297 | 'L1_wt': 0 298 | } 299 | } 300 | } 301 | 302 | L2_regularized_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs) 303 | # print(L2_regularized_fitter.get_beta_SE()) 304 | 305 | def test_different_covariates_to_each_beta_model(self): 306 | twostages_fitter = TwoStagesFitterExact() 307 | covariates = { 308 | 1: ['Z1', 'Z2', 'Z3'], 309 | 2: ['Z2', 'Z3', 'Z4', 'Z5'] 310 | } 311 | twostages_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=covariates) 312 | 313 | def test_different_covariates_to_each_beta_model_prediction(self): 314 | twostages_fitter = TwoStagesFitterExact() 315 | covariates = { 316 | 1: ['Z1', 'Z2', 'Z3'], 317 | 2: ['Z2', 'Z3', 'Z4', 'Z5'] 318 | } 319 | twostages_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=covariates) 320 | twostages_fitter.predict_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1)) 321 | -------------------------------------------------------------------------------- /tests/test_basefitter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from src.pydts.base_fitters import BaseFitter, ExpansionBasedFitter 6 | from src.pydts.examples_utils.generate_simulations_data import generate_quick_start_df 7 | 8 | 9 | class TestBaseFitter(unittest.TestCase): 10 | def setUp(self): 11 | self.real_coef_dict = { 12 | "alpha": { 13 | 1: lambda t: -1 - 0.3 * np.log(t), 14 | 2: lambda t: -1.75 - 0.15 * np.log(t) 15 | }, 16 | "beta": { 17 | 1: -np.log([0.8, 3, 3, 2.5, 2]), 18 | 2: -np.log([1, 3, 4, 3, 2]) 19 | } 20 | } 21 | self.df = generate_quick_start_df(n_patients=1000, n_cov=5, d_times=10, j_events=2, pid_col='pid', seed=0, 22 | real_coef_dict=self.real_coef_dict, censoring_prob=0.8) 23 | 24 | self.base_fitter = BaseFitter() 25 | self.expansion_fitter = ExpansionBasedFitter() 26 | 27 | def test_base_fit(self): 28 | with self.assertRaises(NotImplementedError): 29 | self.base_fitter.fit(self.df) 30 | 31 | def test_base_predict(self): 32 | with self.assertRaises(NotImplementedError): 33 | self.base_fitter.predict(self.df) 34 | 35 | def test_base_evaluate(self): 36 | with self.assertRaises(NotImplementedError): 37 | self.base_fitter.evaluate(self.df) 38 | 39 | def test_base_print_summary(self): 40 | with self.assertRaises(NotImplementedError): 41 | self.base_fitter.print_summary() 42 | 43 | def test_expansion_predict_hazard_not_implemented(self): 44 | with self.assertRaises(NotImplementedError): 45 | self.expansion_fitter.predict_hazard_jt(self.df, event=1, t=100) -------------------------------------------------------------------------------- /tests/test_cross_validation.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import pandas as pd 4 | from src.pydts.data_generation import EventTimesSampler 5 | from src.pydts.cross_validation import TwoStagesCV, PenaltyGridSearchCV, TwoStagesCVExact, PenaltyGridSearchCVExact 6 | 7 | 8 | class TestCrossValidation(unittest.TestCase): 9 | 10 | def setUp(self): 11 | n_cov = 6 12 | beta1 = np.zeros(n_cov) 13 | beta1[:5] = (-0.5 * np.log([0.8, 3, 3, 2.5, 2])) 14 | beta2 = np.zeros(n_cov) 15 | beta2[:5] = (-0.5 * np.log([1, 3, 4, 3, 2])) 16 | 17 | real_coef_dict = { 18 | "alpha": { 19 | 1: lambda t: -2.0 - 0.2 * np.log(t), 20 | 2: lambda t: -2.2 - 0.2 * np.log(t) 21 | }, 22 | "beta": { 23 | 1: beta1, 24 | 2: beta2 25 | } 26 | } 27 | n_patients = 1000 28 | d_times = 5 29 | j_events = 2 30 | 31 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 32 | 33 | seed = 0 34 | means_vector = np.zeros(n_cov) 35 | covariance_matrix = 0.5 * np.identity(n_cov) 36 | clip_value = 1 37 | 38 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 39 | 40 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix, 41 | size=n_patients), 42 | columns=covariates)) 43 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True) 44 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 45 | patients_df.index.name = 'pid' 46 | patients_df = patients_df.reset_index() 47 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times)) 48 | self.patients_df = ets.update_event_or_lof(patients_df) 49 | self.tscv = TwoStagesCV() 50 | 51 | def test_cross_validation_bs(self): 52 | self.tscv.cross_validate(self.patients_df, metrics='BS', n_splits=2) 53 | 54 | def test_cross_validation_auc(self): 55 | self.tscv.cross_validate(self.patients_df, metrics='AUC', n_splits=2) 56 | 57 | def test_cross_validation_iauc(self): 58 | self.tscv.cross_validate(self.patients_df, metrics='IAUC', n_splits=2) 59 | 60 | def test_cross_validation_gauc(self): 61 | self.tscv.cross_validate(self.patients_df, metrics='GAUC', n_splits=2) 62 | 63 | def test_cross_validation_ibs(self): 64 | self.tscv.cross_validate(self.patients_df, metrics='IBS', n_splits=2) 65 | 66 | def test_cross_validation_gbs(self): 67 | self.tscv.cross_validate(self.patients_df, metrics='GBS', n_splits=3) 68 | 69 | 70 | class TestCrossValidationExact(TestCrossValidation): 71 | 72 | def setUp(self): 73 | n_cov = 6 74 | beta1 = np.zeros(n_cov) 75 | beta1[:5] = (-0.25 * np.log([0.8, 3, 3, 2.5, 2])) 76 | beta2 = np.zeros(n_cov) 77 | beta2[:5] = (-0.25 * np.log([1, 3, 4, 3, 2])) 78 | 79 | real_coef_dict = { 80 | "alpha": { 81 | 1: lambda t: -1.7 - 0.2 * np.log(t), 82 | 2: lambda t: -1.8 - 0.2 * np.log(t) 83 | }, 84 | "beta": { 85 | 1: beta1, 86 | 2: beta2 87 | } 88 | } 89 | n_patients = 500 90 | d_times = 4 91 | j_events = 2 92 | 93 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 94 | 95 | seed = 0 96 | means_vector = np.zeros(n_cov) 97 | covariance_matrix = 0.5 * np.identity(n_cov) 98 | clip_value = 1 99 | 100 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 101 | 102 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix, 103 | size=n_patients), 104 | columns=covariates)) 105 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True) 106 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 107 | patients_df.index.name = 'pid' 108 | patients_df = patients_df.reset_index() 109 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times)) 110 | self.patients_df = ets.update_event_or_lof(patients_df) 111 | self.tscv = TwoStagesCVExact() 112 | 113 | 114 | class TestPenaltyGridSearchCV(unittest.TestCase): 115 | 116 | def setUp(self): 117 | n_cov = 6 118 | beta1 = np.zeros(n_cov) 119 | beta1[:5] = (-0.5 * np.log([0.8, 3, 3, 2.5, 2])) 120 | beta2 = np.zeros(n_cov) 121 | beta2[:5] = (-0.5 * np.log([1, 3, 4, 3, 2])) 122 | 123 | real_coef_dict = { 124 | "alpha": { 125 | 1: lambda t: -2.0 - 0.2 * np.log(t), 126 | 2: lambda t: -2.2 - 0.2 * np.log(t) 127 | }, 128 | "beta": { 129 | 1: beta1, 130 | 2: beta2 131 | } 132 | } 133 | n_patients = 2000 134 | d_times = 5 135 | j_events = 2 136 | 137 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 138 | 139 | seed = 0 140 | means_vector = np.zeros(n_cov) 141 | covariance_matrix = 0.5 * np.identity(n_cov) 142 | clip_value = 1 143 | 144 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 145 | 146 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix, 147 | size=n_patients), 148 | columns=covariates)) 149 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True) 150 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 151 | patients_df.index.name = 'pid' 152 | patients_df = patients_df.reset_index() 153 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times)) 154 | self.patients_df = ets.update_event_or_lof(patients_df) 155 | self.pgscv = PenaltyGridSearchCV() 156 | 157 | def test_penalty_grid_search_cross_validate(self): 158 | self.pgscv.cross_validate(full_df=self.patients_df, 159 | l1_ratio=1, 160 | n_splits=2, 161 | penalizers=[0.0001, 0.02], 162 | seed=0) 163 | 164 | 165 | class TestPenaltyGridSearchCVExact(TestPenaltyGridSearchCV): 166 | 167 | def setUp(self): 168 | n_cov = 6 169 | beta1 = np.zeros(n_cov) 170 | beta1[:5] = (-0.3 * np.log([0.8, 3, 3, 2.5, 2])) 171 | beta2 = np.zeros(n_cov) 172 | beta2[:5] = (-0.3 * np.log([1, 3, 4, 3, 2])) 173 | 174 | real_coef_dict = { 175 | "alpha": { 176 | 1: lambda t: -1.9 + 0.2 * np.log(t), 177 | 2: lambda t: -1.9 + 0.2 * np.log(t) 178 | }, 179 | "beta": { 180 | 1: beta1, 181 | 2: beta2 182 | } 183 | } 184 | n_patients = 400 185 | d_times = 4 186 | j_events = 2 187 | 188 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 189 | 190 | seed = 0 191 | means_vector = np.zeros(n_cov) 192 | covariance_matrix = 0.5 * np.identity(n_cov) 193 | clip_value = 1 194 | 195 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 196 | 197 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix, 198 | size=n_patients), 199 | columns=covariates)) 200 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True) 201 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 202 | patients_df.index.name = 'pid' 203 | patients_df = patients_df.reset_index() 204 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times)) 205 | self.patients_df = ets.update_event_or_lof(patients_df) 206 | self.pgscv = PenaltyGridSearchCVExact() 207 | -------------------------------------------------------------------------------- /tests/test_model_selection.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import pandas as pd 4 | from src.pydts.data_generation import EventTimesSampler 5 | from src.pydts.model_selection import PenaltyGridSearch, PenaltyGridSearchExact 6 | from sklearn.model_selection import train_test_split 7 | from src.pydts.fitters import TwoStagesFitter, TwoStagesFitterExact 8 | 9 | 10 | class TestPenaltyGridSearch(unittest.TestCase): 11 | 12 | def setUp(self): 13 | n_cov = 6 14 | beta1 = np.zeros(n_cov) 15 | beta1[:5] = (-0.5 * np.log([0.8, 3, 3, 2.5, 2])) 16 | beta2 = np.zeros(n_cov) 17 | beta2[:5] = (-0.5 * np.log([1, 3, 4, 3, 2])) 18 | 19 | real_coef_dict = { 20 | "alpha": { 21 | 1: lambda t: -2.0 - 0.2 * np.log(t), 22 | 2: lambda t: -2.2 - 0.2 * np.log(t) 23 | }, 24 | "beta": { 25 | 1: beta1, 26 | 2: beta2 27 | } 28 | } 29 | n_patients = 1000 30 | d_times = 5 31 | j_events = 2 32 | 33 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 34 | 35 | seed = 0 36 | means_vector = np.zeros(n_cov) 37 | covariance_matrix = 0.5 * np.identity(n_cov) 38 | clip_value = 1 39 | 40 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 41 | 42 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix, 43 | size=n_patients), 44 | columns=covariates)) 45 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True) 46 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 47 | patients_df.index.name = 'pid' 48 | patients_df = patients_df.reset_index() 49 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times)) 50 | self.patients_df = ets.update_event_or_lof(patients_df) 51 | self.pgs = PenaltyGridSearch() 52 | 53 | def test_get_mixed_two_stages_fitter(self): 54 | train_df, test_df = train_test_split(self.patients_df.drop(['C', 'T'], axis=1), random_state=0) 55 | self.pgs.evaluate(train_df=train_df, 56 | test_df=test_df, 57 | l1_ratio=1, 58 | metrics=[], 59 | penalizers=[0.005, 0.02], 60 | seed=0) 61 | mixed_two_stages = self.pgs.get_mixed_two_stages_fitter([0.005, 0.02]) 62 | 63 | mixed_params = pd.concat([mixed_two_stages.beta_models[1].params_, 64 | mixed_two_stages.beta_models[2].params_], axis=1) 65 | 66 | fit_beta_kwargs = { 67 | 'model_kwargs': { 68 | 1: {'penalizer': 0.005, 'l1_ratio': 1}, 69 | 2: {'penalizer': 0.02, 'l1_ratio': 1}, 70 | } 71 | } 72 | 73 | two_stages_fitter = TwoStagesFitter() 74 | two_stages_fitter.fit(df=train_df, fit_beta_kwargs=fit_beta_kwargs) 75 | 76 | two_stages_params = pd.concat([two_stages_fitter.beta_models[1].params_, 77 | two_stages_fitter.beta_models[2].params_], axis=1) 78 | 79 | pd.testing.assert_frame_equal(mixed_params, two_stages_params) 80 | pd.testing.assert_frame_equal(mixed_two_stages.alpha_df, two_stages_fitter.alpha_df) 81 | 82 | def test_assertion_get_mixed_two_stages_fitter_not_included(self): 83 | with self.assertRaises(AssertionError): 84 | 85 | train_df, test_df = train_test_split(self.patients_df.drop(['C', 'T'], axis=1), random_state=0) 86 | self.pgs.evaluate(train_df=train_df, 87 | test_df=test_df, 88 | l1_ratio=1, 89 | metrics=[], 90 | penalizers=[0.005, 0.02], 91 | seed=0) 92 | 93 | mixed_two_stages = self.pgs.get_mixed_two_stages_fitter([0.1, 0.02]) 94 | 95 | def test_assertion_get_mixed_two_stages_fitter_empty(self): 96 | with self.assertRaises(AssertionError): 97 | mixed_two_stages = self.pgs.get_mixed_two_stages_fitter([0.001, 0.02]) 98 | 99 | def test_evaluate(self): 100 | train_df, test_df = train_test_split(self.patients_df.drop(['C', 'T'], axis=1), random_state=0) 101 | idx_max = self.pgs.evaluate(train_df=train_df, test_df=test_df, l1_ratio=1, 102 | penalizers=[0.0001, 0.005, 0.02], 103 | seed=0) 104 | 105 | def test_convert_results_dict_to_df(self): 106 | train_df, test_df = train_test_split(self.patients_df.drop(['C', 'T'], axis=1), random_state=0) 107 | idx_max = self.pgs.evaluate(train_df=train_df, test_df=test_df, l1_ratio=1, 108 | penalizers=[0.0001, 0.005], 109 | seed=0) 110 | self.pgs.convert_results_dict_to_df(self.pgs.global_bs) 111 | 112 | 113 | class TestPenaltyGridSearchExact(TestPenaltyGridSearch): 114 | 115 | def setUp(self): 116 | n_cov = 6 117 | beta1 = np.zeros(n_cov) 118 | beta1[:5] = (-0.5 * np.log([0.8, 3, 3, 2.5, 2])) 119 | beta2 = np.zeros(n_cov) 120 | beta2[:5] = (-0.5 * np.log([1, 3, 4, 3, 2])) 121 | 122 | real_coef_dict = { 123 | "alpha": { 124 | 1: lambda t: -2.0 - 0.2 * np.log(t), 125 | 2: lambda t: -2.2 - 0.2 * np.log(t) 126 | }, 127 | "beta": { 128 | 1: beta1, 129 | 2: beta2 130 | } 131 | } 132 | n_patients = 300 133 | d_times = 4 134 | j_events = 2 135 | 136 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 137 | 138 | seed = 0 139 | means_vector = np.zeros(n_cov) 140 | covariance_matrix = 0.5 * np.identity(n_cov) 141 | clip_value = 1 142 | 143 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 144 | 145 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix, 146 | size=n_patients), 147 | columns=covariates)) 148 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True) 149 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 150 | patients_df.index.name = 'pid' 151 | patients_df = patients_df.reset_index() 152 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times)) 153 | self.patients_df = ets.update_event_or_lof(patients_df) 154 | self.pgs = PenaltyGridSearchExact() 155 | 156 | def test_get_mixed_two_stages_fitter(self): 157 | train_df, test_df = train_test_split(self.patients_df.drop(['C', 'T'], axis=1), random_state=0) 158 | self.pgs.evaluate(train_df=train_df, 159 | test_df=test_df, 160 | l1_ratio=1, 161 | metrics=[], 162 | penalizers=[0.005, 0.02], 163 | seed=0) 164 | mixed_two_stages = self.pgs.get_mixed_two_stages_fitter([0.005, 0.02]) 165 | 166 | mixed_params = pd.concat([mixed_two_stages.beta_models[1].params, 167 | mixed_two_stages.beta_models[2].params], axis=1) 168 | 169 | 170 | fit_beta_kwargs = { 171 | 'model_fit_kwargs': { 172 | 1: { 173 | 'alpha': 0.005, 174 | 'L1_wt': 1 175 | }, 176 | 2: { 177 | 'alpha': 0.02, 178 | 'L1_wt': 1 179 | } 180 | } 181 | } 182 | 183 | two_stages_fitter = TwoStagesFitterExact() 184 | two_stages_fitter.fit(df=train_df, fit_beta_kwargs=fit_beta_kwargs) 185 | 186 | two_stages_params = pd.concat([two_stages_fitter.beta_models[1].params, 187 | two_stages_fitter.beta_models[2].params], axis=1) 188 | 189 | pd.testing.assert_frame_equal(mixed_params, two_stages_params) 190 | pd.testing.assert_frame_equal(mixed_two_stages.alpha_df, two_stages_fitter.alpha_df) 191 | -------------------------------------------------------------------------------- /tests/test_repetative_fitter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | # 3 | # import numpy as np 4 | # 5 | # from src.pydts.fitters import repetitive_fitters 6 | # 7 | # 8 | class TestRepFitters(unittest.TestCase): 9 | def setUp(self): 10 | pass 11 | # self.real_coef_dict = { 12 | # "alpha": { 13 | # 1: lambda t: -1 - 0.3 * np.log(t), 14 | # 2: lambda t: -1.75 - 0.15 * np.log(t) 15 | # }, 16 | # "beta": { 17 | # 1: -np.log([0.8, 3, 3, 2.5, 2]), 18 | # 2: -np.log([1, 3, 4, 3, 2]) 19 | # } 20 | # } 21 | # 22 | # self.n_patients = 10000 23 | # self.n_cov = 5 24 | # self.d = 15 25 | # 26 | # # def test_fit_function_case_successful(self): 27 | # # _ = repetitive_fitters(rep=5, n_patients=self.n_patients, n_cov=self.n_cov, 28 | # # d_times=self.d, j_events=2, pid_col='pid', verbose=0, 29 | # # allow_fails=20, real_coef_dict=self.real_coef_dict, 30 | # # censoring_prob=.8) 31 | # # 32 | # # def test_fit_not_sending_coef(self): 33 | # # # event where fit are sent without real coefficient dict 34 | # # with self.assertRaises(AssertionError): 35 | # # _ = repetitive_fitters(rep=5, n_patients=self.n_patients, n_cov=self.n_cov, 36 | # # d_times=self.d, j_events=2, pid_col='pid', verbose=0, 37 | # # allow_fails=20, censoring_prob=.8) 38 | # # 39 | # # def test_fit_repetitive_function_case_j_event_not_equal_to_real_coef(self): 40 | # # # event where fit are sent with wrong j_events, causing except to print it, 41 | # # # but not deal with value error in the end 42 | # # with self.assertRaises(ValueError): 43 | # # _ = repetitive_fitters(rep=2, n_patients=self.n_patients, n_cov=self.n_cov, 44 | # # d_times=self.d, j_events=3, pid_col='pid', verbose=0, 45 | # # allow_fails=0, real_coef_dict=self.real_coef_dict, 46 | # # censoring_prob=.8) 47 | # # 48 | # # def test_fit_function_case_second_model_is_not_twoStages(self): 49 | # # from src.pydts.fitters import DataExpansionFitter 50 | # # _ = repetitive_fitters(rep=2, n_patients=self.n_patients, n_cov=self.n_cov, 51 | # # d_times=self.d, j_events=2, pid_col='pid', 52 | # # model2=DataExpansionFitter, verbose=0, 53 | # # allow_fails=20, real_coef_dict=self.real_coef_dict, 54 | # # censoring_prob=.8) -------------------------------------------------------------------------------- /tests/test_screening.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from src.pydts.data_generation import EventTimesSampler 3 | from src.pydts.screening import SISTwoStagesFitterExact, SISTwoStagesFitter, get_expanded_df 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | class TestScreening(unittest.TestCase): 9 | 10 | def setUp(self): 11 | n_cov = 50 12 | beta1 = np.zeros(n_cov) 13 | beta1[:5] = np.array([-0.6, 0.5, -0.5, 0.6, -0.6]) 14 | beta2 = np.zeros(n_cov) 15 | beta2[:5] = np.array([0.5, -0.7, 0.7, -0.5, -0.7]) 16 | 17 | real_coef_dict = { 18 | "alpha": { 19 | 1: lambda t: -2.6 + 0.1 * np.log(t), 20 | 2: lambda t: -2.7 + 0.2 * np.log(t) 21 | }, 22 | "beta": { 23 | 1: beta1, 24 | 2: beta2 25 | } 26 | } 27 | 28 | n_patients = 400 29 | d_times = 7 30 | j_events = 2 31 | 32 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 33 | 34 | seed = 2 35 | means_vector = np.zeros(n_cov) 36 | covariance_matrix = np.identity(n_cov) 37 | 38 | clip_value = 3 39 | 40 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 41 | 42 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix, 43 | size=n_patients), 44 | columns=covariates)) 45 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True) 46 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 47 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times), 48 | seed=seed + 1) 49 | patients_df = ets.update_event_or_lof(patients_df) 50 | patients_df.index.name = 'pid' 51 | self.patients_df = patients_df.reset_index() 52 | self.covariates = covariates 53 | self.fitter = SISTwoStagesFitter() 54 | 55 | def test_psis_permute_df(self): 56 | self.fitter.permute_df(df=self.patients_df) 57 | 58 | def test_psis_fit_marginal_model(self): 59 | expanded_df = get_expanded_df(self.patients_df.drop(['C', 'T'], axis=1)) 60 | self.fitter.fit_marginal_model(expanded_df, covariate='Z1') 61 | 62 | def test_psis_get_marginal_estimates(self): 63 | expanded_df = get_expanded_df(self.patients_df.drop(['C', 'T'], axis=1)) 64 | self.fitter.get_marginal_estimates(expanded_df) 65 | 66 | def test_psis_get_data_driven_treshold(self): 67 | self.fitter.get_data_driven_threshold(df=self.patients_df.drop(['C', 'T'], axis=1)) 68 | 69 | def test_psis_fit_data_driven_threshold(self): 70 | self.fitter.fit(df=self.patients_df.drop(['C', 'T'], axis=1), quantile=0.95) 71 | 72 | def test_psis_fit_user_defined_threshold(self): 73 | self.fitter.fit(df=self.patients_df.drop(['C', 'T'], axis=1), threshold=0.15) 74 | 75 | def test_psis_covs_dict(self): 76 | with self.assertRaises(ValueError): 77 | self.fitter.fit(df=self.patients_df.drop(['C', 'T'], axis=1), 78 | covariates={1: self.covariates[:-3], 2: self.covariates[:-8]}) 79 | 80 | 81 | class TestScreeningExact(TestScreening): 82 | 83 | def setUp(self): 84 | n_cov = 30 85 | beta1 = np.zeros(n_cov) 86 | beta1[:5] = np.array([-0.6, 0.5, -0.5, 0.6, -0.6]) 87 | beta2 = np.zeros(n_cov) 88 | beta2[:5] = np.array([0.5, -0.7, 0.7, -0.5, -0.7]) 89 | 90 | real_coef_dict = { 91 | "alpha": { 92 | 1: lambda t: -3.1 + 0.1 * np.log(t), 93 | 2: lambda t: -3.2 + 0.2 * np.log(t) 94 | }, 95 | "beta": { 96 | 1: beta1, 97 | 2: beta2 98 | } 99 | } 100 | 101 | n_patients = 400 102 | d_times = 7 103 | j_events = 2 104 | 105 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events) 106 | 107 | seed = 2 108 | means_vector = np.zeros(n_cov) 109 | covariance_matrix = np.identity(n_cov) 110 | 111 | clip_value = 3 112 | 113 | covariates = [f'Z{i + 1}' for i in range(n_cov)] 114 | 115 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix, 116 | size=n_patients), 117 | columns=covariates)) 118 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True) 119 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed) 120 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times), 121 | seed=seed + 1) 122 | patients_df = ets.update_event_or_lof(patients_df) 123 | patients_df.index.name = 'pid' 124 | self.patients_df = patients_df.reset_index() 125 | self.covariates = covariates 126 | self.fitter = SISTwoStagesFitterExact() 127 | --------------------------------------------------------------------------------