├── .codecov.yml
├── .github
└── workflows
│ ├── ci.yml
│ ├── codecov.yml
│ └── tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── docs
├── EventTimesSampler.ipynb
├── ModelsComparison.ipynb
├── PaperCodeFigure.ipynb
├── PerformanceMeasures.ipynb
├── Regularization.ipynb
├── Simple Simulation.ipynb
├── SimulatedDataset.ipynb
├── UsageExample-DataPreparation.ipynb
├── UsageExample-FittingDataExpansionFitter-FULL.ipynb
├── UsageExample-FittingDataExpansionFitter.ipynb
├── UsageExample-FittingTwoStagesFitter-FULL.ipynb
├── UsageExample-FittingTwoStagesFitter.ipynb
├── UsageExample-FittingTwoStagesFitterExact-FULL.ipynb
├── UsageExample-Intro.ipynb
├── UsageExample-RegroupingData.ipynb
├── UsageExample-SIS-SIS-L.ipynb
├── User Story.ipynb
├── api
│ ├── cross_validation.md
│ ├── data_expansion_fitter.md
│ ├── evaluation.md
│ ├── event_times_sampler.md
│ ├── model_selection.md
│ ├── screening.md
│ ├── two_stages_fitter.md
│ ├── two_stages_fitter_exact.md
│ └── utils.md
├── dtsicon.svg
├── icon.png
├── index.md
├── intro.md
├── jss_replication.py
├── methods.md
├── methodsevaluation.md
├── methodsintro.md
├── models_params.png
└── requirements.txt
├── mkdocs.yml
├── poetry.lock
├── pyproject.toml
├── src
└── pydts
│ ├── __init__.py
│ ├── base_fitters.py
│ ├── config.py
│ ├── cross_validation.py
│ ├── data_generation.py
│ ├── datasets
│ └── LOS_simulated_data.csv
│ ├── evaluation.py
│ ├── examples_utils
│ ├── __init__.py
│ ├── datasets.py
│ ├── generate_simulations_data.py
│ ├── mimic_consts.py
│ ├── plots.py
│ └── simulations_data_config.py
│ ├── fitters.py
│ ├── model_selection.py
│ ├── screening.py
│ └── utils.py
└── tests
├── __init__.py
├── test_DataExpansionFitter.py
├── test_EventTimesSampler.py
├── test_TwoStagesFitter.py
├── test_TwoStagesFitterExact.py
├── test_basefitter.py
├── test_cross_validation.py
├── test_evaluation.py
├── test_model_selection.py
├── test_repetative_fitter.py
└── test_screening.py
/.codecov.yml:
--------------------------------------------------------------------------------
1 | #
2 | # This codecov.yml is the default configuration for
3 | # all repositories on Codecov. You may adjust the settings
4 | # below in your own codecov.yml in your repository.
5 | #
6 | coverage:
7 | precision: 2
8 | round: down
9 | range: 70...100
10 |
11 | status:
12 | # Learn more at https://docs.codecov.io/docs/commit-status
13 | project:
14 | default:
15 | threshold: 1% # allow this much decrease on project
16 | app:
17 | target: 70%
18 | flags:
19 | - app
20 | modules:
21 | target: 70%
22 | flags:
23 | - modules
24 | client:
25 | flags:
26 | - client
27 | changes: false
28 |
29 | comment:
30 | layout: "reach, diff, files"
31 | behavior: default # update if exists else create new
32 | require_changes: true
33 |
34 | flags:
35 | app:
36 | paths:
37 | - "app/"
38 | - "baseapp/"
39 | modules:
40 | paths:
41 | - "x/"
42 | - "!x/**/client/" # ignore client package
43 | client:
44 | paths:
45 | - "client/"
46 | - "x/**/client/"
47 |
48 | ignore:
49 | - "docs"
50 | - "*.md"
51 | - "*.rst"
52 | - "**/*.pb.go"
53 | - "types/*.pb.go"
54 | - "tests/*"
55 | - "tests/**/*"
56 | - "x/**/*.pb.go"
57 | - "x/**/test_common.go"
58 | - "scripts/"
59 | - "contrib"
60 | - "src/pydts/examples_utils"
61 | - "src/pydts/utils.py"
62 | - "src/pydts/config.py"
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: ci
2 | on:
3 | push:
4 | branches:
5 | - master
6 | - main
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - uses: actions/setup-node@v4
13 | with:
14 | node-version: 16
15 | - uses: actions/setup-python@v4
16 | with:
17 | python-version: '3.10'
18 | - run: sudo apt-get update && sudo apt-get install -y gettext
19 | - run: pip install poetry
20 | - run: poetry install
21 | - run: poetry run mkdocs gh-deploy --force
22 |
23 |
24 |
--------------------------------------------------------------------------------
/.github/workflows/codecov.yml:
--------------------------------------------------------------------------------
1 | name: Code cov
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-latest
8 | name: Code cov
9 | steps:
10 | - uses: actions/checkout@v4
11 | - uses: actions/setup-python@v2
12 | with:
13 | python-version: '3.10'
14 | - name: Install requirements
15 | run: |
16 | sudo apt-get update && sudo apt-get install -y gettext
17 | pip install poetry
18 | poetry install
19 | - name: Run tests and collect coverage
20 | run: poetry run pytest tests/ --cov=./ --cov-report=xml
21 | - name: Upload coverage reports to Codecov with GitHub Action
22 | uses: codecov/codecov-action@v5
23 | env:
24 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
25 |
26 | #name: Codecov
27 | #on: [push, pull_request]
28 | #jobs:
29 | # run:
30 | # runs-on: ${{ matrix.os }}
31 | # strategy:
32 | # matrix:
33 | # os: [ubuntu-latest]
34 | # env:
35 | # OS: ${{ matrix.os }}
36 | # PYTHON: '3.9'
37 | # steps:
38 | # - uses: actions/checkout@master
39 | # - name: Setup Python
40 | # uses: actions/setup-python@master
41 | # with:
42 | # python-version: 3.9
43 | # - name: Generate coverage report
44 | # run: |
45 | # pip install poetry==1.5.1
46 | # poetry install
47 | # poetry run pytest --cov=./ --cov-report=xml
48 | # - name: Upload coverage to Codecov
49 | # uses: codecov/codecov-action@v2
50 | # with:
51 | # version: "v0.1.15"
52 | # token: ${{ secrets.CODECOV_TOKEN }}
53 | # directory: ./coverage/reports/
54 | # env_vars: OS,PYTHON
55 | # fail_ci_if_error: true
56 | # # files: ./coverage1.xml,./coverage2.xml
57 | # files: ./coverage.xml
58 | # flags: unittests
59 | # name: codecov-umbrella
60 | # path_to_write_report: ./coverage/codecov_report.txt
61 | # verbose: true
62 |
63 |
64 | #name: Codecov
65 | #on: [push, pull_request]
66 | #jobs:
67 | # test-coverage:
68 | # runs-on: ubuntu-latest
69 | # timeout-minutes: 120
70 | # strategy:
71 | # matrix:
72 | # python-version: ['3.10']
73 | # steps:
74 | # - uses: actions/checkout@v4
75 | # - name: Set up Python ${{ matrix.python-version }}
76 | # uses: actions/setup-python@v4
77 | # with:
78 | # python-version: ${{ matrix.python-version }}
79 | # - name: Install dependencies
80 | # run: |
81 | # sudo apt-get update && sudo apt-get install -y gettext
82 | # pip install poetry
83 | # poetry install
84 | # - name: Run tests with coverage
85 | # run: |
86 | # # Start a background loop to print a message every 60 seconds
87 | # while true; do echo "Running tests... still working at $(date)"; sleep 60; done &
88 | # keep_alive=$!
89 | #
90 | # # Run tests
91 | # poetry run pytest tests/ -v --cov=src/pydts --cov-report=xml
92 | # test_status=$?
93 | #
94 | # # Stop keep-alive loop
95 | # kill $keep_alive
96 | #
97 | # # Return pytest exit code
98 | # exit $test_status
99 | # - name: Upload coverage to Codecov
100 | # uses: codecov/codecov-action@v2
101 | # with:
102 | # version: "v0.1.15"
103 | # token: ${{ secrets.CODECOV_TOKEN }}
104 | # directory: ./coverage/reports/
105 | # env_vars: OS,PYTHON
106 | # fail_ci_if_error: true
107 | # # files: ./coverage1.xml,./coverage2.xml
108 | # files: ./coverage.xml
109 | # flags: unittests
110 | # name: codecov-umbrella
111 | # path_to_write_report: ./coverage/codecov_report.txt
112 | # verbose: true
113 | ## - name: Upload coverage reports to Codecov
114 | ## uses: codecov/codecov-action@v2
115 | ## env:
116 | ## CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
117 | ## CODECOV_VERSION: 'v0.1.15'
118 | ## with:
119 | ## files: coverage.xml
120 | ## token: ${{ secrets.CODECOV_TOKEN }}
121 | ## version: ${{ env.CODECOV_VERSION }}
122 | ## fail_ci_if_error: true
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 | on: [push, pull_request]
3 | jobs:
4 | test-DataExpansionFitter:
5 | runs-on: ubuntu-latest
6 | timeout-minutes: 60
7 | strategy:
8 | matrix:
9 | python-version: ['3.10']
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Set up Python ${{ matrix.python-version }}
13 | uses: actions/setup-python@v4
14 | with:
15 | python-version: ${{ matrix.python-version }}
16 | - name: Setup
17 | run: |
18 | sudo apt-get update && sudo apt-get install -y gettext
19 | pip install poetry
20 | poetry install
21 | - name: Run tests
22 | run: poetry run pytest tests/test_DataExpansionFitter.py -v
23 |
24 | test-EventTimesSampler:
25 | runs-on: ubuntu-latest
26 | timeout-minutes: 60
27 | strategy:
28 | matrix:
29 | python-version: ['3.10']
30 | steps:
31 | - uses: actions/checkout@v4
32 | - name: Set up Python ${{ matrix.python-version }}
33 | uses: actions/setup-python@v4
34 | with:
35 | python-version: ${{ matrix.python-version }}
36 | - name: Setup
37 | run: |
38 | sudo apt-get update && sudo apt-get install -y gettext
39 | pip install poetry
40 | poetry install
41 | - name: Run tests
42 | run: poetry run pytest tests/test_EventTimesSampler.py -v
43 |
44 | test-TwoStagesFitter:
45 | runs-on: ubuntu-latest
46 | timeout-minutes: 60
47 | strategy:
48 | matrix:
49 | python-version: ['3.10']
50 | steps:
51 | - uses: actions/checkout@v4
52 | - name: Set up Python ${{ matrix.python-version }}
53 | uses: actions/setup-python@v4
54 | with:
55 | python-version: ${{ matrix.python-version }}
56 | - name: Setup
57 | run: |
58 | sudo apt-get update && sudo apt-get install -y gettext
59 | pip install poetry
60 | poetry install
61 | - name: Run tests
62 | run: poetry run pytest tests/test_TwoStagesFitter.py -v
63 |
64 | test-TwoStagesFitterExact:
65 | runs-on: ubuntu-latest
66 | timeout-minutes: 120
67 | strategy:
68 | matrix:
69 | python-version: ['3.10']
70 | steps:
71 | - uses: actions/checkout@v4
72 | - name: Set up Python ${{ matrix.python-version }}
73 | uses: actions/setup-python@v4
74 | with:
75 | python-version: ${{ matrix.python-version }}
76 | - name: Setup
77 | run: |
78 | sudo apt-get update && sudo apt-get install -y gettext
79 | pip install poetry
80 | poetry install
81 | - name: Run tests
82 | run: poetry run pytest tests/test_TwoStagesFitterExact.py -v
83 |
84 | test-screening:
85 | runs-on: ubuntu-latest
86 | timeout-minutes: 60
87 | strategy:
88 | matrix:
89 | python-version: ['3.10']
90 | steps:
91 | - uses: actions/checkout@v4
92 | - name: Set up Python ${{ matrix.python-version }}
93 | uses: actions/setup-python@v4
94 | with:
95 | python-version: ${{ matrix.python-version }}
96 | - name: Setup
97 | run: |
98 | sudo apt-get update && sudo apt-get install -y gettext
99 | pip install poetry
100 | poetry install
101 | - name: Run tests
102 | run: poetry run pytest tests/test_screening.py -v
103 |
104 | test-model-selection:
105 | runs-on: ubuntu-latest
106 | timeout-minutes: 120
107 | strategy:
108 | matrix:
109 | python-version: ['3.10']
110 | steps:
111 | - uses: actions/checkout@v4
112 | - name: Set up Python ${{ matrix.python-version }}
113 | uses: actions/setup-python@v4
114 | with:
115 | python-version: ${{ matrix.python-version }}
116 | - name: Setup
117 | run: |
118 | sudo apt-get update && sudo apt-get install -y gettext
119 | pip install poetry
120 | poetry install
121 | - name: Run tests
122 | run: poetry run pytest tests/test_model_selection.py -v
123 |
124 | test-remaining:
125 | runs-on: ubuntu-latest
126 | timeout-minutes: 120
127 | strategy:
128 | matrix:
129 | python-version: ['3.10']
130 | steps:
131 | - uses: actions/checkout@v4
132 | - name: Set up Python ${{ matrix.python-version }}
133 | uses: actions/setup-python@v4
134 | with:
135 | python-version: ${{ matrix.python-version }}
136 | - name: Setup
137 | run: |
138 | sudo apt-get update && sudo apt-get install -y gettext
139 | pip install poetry
140 | poetry install
141 | - name: Run remaining tests
142 | run: |
143 | poetry run pytest tests/ \
144 | --ignore=tests/test_DataExpansionFitter.py \
145 | --ignore=tests/test_EventTimesSampler.py \
146 | --ignore=tests/test_TwoStagesFitter.py \
147 | --ignore=tests/test_TwoStagesFitterExact.py \
148 | --ignore=tests/test_screening.py \
149 | --ignore=tests/test_model_selection.py \
150 | -v
151 |
152 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | output*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | PyDTS - A python package for discrete-time survival-analysis with competing risks
2 | Copyright (C) 2022 Tomer Meir, Rom Gutman, and Malka Gorfine
3 |
4 | This program is free software: you can redistribute it and/or modify
5 | it under the terms of the GNU General Public License as published by
6 | the Free Software Foundation, either version 3 of the License, or
7 | (at your option) any later version.
8 |
9 | This program is distributed in the hope that it will be useful,
10 | but WITHOUT ANY WARRANTY; without even the implied warranty of
11 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 | GNU General Public License for more details.
13 |
14 | You should have received a copy of the GNU General Public License
15 | along with this program. If not, see .
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://pypi.org/project/pydts/)
2 | [](https://github.com/tomer1812/pydts/actions?workflow=Tests)
3 | [](https://tomer1812.github.io/pydts)
4 | [](https://codecov.io/gh/tomer1812/pydts)
5 | [](https://doi.org/10.5281/zenodo.15296343)
6 |
7 | # Discrete Time Survival Analysis
8 | A Python package for discrete-time survival data analysis with competing risks.
9 |
10 | 
11 |
12 | [Tomer Meir](https://tomer1812.github.io/), [Rom Gutman](https://github.com/RomGutman), [Malka Gorfine](https://www.tau.ac.il/~gorfinem/) 2022
13 |
14 | [Documentation](https://tomer1812.github.io/pydts/)
15 |
16 | ## Installation
17 | ```console
18 | pip install pydts
19 | ```
20 |
21 | ## Quick Start
22 |
23 | ```python
24 | from pydts.fitters import TwoStagesFitter
25 | from pydts.examples_utils.generate_simulations_data import generate_quick_start_df
26 |
27 | patients_df = generate_quick_start_df(n_patients=10000, n_cov=5, d_times=14, j_events=2, pid_col='pid', seed=0)
28 |
29 | fitter = TwoStagesFitter()
30 | fitter.fit(df=patients_df.drop(['C', 'T'], axis=1))
31 | fitter.print_summary()
32 | ```
33 |
34 | ## Examples
35 | 1. [Usage Example](https://tomer1812.github.io/pydts/UsageExample-Intro/)
36 | 2. [Hospital Length of Stay Simulation Example](https://tomer1812.github.io/pydts/SimulatedDataset/)
37 |
38 | ## Citations
39 | If you found PyDTS software useful to your research, please cite the papers:
40 |
41 | ```bibtex
42 | @article{Meir_PyDTS_2022,
43 | author = {Meir, Tomer and Gutman, Rom, and Gorfine, Malka},
44 | doi = {10.48550/arXiv.2204.05731},
45 | title = {{PyDTS: A Python Package for Discrete Time Survival Analysis with Competing Risks}},
46 | url = {https://arxiv.org/abs/2204.05731},
47 | year = {2022}
48 | }
49 |
50 | @article{Meir_Gorfine_DTSP_2025,
51 | author = {Meir, Tomer and Gorfine, Malka},
52 | doi = {10.1093/biomtc/ujaf040},
53 | title = {{Discrete-Time Competing-Risks Regression with or without Penalization}},
54 | year = {2025},
55 | journal = {Biometrics},
56 | volume = {81},
57 | number = {2},
58 | url = {https://academic.oup.com/biometrics/article/81/2/ujaf040/8120014},
59 | }
60 | ```
61 |
62 | and please consider starring the project [on GitHub](https://github.com/tomer1812/pydts)
63 |
64 | ## How to Contribute
65 | 1. Open Github issues to suggest new features or to report bugs\errors
66 | 2. Contact Tomer or Rom if you want to add a usage example to the documentation
67 | 3. If you want to become a developer (thank you, we appreciate it!) - please contact Tomer or Rom for developers' on-boarding
68 |
69 | Tomer Meir: tomer1812@gmail.com, Rom Gutman: rom.gutman1@gmail.com
--------------------------------------------------------------------------------
/docs/PaperCodeFigure.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "ee1b61c5",
7 | "metadata": {
8 | "ExecuteTime": {
9 | "end_time": "2022-05-18T19:07:03.674581Z",
10 | "start_time": "2022-05-18T19:03:12.578523Z"
11 | },
12 | "scrolled": false
13 | },
14 | "outputs": [],
15 | "source": [
16 | "import pandas as pd\n",
17 | "import numpy as np\n",
18 | "from sklearn.model_selection import train_test_split\n",
19 | "from pydts.examples_utils.generate_simulations_data import generate_quick_start_df\n",
20 | "\n",
21 | "# Data Generation\n",
22 | "real_coef_dict = {\n",
23 | " \"alpha\": {\n",
24 | " 1: lambda t: -1 - 0.3 * np.log(t),\n",
25 | " 2: lambda t: -1.75 - 0.15 * np.log(t)},\n",
26 | " \"beta\": {\n",
27 | " 1: -np.log([0.8, 3, 3, 2.5, 2]),\n",
28 | " 2: -np.log([1, 3, 4, 3, 2])}}\n",
29 | "\n",
30 | "patients_df = generate_quick_start_df(n_patients=50000, n_cov=5, d_times=30, j_events=2, pid_col='pid', \n",
31 | " seed=0, censoring_prob=0.8, real_coef_dict=real_coef_dict)\n",
32 | "\n",
33 | "train_df, test_df = train_test_split(patients_df, test_size=0.2)\n",
34 | "\n",
35 | "# DataExpansionFitter Usage\n",
36 | "from pydts.fitters import DataExpansionFitter\n",
37 | "fitter = DataExpansionFitter()\n",
38 | "fitter.fit(df=train_df.drop(['C', 'T'], axis=1))\n",
39 | "\n",
40 | "pred_df = fitter.predict_cumulative_incident_function(test_df.drop(['J', 'T', 'C', 'X'], axis=1))\n",
41 | "\n",
42 | "# TwoStagesFitter Usage\n",
43 | "from pydts.fitters import TwoStagesFitter\n",
44 | "new_fitter = TwoStagesFitter()\n",
45 | "new_fitter.fit(df=train_df.drop(['C', 'T'], axis=1))\n",
46 | "\n",
47 | "pred_df = fitter.predict_cumulative_incident_function(test_df.drop(['J', 'T', 'C', 'X'], axis=1))\n",
48 | "\n",
49 | "# Training with Regularization\n",
50 | "L1_regularized_fitter = TwoStagesFitter()\n",
51 | "fit_beta_kwargs = {'model_kwargs': {'penalizer': 0.003, 'l1_ratio': 1}}\n",
52 | "L1_regularized_fitter.fit(df=train_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)\n",
53 | "\n",
54 | "L2_regularized_fitter = TwoStagesFitter()\n",
55 | "fit_beta_kwargs = {'model_kwargs': {'penalizer': 0.003, 'l1_ratio': 0}}\n",
56 | "L2_regularized_fitter.fit(df=train_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)\n",
57 | "\n",
58 | "EN_regularized_fitter = TwoStagesFitter()\n",
59 | "fit_beta_kwargs = {'model_kwargs': {'penalizer': 0.003, 'l1_ratio': 0.5}}\n",
60 | "EN_regularized_fitter.fit(df=train_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": null,
66 | "id": "b05f4b71",
67 | "metadata": {},
68 | "outputs": [],
69 | "source": []
70 | }
71 | ],
72 | "metadata": {
73 | "kernelspec": {
74 | "display_name": "Python 3 (ipykernel)",
75 | "language": "python",
76 | "name": "python3"
77 | },
78 | "language_info": {
79 | "codemirror_mode": {
80 | "name": "ipython",
81 | "version": 3
82 | },
83 | "file_extension": ".py",
84 | "mimetype": "text/x-python",
85 | "name": "python",
86 | "nbconvert_exporter": "python",
87 | "pygments_lexer": "ipython3",
88 | "version": "3.8.2"
89 | }
90 | },
91 | "nbformat": 4,
92 | "nbformat_minor": 5
93 | }
94 |
--------------------------------------------------------------------------------
/docs/api/cross_validation.md:
--------------------------------------------------------------------------------
1 | ::: pydts.cross_validation.TwoStagesCV
2 | ::: pydts.cross_validation.TwoStagesCVExact
3 | ::: pydts.cross_validation.PenaltyGridSearchCV
4 | ::: pydts.cross_validation.PenaltyGridSearchCVExact
5 |
--------------------------------------------------------------------------------
/docs/api/data_expansion_fitter.md:
--------------------------------------------------------------------------------
1 | ::: pydts.fitters.DataExpansionFitter
--------------------------------------------------------------------------------
/docs/api/evaluation.md:
--------------------------------------------------------------------------------
1 | ::: pydts.evaluation
--------------------------------------------------------------------------------
/docs/api/event_times_sampler.md:
--------------------------------------------------------------------------------
1 | ::: pydts.data_generation.EventTimesSampler
--------------------------------------------------------------------------------
/docs/api/model_selection.md:
--------------------------------------------------------------------------------
1 | ::: pydts.model_selection.PenaltyGridSearch
2 | ::: pydts.model_selection.PenaltyGridSearchExact
--------------------------------------------------------------------------------
/docs/api/screening.md:
--------------------------------------------------------------------------------
1 | ::: pydts.screening.SISTwoStagesFitter
2 | ::: pydts.screening.SISTwoStagesFitterExact
--------------------------------------------------------------------------------
/docs/api/two_stages_fitter.md:
--------------------------------------------------------------------------------
1 | ::: pydts.fitters.TwoStagesFitter
--------------------------------------------------------------------------------
/docs/api/two_stages_fitter_exact.md:
--------------------------------------------------------------------------------
1 | ::: pydts.fitters.TwoStagesFitterExact
--------------------------------------------------------------------------------
/docs/api/utils.md:
--------------------------------------------------------------------------------
1 | ::: pydts.utils.get_expanded_df
--------------------------------------------------------------------------------
/docs/icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomer1812/pydts/578929a457c111efe009d3461aab531b793b33d0/docs/icon.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | [](https://pypi.org/project/pydts/)
2 | [](https://github.com/tomer1812/pydts/actions?workflow=Tests)
3 | [](https://tomer1812.github.io/pydts)
4 | [](https://codecov.io/gh/tomer1812/pydts)
5 | [](https://doi.org/10.5281/zenodo.15296343)
6 |
7 | # Discrete Time Survival Analysis
8 | A Python package for discrete-time survival data analysis with competing risks.
9 |
10 | 
11 |
12 | [Tomer Meir](https://tomer1812.github.io/), [Rom Gutman](https://github.com/RomGutman), [Malka Gorfine](https://www.tau.ac.il/~gorfinem/) 2022
13 |
14 | ## Installation
15 | ```console
16 | pip install pydts
17 | ```
18 |
19 | ## Quick Start
20 |
21 | ```python
22 | from pydts.fitters import TwoStagesFitter
23 | from pydts.examples_utils.generate_simulations_data import generate_quick_start_df
24 |
25 | patients_df = generate_quick_start_df(n_patients=10000, n_cov=5, d_times=14, j_events=2, pid_col='pid', seed=0)
26 |
27 | fitter = TwoStagesFitter()
28 | fitter.fit(df=patients_df.drop(['C', 'T'], axis=1))
29 | fitter.print_summary()
30 | ```
31 |
32 | ## Examples
33 | 1. [Usage Example](https://tomer1812.github.io/pydts/UsageExample-Intro/)
34 | 2. [Hospital Length of Stay Simulation Example](https://tomer1812.github.io/pydts/SimulatedDataset/)
35 |
36 | ## Citation
37 | If you found PyDTS useful, please cite:
38 |
39 | ```bibtex
40 | @article{Meir_PyDTS_2022,
41 | author = {Meir, Tomer and Gutman, Rom, and Gorfine, Malka},
42 | doi = {10.48550/arXiv.2204.05731},
43 | title = {{PyDTS: A Python Package for Discrete Time Survival Analysis with Competing Risks}},
44 | url = {https://arxiv.org/abs/2204.05731},
45 | year = {2022}
46 | }
47 |
48 | @article{Meir_Gorfine_DTSP_2025,
49 | author = {Meir, Tomer and Gorfine, Malka},
50 | doi = {10.1093/biomtc/ujaf040},
51 | title = {{Discrete-Time Competing-Risks Regression with or without Penalization}},
52 | year = {2025},
53 | journal = {Biometrics},
54 | volume = {81},
55 | number = {2},
56 | url = {https://academic.oup.com/biometrics/article/81/2/ujaf040/8120014},
57 | }
58 | ```
59 |
60 | and please consider starring the project [on GitHub](https://github.com/tomer1812/pydts)
61 |
62 | ## How to Contribute
63 | 1. Open Github issues to suggest new features or to report bugs\errors
64 | 2. Contact Tomer or Rom if you want to add a usage example to the documentation
65 | 3. If you want to become a developer (thank you, we appreciate it!) - please contact Tomer or Rom for developers' on-boarding
66 |
67 | Tomer Meir: tomer1812@gmail.com, Rom Gutman: rom.gutman1@gmail.com
68 |
--------------------------------------------------------------------------------
/docs/intro.md:
--------------------------------------------------------------------------------
1 | # Introduction
2 |
3 | Based on
4 |
5 | "PyDTS: A Python Package for Discrete-Time Survival Analysis with Competing Risks"
6 |
7 | Tomer Meir\*, Rom Gutman\*, and Malka Gorfine (2022) [[1]](#1).
8 |
9 | and
10 |
11 | "Discrete-time Competing-Risks Regression with or without Penalization"
12 |
13 | Tomer Meir and Malka Gorfine (2023) [[2]](#2).
14 |
15 | ## Discrete-data survival analysis
16 | Discrete-data survival analysis refers to the case where data can only take values over a discrete grid. Sometimes, events can only occur at regular, discrete points in time. For example, in the United States a change in party controlling the presidency only occurs quadrennially in the month of January [[3]](#3). In other situations events may occur at any point in time, but available data record only the particular interval of time in which each event occurs. For example, death from cancer measured by months since time of diagnosis [[4]](#4), or length of stay in hospital recorded on a daily basis. It is well-known that naively using standard continuous-time models (even after correcting for ties) with discrete-time data may result in biased estimators for the discrete time models.
17 |
18 | ## Competing events
19 |
20 | Competing events arise when individuals are susceptible to several types of events but can experience at most one event. For example, competing risks for hospital length of stay are discharge and in-hospital death. Occurrence of one of these events precludes us from observing the other event on this patient. Another classical example of competing risks is cause-specific mortality, such as death from heart disease, death from cancer and death from other causes [[5, 6]](#5#6).
21 |
22 |
23 | PyDTS is an open source Python package which implements tools for discrete-time survival analysis with competing risks.
24 |
25 |
26 | ## References
27 | [1]
28 | Meir, Tomer\*, Gutman, Rom\*, and Gorfine, Malka,
29 | "PyDTS: A Python Package for Discrete-Time Survival Analysis with Competing Risks"
30 | (2022)
31 |
32 | [2]
33 | Meir, Tomer and Gorfine, Malka,
34 | "Discrete-time Competing-Risks Regression with or without Penalization", Biometrics (2025), doi: 10.1093/biomtc/ujaf040
35 |
36 | [3]
37 | Allison, Paul D.
38 | "Discrete-Time Methods for the Analysis of Event Histories"
39 | Sociological Methodology (1982),
40 | doi: 10.2307/270718
41 |
42 | [4]
43 | Lee, Minjung and Feuer, Eric J. and Fine, Jason P.
44 | "On the analysis of discrete time competing risks data"
45 | Biometrics (2018)
46 | doi: 10.1111/biom.12881
47 |
48 | [5]
49 | Kalbfleisch, John D. and Prentice, Ross L.
50 | "The Statistical Analysis of Failure Time Data" 2nd Ed.,
51 | Wiley (2011)
52 | ISBN: 978-1-118-03123-0
53 |
54 | [6]
55 | Klein, John P. and Moeschberger, Melvin L.
56 | "Survival Analysis",
57 | Springer (2003)
58 | ISBN: 978-0-387-95399-1
59 |
--------------------------------------------------------------------------------
/docs/methods.md:
--------------------------------------------------------------------------------
1 | # Methods
2 |
3 |
4 | ## Definitions
5 |
6 | We let $T$ denote a discrete event time that can take on only the values $\{1,2,...,d\}$ and $J$ denote the type of event, $J \in \{1,\ldots,M\}$. Consider a $p \times 1$ vector of baseline covariates $Z$. A general discrete cause-specific hazard function is of the form
7 | $$
8 | \lambda_j(t|Z) = \Pr(T=t,J=j|T\geq t, Z) \hspace{0.3cm} t \in \{1,2,...,d\} \hspace{0.3cm} j=1,\ldots,M \, .
9 | $$
10 | A popular semi-parametric model of the above hazard function based on a transformation regression model is of the form
11 | $$
12 | h(\lambda_{j}(t|Z)) = \alpha_{jt} +Z^T \beta_j \hspace{0.3cm} t \in \{1,2,...,d\} \hspace{0.3cm} j=1, \ldots,M
13 | $$
14 | such that $h$ is a known function [[2, and reference therein]](#2). The total number of parameters in the model is $M(d+p)$. The logit function $h(a)=\log \{ a/(1-a) \}$ yields
15 | \begin{equation}\label{eq:logis}
16 | \lambda_j(t|Z)=\frac{\exp(\alpha_{jt}+Z^T\beta_j)}{1+\exp(\alpha_{jt}+Z^T\beta_j)} \, .
17 | \end{equation}
18 | It should be noted that leaving $\alpha_{jt}$ unspecified is analogous to having an unspecified baseline hazard function in the Cox proportional hazard model [[3]](#3), and thus we consider the above as a semi-parametric model.
19 |
20 |
21 | Let $S(t|Z) = \Pr(T>t|Z)$ be the overall survival given $Z$. Then, the probability of experiencing event of type $j$ at time $t$ equals
22 | $$
23 | \Pr(T=t,J=j|Z)=\lambda_j(t|Z) \prod_{k=1}^{t-1} \left\lbrace 1-
24 | \sum_{j'=1}^M\lambda_{j'}(k|Z) \right\rbrace
25 | $$
26 | and the cumulative incident function (CIF) of cause $j$ is given by
27 | $$
28 | F_j(t|Z) = \Pr(T \leq t, J=j|Z) = \sum_{m=1}^{t} \lambda_j(m|Z) S(m-1|Z) = \sum_{m=1}^{t}\lambda_j(m|Z) \prod_{k=1}^{m-1} \left\lbrace 1-\sum_{j'=1}^M\lambda_{j'}(k|Z) \right\rbrace \, .
29 | $$
30 | Finally, the marginal probability of event type $j$ (marginally with respect to the time of event), given $Z$, equals
31 | $$
32 | \Pr(J=j|Z) = \sum_{m=1}^{d} \lambda_j(m|Z) \prod_{k=1}^{m-1} \left\lbrace 1-\sum_{j'=1}^M\lambda_{j'}(k|Z) \right\rbrace \, .
33 | $$
34 | In the next section we provide a fast estimation technique of the parameters $\{\alpha_{j1},\ldots,\alpha_{jd},\beta_j^T \, ; \, j=1,\ldots,M\}$.
35 |
36 |
37 | ## The Collapsed Log-Likelihood Approach and the Proposed Estimators
38 | For simplicity of presentation, we assume two competing events, i.e., $M=2$ and our goal is estimating $\{\alpha_{11},\ldots,\alpha_{1d},\beta_1^T,\alpha_{21},\ldots,\alpha_{2d},\beta_2^T\}$ along with the standard error of the estimators. The data at hand consist of $n$ independent observations, each with $(X_i,\delta_i,J_i,Z_i)$ where $X_i=\min(C_i,T_i)$, $C_i$ is a right-censoring time,
39 | $\delta_i=I(X_i=T_i)$ is the event indicator and $J_i\in\{0,1,2\}$, where $J_i=0$ if and only if $\delta_i=0$. Assume that given the covariates, the censoring and failure time are independent and non-informative. Then, the likelihood function is proportional to
40 | $$
41 | L = \prod_{i=1}^n \left\lbrace\frac{\lambda_1(X_i|Z_i)}{1-\lambda_1(X_i|Z_i)-\lambda_2(X_i|Z_i)}\right\rbrace^{I(\delta_{1i}=1)} \left\lbrace\frac{\lambda_2(X_i|Z_i)}{1-\lambda_1(X_i|Z_i)-\lambda_2(X_i|Z_i)}\right\rbrace^{I(\delta_{2i}=1)} \prod_{k=1}^{X_i}\lbrace 1-\lambda_1(k|Z_i)-\lambda_2(k|Z_i)\rbrace
42 | $$
43 | or, equivalently,
44 | $$
45 | L = \prod_{i=1}^n \left\[ \prod_{j=1}^2 \prod_{m=1}^{X_i} \left\lbrace \frac{\lambda_j(m|Z_i)}{1-\lambda_1(m|Z_i)-\lambda_2(m|Z_i)}\right\rbrace^{\delta_{jim}}\right] \prod_{k=1}^{X_i}\lbrace 1-\lambda_1(k|Z_i)-\lambda_2(k|Z_i)\rbrace
46 | $$
47 | where $\delta_{jim}$ equals one if subject $i$ experienced event type $j$ at time $m$; and 0 otherwise. Clearly $L$ cannot be decomposed into separate likelihoods for each cause-specific
48 | hazard function $\lambda_j$.
49 | The log likelihood becomes
50 | $$
51 | \log L = \sum_{i=1}^n \left\[ \sum_{j=1}^2 \sum_{m=1}^{X_i} \left\[ \delta_{jim} \log \lambda_j(m|Z_i) - \delta_{jim}\{1-\lambda_1(m|Z_i)-\lambda_2(m|Z_i)\}\right\] \right\.
52 | +\left\.\sum_{k=1}^{X_i}\log \lbrace 1-\lambda_1(k|Z_i)-\lambda_2(k|Z_i)\rbrace \right\]
53 | $$
54 | $$
55 | = \sum_{i=1}^n \sum_{m=1}^{X_i} \left\[ \delta_{1im} \log \lambda_1(m|Z_i)+\delta_{2im} \log \lambda_2(m|Z_i) \right\. +\left\. \lbrace 1-\delta_{1im}-\delta_{2im}\rbrace \log\lbrace 1-\lambda_1(m|Z_i)-\lambda_2(m|Z_i)\rbrace \right\] \, .
56 | $$
57 |
58 | Instead of maximizing the $M(d+p)$ parameters simultaneously based on the above log-likelihood, the collapsed log-likelihood of Lee et al. [[4]](#4) can be adopted. Specifically, the data are expanded such that for each observation $i$ the expanded dataset includes $X_i$ rows, one row for each time $t$, $t \leq X_i$. At each time point $t$ the expanded data are conditionally multinomial with one of three possible outcomes $\{\delta_{1it},\delta_{2it},1-\delta_{1it}-\delta_{2it}\}$, as in [Table 1](#tbl:expanded).
59 |
60 | Table 1: Original and expanded datasets with $M = 2$ competing events [[Lee et al. (2018)]](#4).
61 |
62 |
63 | | $i$ | $X_i$ | $\delta_i$ | $Z_i$ | $i$ | $\tilde{X}_i$ | $\delta_{1it}$ | $\delta_{2it}$ | $1 - \delta_{1it} - \delta_{2it}$ | $Z_i$ |
64 | |-----|--------|--------------|-------|-----|----------------|----------------|----------------|-------------------------------|-------|
65 | | 1 | 2 | 1 | $Z_1$ | 1 | 1 | 0 | 0 | 1 | $Z_1$ |
66 | | | | | | 1 | 2 | 1 | 0 | 0 | $Z_1$ |
67 | | 2 | 3 | 2 | $Z_2$ | 2 | 1 | 0 | 0 | 1 | $Z_2$ |
68 | | | | | | 2 | 2 | 0 | 0 | 1 | $Z_2$ |
69 | | | | | | 2 | 3 | 0 | 1 | 0 | $Z_2$ |
70 | | 3 | 3 | 0 | $Z_3$ | 3 | 1 | 0 | 0 | 1 | $Z_3$ |
71 | | | | | | 3 | 2 | 0 | 0 | 1 | $Z_3$ |
72 | | | | | | 3 | 3 | 0 | 0 | 1 | $Z_3$ |
73 |
74 |
75 |
76 |
77 | Then, for estimating $\{\alpha_{11},\ldots,\alpha_{1d},\beta_1^T\}$, we combine $\delta_{2it}$ and $1-\delta_{1it}-\delta_{2it}$, and the collapsed log-likelihood for cause $J=1$ based on a binary regression model with $\delta_{1it}$ as the outcome is given by
78 | $$
79 | \log L_1 = \sum_{i=1}^n \sum_{m=1}^{X_i}\left\[ \delta_{1im} \log \lambda_1(m|Z_i)+(1-\delta_{1im})\log \lbrace 1-\lambda_1(m|Z_i)\rbrace \right\] \, .
80 | $$
81 | Similarly, the collapsed log-likelihood for cause $J=2$ based on a binary regression model with $\delta_{2it}$ as the outcome becomes
82 | $$
83 | \log L_2 = \sum_{i=1}^n \sum_{m=1}^{X_i}\left\[ \delta_{2im} \log \lambda_2(m|Z_i)+(1-\delta_{2im})\log \lbrace 1-\lambda_2(m|Z_i)\rbrace \right\]
84 | $$
85 | and one can fit the two models, separately.
86 |
87 | In general, for $M$ competing events,
88 | the estimators of $\{\alpha_{j1},\ldots,\alpha_{jd},\beta_j^T\}$, $j=1,\ldots,M$, are the respective values that maximize
89 | $$
90 | \log L_j = \sum_{i=1}^n \sum_{m=1}^{X_i}\left[ \delta_{jim} \log \lambda_j(m|Z_i)+(1-\delta_{jim})\log \{1-\lambda_j(m|Z_i)\} \right] \, .
91 | $$
92 | Namely, each maximization $j$, $j=1,\ldots,M$, consists of maximizing $d + p$ parameters simultaneously.
93 |
94 | ### Proposed Estimators
95 | Alternatively, we propose the following simpler and faster estimation procedure, with a negligible efficiency loss, if any. Our idea exploits the close relationship between conditional logistic regression analysis and stratified Cox regression analysis [[5]](#5). We propose to estimate each $\beta_j$ separately, and given $\beta_j$, $\alpha_{jt}$, $t=1\ldots,d$, are separately estimated. In particular, the proposed estimating procedure consists of the following two speedy steps:
96 | #### Step 1.
97 | Use the expanded dataset and estimate each vector $\beta_j$, $j \in \{1,\ldots, M\}$, by a simple conditional logistic regression, conditioning on the event time $X$, using a stratified Cox analysis.
98 |
99 | #### Step 2.
100 | Given the estimators $\widehat{\beta}_j$ , $j \in \{1,\ldots, M\}$, of Step 1, use the original (un-expanded) data and estimate each $\alpha_{jt}$, $j \in \{1,\ldots,M\}$, $t=1,\ldots,d$, separately, by
101 |
102 | $$\widehat{ \alpha }_{jt} = argmin_{a} \left\lbrace \frac{1}{y_t} \sum_{i=1}^n I(X_i \geq t) \frac{ \exp(a+Z_i^T \widehat{\beta}_j)}{1 + \exp(a + Z_i^T \widehat{\beta}_j)} - \frac{n_{tj}}{y_t} \right\rbrace ^2 $$
103 |
104 | where $y_t=\sum_{i=1}^n I(X_i \geq t)$ and $n_{tj}=\sum_{i=1}^n I(X_i = t, J_i=j)$.
105 |
106 | The above equation consists minimizing the squared distance between the observed proportion of failures of type $j$ at time $t$
107 | ($n_{tj}/y_t$) and the expected proportion of failures given model defined above for $\lambda_j$ and $\widehat{\beta}_j$.
108 | The simulation results of section Simple Simulation reveals that the above two-step procedure performs well in terms of bias, and provides similar standard error of that of [[3]](#3). However, the improvement in computational time, by using our procedure, could be improved by a factor of 1.5-3.5 depending on d. Standard errors of $\widehat{\beta}_j$, $j \in \{1,\ldots,M\}$, can be derived directly from the stratified Cox analysis.
109 |
110 | ### Time-dependent covariates
111 |
112 | Similarly to the continuous-time Cox model, the simplest way to code time-dependent covariates uses intervals of time [[Therneau et al. (2000)]](#6). Then, the data is encoded by breaking the individual’s time into multiple time intervals, with one row of data for each interval. Hence combining this data expansion step with the expansion demonstrated in [Table 1](#tbl:expanded) is straightforward.
113 |
114 | ### Regularized regression models
115 |
116 | Penalized regression methods, such as lasso, adaptive lasso, and elastic net [[Hastie et al. 2009]](#7), place a constraint on the size of the regression coefficients. The estimation procedure of [[Meir and Gorfine (2023)]](#8) that separates the estimation of $\beta_j$ and $\alpha_{jt}$ can easily incorporate such constraints in Lagrangian form by minimizing
117 |
118 | $$
119 | -\log L_j^c(\beta_j) + \eta_j P(\beta_j) \, , \quad j=1,\ldots,M \, ,
120 | $$
121 |
122 | where $P$ is a penalty function and $\eta_j \geq 0$ are shrinkage tuning parameters. The parameters $\alpha_{jt}$ are estimated once the regularization step is completed and $\beta_j$ are estimated.
123 |
124 | Clearly, any regularized Cox regression model routines can be used for estimating $\beta_j$, $j=1,\ldots,M$, based on the above equation, for example, the `CoxPHFitter` of the `lifelines` Python package [[Davidson-Pilon (2019)]](#9) with penalization.
125 |
126 |
127 | ### Sure Independence Screening
128 |
129 | When the number of available covariates greatly exceeds the number of observations (as common in genetic datasets, for example), i.e., the ultra-high setting, most regularized methods suffer from the curse of dimensionality, high variance, and overfitting [[Hastie et al. (2009)]](#7).
130 | Sure Independent Screening (SIS) is a marginal screening technique designed to filter out uninformative covariates.
131 | Penalized regression methods can be applied after the marginal screening process to the remaining covariates.
132 |
133 | We start the SIS procedure by ranking all the covariates using a utility measure between the response and each covariate, and then retain only covariates with estimated coefficients that exceeds a threshold value.
134 | We focus on SIS and SIS followed by lasso (SIS-L) [[Fan et al. (2010); Saldana and Feng (2018)]](#10) within the proposed two-step procedure.
135 |
136 | We start by fitting a marginal regression for each covariate by maximizing:
137 |
138 | $$
139 | L_j^{\mathcal{C}}(\beta_{jr}) \quad \text{for } j=1,\ldots,M, \quad r=1,\ldots,p
140 | $$
141 |
142 | where $\boldsymbol{\beta}_j = (\beta_{j1},\ldots,\beta_{jp})^T$.
143 | Then we rank the features based on the magnitude of their marginal regression coefficients.
144 | The selected sets of variables are given by:
145 |
146 | $$
147 | \widehat{\mathcal{M}}_{j,w_n} = \left\{1 \leq k \leq p \, : \, |\widehat{\beta}_{jk}| \geq w_n \right\}, \quad j=1,\ldots,M,
148 | $$
149 |
150 | where $w_n$ is a threshold value.
151 | We adopt the data-driven threshold of [[Saldana and Feng (2018)]](#11).
152 | Given data of the form $\{X_i, \delta_i, J_i, \mathbf{Z}_i \, ; \, i = 1, \ldots, n\}$, a random permutation $\pi$ of $\{1,\ldots,n\}$ is used to decouple $\mathbf{Z}_i$ and $(X_i, \delta_i, J_i)$, so that the permuted data $\{X_i, \delta_i, J_i, \mathbf{Z}_{\pi(i)}\}$ follow a model where the covariates have no predictive power over the survival time of any event type.
153 |
154 | For the permuted data, we re-estimate individual regression coefficients and obtain $\widehat{\beta}^*_{jr}$. The data-driven threshold is defined as:
155 |
156 | $$
157 | w_n = \max_{1 \leq j \leq M, \, 1 \leq k \leq p} |\widehat{\beta}^*_{jk}|.
158 | $$
159 |
160 | For the SIS-L procedure, lasso regularization is then applied in the first step of the two-step procedure to the set of covariates selected by SIS.
161 |
162 |
163 |
164 | ## References
165 | [1]
166 | Meir, Tomer\*, Gutman, Rom\*, and Gorfine, Malka
167 | "PyDTS: A Python Package for Discrete Time Survival-analysis with Competing Risks"
168 | (2022)
169 |
170 | [2]
171 | Allison, Paul D.
172 | "Discrete-Time Methods for the Analysis of Event Histories"
173 | Sociological Methodology (1982),
174 | doi: 10.2307/270718
175 |
176 | [3]
177 | Cox, D. R.
178 | "Regression Models and Life-Tables"
179 | Journal of the Royal Statistical Society: Series B (Methodological) (1972)
180 | doi: 10.1111/j.2517-6161.1972.tb00899.x
181 |
182 | [4]
183 | Lee, Minjung and Feuer, Eric J. and Fine, Jason P.
184 | "On the analysis of discrete time competing risks data"
185 | Biometrics (2018)
186 | doi: 10.1111/biom.12881
187 |
188 | [5]
189 | Prentice, Ross L and Breslow, Norman E
190 | "Retrospective studies and failure time models"
191 | Biometrika (1978)
192 | doi: 10.1111/j.2517-6161.1972.tb00899.x
193 |
194 | [6]
195 | Therneau, Terry M and Grambsch, Patricia M,
196 | "Modeling Survival Data: Extending the Cox Model", Springer-Verlag,
197 | (2000)
198 |
199 | [7]
200 | Hastie, Trevor and Tibshirani, Robert and Friedman, Jerome H,
201 | "The Elements of Statistical Learning: Data Mining, Inference, and Prediction.", Springer-Verlag,
202 | (2009)
203 |
204 | [8]
205 | Meir, Tomer and Gorfine, Malka,
206 | "Discrete-time Competing-Risks Regression with or without Penalization", Biometrics (2025), doi: 10.1093/biomtc/ujaf040
207 |
208 |
209 | [9]
210 | Davidson-Pilon, Cameron,
211 | "lifelines: Survival Analysis in Python", Journal of Open Source Software,
212 | (2019)
213 |
214 | [10]
215 | Fan, J and Feng, Y and Wu, Y,
216 | "High-dimensional variable selection for Cox’s proportional hazards model",
217 | Institute of Mathematical Statistics,
218 | (2010)
219 |
220 | [11]
221 | Saldana, DF and Feng, Y,
222 | "SIS: An R package for sure independence screening in ultrahigh-dimensional statistical models",
223 | Journal of Statistical Software,
224 | (2018)
225 |
--------------------------------------------------------------------------------
/docs/methodsevaluation.md:
--------------------------------------------------------------------------------
1 | # Evaluation Measures
2 | Let
3 |
4 | $$
5 | \pi_{ij}(t) = \widehat{\Pr}(T_i=t, J_i=j \mid Z_i) = \widehat{\lambda}_j (t \mid Z_i) \widehat{S}(t-1 \mid Z_i)
6 | $$
7 |
8 | and
9 |
10 | $$
11 | D_{ij} (t) = I(T_i = t, J_i = j)
12 | $$
13 |
14 | The cause-specific incidence/dynamic area under the receiver operating characteristics curve (AUC) is defined and estimated in the spirit of Heagerty and Zheng (2005) and Blanche et al. (2015) as the probability of a random observation with observed event $j$ at time $t$ having a higher risk prediction for cause $j$ than a randomly selected observation $m$, at risk at time $t$, without the observed event $j$ at time $t$. Namely,
15 |
16 | $$
17 | \mbox{AUC}_j(t) = \Pr (\pi_{ij}(t) > \pi_{mj}(t) \mid D_{ij} (t) = 1, D_{mj} (t) = 0, T_m \geq t)
18 | $$
19 |
20 | In the presence of censored data and under the assumption that the censoring is independent of the failure time and observed covariates, an inverse probability censoring weighting (IPCW) estimator of $\mbox{AUC}_j(t)$ becomes
21 |
22 | $$
23 | \widehat{\mbox{AUC}}_j (t) = \frac{\sum_{i=1}^{n}\sum_{m=1}^{n} D_{ij}(t)(1-D_{mj}(t))I(X_m \geq t) W_{ij}(t) W_{mj}(t) \{I(\pi_{ij}(t) > \pi_{mj}(t))+0.5I(\pi_{ij}(t)=\pi_{mj}(t))\}}{\sum_{i=1}^{n}\sum_{m=1}^{n} D_{ij}(t)(1-D_{mj}(t))I(X_m \geq t) W_{ij}(t) W_{mj}(t)}
24 | $$
25 |
26 | And can be simplified as:
27 |
28 | $$
29 | \widehat{\mbox{AUC}}_j (t) = \frac{\sum_{i=1}^{n}\sum_{m=1}^{n} D_{ij}(t)(1-D_{mj}(t))I(X_m \geq t) \{I(\pi_{ij}(t) > \pi_{mj}(t))+0.5I(\pi_{ij}(t)=\pi_{mj}(t))\}}{\sum_{i=1}^{n}\sum_{m=1}^{n} D_{ij}(t)(1-D_{mj}(t))I(X_m \geq t)}
30 | $$
31 |
32 | where
33 |
34 | $$
35 | W_{ij}(t) = \frac{D_{ij}(t)}{\widehat{G}_C(T_i)} + I(X_i \geq t)\frac{1-D_{ij}(t)}{\widehat{G}_C(t)} = \frac{D_{ij}(t)}{\widehat{G}_C(t)} + I(X_i \geq t)\frac{1-D_{ij}(t)}{\widehat{G}_C(t)} = I(X_i \geq t) / \widehat{G}_C(t)
36 | $$
37 |
38 | and $\widehat{G}_C(\cdot)$ is the estimated survival function of the censoring (e.g., the Kaplan-Meier estimator). Interestingly, the IPCWs have no effect on $\widehat{\mbox{AUC}}_j (t)$.
39 |
40 | An integrated cause-specific AUC can be estimated as a weighted sum by
41 |
42 | $$
43 | \widehat{\mbox{AUC}}_j = \sum_t \widehat{\mbox{AUC}}_j (t) w_j (t)
44 | $$
45 |
46 | and we adopt a simple data-driven weight function of the form
47 |
48 | $$
49 | w_j(t) = \frac{N_j(t)}{\sum_t N_j(t)}
50 | $$
51 |
52 | A global AUC can be defined as
53 |
54 | $$
55 | \widehat{\mbox{AUC}} = \sum_j \widehat{\mbox{AUC}}_j v_j
56 | $$
57 |
58 | where
59 |
60 | $$
61 | v_j = \frac{\sum_{t} N_j(t)}{ \sum_{j=1}^M \sum_{t} N_j(t) }
62 | $$
63 |
64 | Another well-known performance measure is the Brier Score (BS). In the spirit of Blanche et al. (2015) we define
65 |
66 | $$
67 | \widehat{\mbox{BS}}_{j}(t) = \frac{1}{Y_{\cdot}(t)} {\sum_{i=1}^n W_{ij}(t) \left( D_{ij}(t) - \pi_{ij}(t)\right)^2} \, .
68 | $$
69 |
70 | An integrated cause-specific BS can be estimated by the weighted sum
71 |
72 | $$
73 | \widehat{\mbox{BS}}_{j} = \sum_t \widehat{\mbox{BS}}_{j}(t) w_j(t)
74 | $$
75 |
76 | and an estimated global BS is given by
77 |
78 | $$
79 | \widehat{\mbox{BS}} = \sum_j \widehat{\mbox{BS}}_{j} v_j \, .
80 | $$
81 |
--------------------------------------------------------------------------------
/docs/methodsintro.md:
--------------------------------------------------------------------------------
1 | # Methods
2 |
3 | In this section, we outline the statistical background for the tools incorporated in PyDTS. We commence with some definitions, present the collapsed log-likelihood approach and the estimation procedure of Lee et al. (2018) [[4]](#4), introduce our estimation method [[1]](#1)-[[2]](#2), and conclude with evaluation metrics. For additional details, check out the references.
4 |
5 |
6 | ## References
7 | [1]
8 | Meir, Tomer\*, Gutman, Rom\*, and Gorfine, Malka,
9 | "PyDTS: A Python Package for Discrete-Time Survival Analysis with Competing Risks"
10 | (2022)
11 |
12 | [2]
13 | Meir, Tomer and Gorfine, Malka,
14 | "Discrete-time Competing-Risks Regression with or without Penalization", Biometrics (2025), doi: 10.1093/biomtc/ujaf040
15 |
16 |
17 | [3]
18 | Allison, Paul D.
19 | "Discrete-Time Methods for the Analysis of Event Histories"
20 | Sociological Methodology (1982),
21 | doi: 10.2307/270718
22 |
23 | [4]
24 | Lee, Minjung and Feuer, Eric J. and Fine, Jason P.
25 | "On the analysis of discrete time competing risks data"
26 | Biometrics (2018)
27 | doi: 10.1111/biom.12881
28 |
29 | [5]
30 | Kalbfleisch, John D. and Prentice, Ross L.
31 | "The Statistical Analysis of Failure Time Data" 2nd Ed.,
32 | Wiley (2011)
33 | ISBN: 978-1-118-03123-0
34 |
35 | [6]
36 | Klein, John P. and Moeschberger, Melvin L.
37 | "Survival Analysis",
38 | Springer-Verlag (2003)
39 | ISBN: 978-0-387-95399-1
40 |
--------------------------------------------------------------------------------
/docs/models_params.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomer1812/pydts/578929a457c111efe009d3461aab531b793b33d0/docs/models_params.png
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | lifelines==0.26.4
2 | matplotlib==3.5.1
3 | numpy==1.23.4
4 | pandarallel==1.5.7
5 | pandas==1.4.1
6 | patsy==0.5.2
7 | pillow==9.0.1
8 | psutil==5.9.4
9 | scikit-learn==1.0.2
10 | scipy==1.8.0
11 | statsmodels==0.13.2
12 | threadpoolctl==3.1.0
13 | tqdm==4.63.0
14 | seaborn==0.11.2
15 | tableone==0.7.10
16 | pydts==0.7.8
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | theme:
2 | name: material
3 | features:
4 | - navigation.sections # Sections are included in the navigation on the left.
5 | - toc.integrate # Table of contents is integrated on the left; does not appear separately on the right.
6 | - header.autohide # header disappears as you scroll
7 |
8 | site_name: PyDTS
9 | site_description: The documentation for the PyDTS software library.
10 | site_author: Tomer Meir, Rom Gutman
11 |
12 | repo_url: https://github.com/tomer1812/pydts/
13 | repo_name: tomer1812/pydts
14 | edit_uri: "" # No edit button, as some of our pages are in /docs and some in /examples_utils via symlink, so it's impossible for them all to be accurate
15 |
16 | strict: false # Don't allow warnings during the build process
17 |
18 | #markdown_extensions:
19 | # - pymdownx.highlight:
20 | # anchor_linenums: true
21 | # - pymdownx.inlinehilite
22 | # - pymdownx.arithmatex: # Render LaTeX via MathJax
23 | # generic: true
24 | # inline_syntax: ['$', '$']
25 | # - pymdownx.superfences # Seems to enable syntax highlighting when used with the Material theme.
26 | # - pymdownx.details # Allowing hidden expandable regions denoted by ???
27 | # - pymdownx.snippets: # Include one Markdown file into another
28 | # base_path: docs
29 | # - admonition
30 | # - toc:
31 | # permalink: "¤" # Adds a clickable permalink to each section heading
32 | # toc_depth: 4 # Prevents h5, h6 from showing up in the TOC.
33 | #
34 | #extra_javascript:
35 | # - javascripts/mathjax.js
36 | # - https://polyfill.io/v3/polyfill.min.js?features=es6
37 | # - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
38 |
39 | markdown_extensions:
40 | - pymdownx.highlight:
41 | anchor_linenums: true
42 | - pymdownx.inlinehilite
43 | - pymdownx.arithmatex:
44 | generic: true
45 | - pymdownx.superfences
46 | - pymdownx.details
47 | - pymdownx.snippets:
48 | base_path: docs
49 | - admonition
50 | - toc:
51 | permalink: "¤"
52 | toc_depth: 4
53 |
54 | extra_javascript:
55 | - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
56 |
57 | nav:
58 | - Home: 'index.md'
59 | - Introduction: intro.md
60 | - Methods:
61 | - Introduction: methodsintro.md
62 | - Definitions and Estimation: methods.md
63 | - Evaluation Metrics: methodsevaluation.md
64 | - Examples:
65 | - Event Times Sampler: EventTimesSampler.ipynb
66 | - Estimation Example:
67 | - Introduction: UsageExample-Intro.ipynb
68 | - Data Preparation: UsageExample-DataPreparation.ipynb
69 | - Estimation with TwoStagesFitter: UsageExample-FittingTwoStagesFitter.ipynb
70 | - Estimation with DataExpansionFitter: UsageExample-FittingDataExpansionFitter.ipynb
71 | - Data Regrouping Example: UsageExample-RegroupingData.ipynb
72 | - Comparing the Estimation Methods: ModelsComparison.ipynb
73 | - Evaluation: PerformanceMeasures.ipynb
74 | - Regularization: Regularization.ipynb
75 | - Small Sample Size Example: UsageExample-FittingTwoStagesFitterExact-FULL.ipynb
76 | - Screening Example: UsageExample-SIS-SIS-L.ipynb
77 | - Hospitalization LOS Simulation: SimulatedDataset.ipynb
78 | - API:
79 | - The Two Stages Procedure of Meir and Gorfine (2023) - Efron: 'api/two_stages_fitter.md'
80 | - The Two Stages Procedure of Meir and Gorfine (2023) - Exact: 'api/two_stages_fitter_exact.md'
81 | - Data Expansion Procedure of Lee et al. (2018): 'api/data_expansion_fitter.md'
82 | - Event Times Sampler: 'api/event_times_sampler.md'
83 | - Evaluation: 'api/evaluation.md'
84 | - Cross Validation: 'api/cross_validation.md'
85 | - Model Selection: 'api/model_selection.md'
86 | - Sure Independent Screening: 'api/screening.md'
87 | - Utils: 'api/utils.md'
88 |
89 | plugins:
90 | - mknotebooks
91 | - search
92 | # - mkdocs-jupyter
93 | - mkdocstrings:
94 | handlers:
95 | python:
96 | options:
97 | members: true
98 | show_inheritance: true
99 | inherited_members: true
100 | show_root_heading: true
101 | show_source: true
102 | merge_init_into_class: true
103 | docstring_style: google
104 |
105 | extra:
106 | copyright: Copyright © 2022 Tomer Meir, Rom Gutman, Malka Gorfine
107 | analytics:
108 | provider: google
109 | property: G-Z0XYP3868P
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "pydts"
3 | version = "0.9.7"
4 | description = "Discrete time survival analysis with competing risks"
5 | authors = ["Tomer Meir ", "Rom Gutman ", "Malka Gorfine "]
6 | license = "GNU GPLv3"
7 | readme = "README.md"
8 | homepage = "https://github.com/tomer1812/pydts"
9 | repository = "https://github.com/tomer1812/pydts"
10 | keywords = ["Discrete Time", "Time to Event" ,"Survival Analysis", "Competing Events"]
11 | documentation = "https://tomer1812.github.io/pydts"
12 |
13 | [tool.poetry.dependencies]
14 | python = ">=3.9,<3.11"
15 | pandas = "^1.4.1"
16 | mkdocs = "^1.4.3"
17 | mkdocs-material = "^9.0.0"
18 | #mknotebooks = "^0.7.1"
19 | lifelines = "^0.26.4"
20 | scipy = "^1.8.0"
21 | scikit-learn = "^1.0.2"
22 | tqdm = "^4.63.0"
23 | statsmodels = "^0.13.2"
24 | pandarallel = "^1.5.7"
25 | ipython = "^8.2.0"
26 | numpy = ">=1.23.4, <2.0.0"
27 | psutil = "^5.9.4"
28 | setuptools = "^68.0.0"
29 | seaborn = "^0.12.2"
30 | mkdocstrings = "^0.28"
31 | mkdocstrings-python = "^1.16"
32 | mknotebooks = "^0.8.0"
33 |
34 | [tool.poetry.dev-dependencies]
35 | pytest = "^7.0.1"
36 | coverage = {extras = ["toml"], version = "^6.3.2"}
37 | pytest-cov = "^3.0.0"
38 | mkdocs = "^1.2.3"
39 | mkdocs-material = "^9.0.0"
40 | #mknotebooks = "^0.7.1"
41 | jupyter = "^1.0.0"
42 | scikit-survival = "^0.17.1"
43 | tableone = "^0.7.10"
44 |
45 | [build-system]
46 | requires = ["poetry-core>=1.0.0"]
47 | build-backend = "poetry.core.masonry.api"
48 |
49 | [tool.coverage.paths]
50 | source = ["src", "*/site-packages"]
51 |
52 | [tool.coverage.run]
53 | branch = true
54 | source = ["pydts"]
55 |
56 | [tool.coverage.report]
57 | show_missing = true
58 |
--------------------------------------------------------------------------------
/src/pydts/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.0"
--------------------------------------------------------------------------------
/src/pydts/base_fitters.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable, Tuple, Union
2 |
3 | import pandas as pd
4 | import numpy as np
5 | from .utils import get_expanded_df
6 |
7 |
8 |
9 | class BaseFitter:
10 | """
11 | This class implements the basic fitter methods and attributes api
12 | """
13 |
14 | def __init__(self):
15 | self.event_models = {}
16 | self.expanded_df = pd.DataFrame()
17 | self.event_type_col = None
18 | self.duration_col = None
19 | self.pid_col = None
20 | self.events = None
21 | self.covariates = None
22 | self.formula = None
23 | self.times = None
24 |
25 | def fit(self, df: pd.DataFrame, event_type_col: str = 'J', duration_col: str = 'X', pid_col: str = 'pid',
26 | **kwargs) -> dict:
27 | raise NotImplementedError
28 |
29 | def predict(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame:
30 | raise NotImplementedError
31 |
32 | def evaluate(self, test_df: pd.DataFrame, oracle_col: str = 'T', **kwargs) -> float:
33 | raise NotImplementedError
34 |
35 | def print_summary(self, **kwargs) -> None:
36 | raise NotImplementedError
37 |
38 | def _validate_t(self, t, return_iter=True):
39 | _t = np.array([t]) if not isinstance(t, Iterable) else t
40 | t_i_not_fitted = [t_i for t_i in _t if (t_i not in self.times)]
41 | assert len(t_i_not_fitted) == 0, \
42 | f"Cannot predict for times which were not included during .fit(): {t_i_not_fitted}"
43 | if return_iter:
44 | return _t
45 | return t
46 |
47 | def _validate_covariates_in_df(self, df):
48 | cov_not_fitted = []
49 | if isinstance(self.covariates, list):
50 | cov_not_fitted = [cov for cov in self.covariates if cov not in df.columns]
51 | elif isinstance(self.covariates, dict):
52 | for event in self.events:
53 | event_cov_not_fitted = [cov for cov in self.covariates[event] if cov not in df.columns]
54 | cov_not_fitted.extend(event_cov_not_fitted)
55 | assert len(cov_not_fitted) == 0, \
56 | f"Cannot predict - required covariates are missing from df: {cov_not_fitted}"
57 |
58 | def _validate_cols(self, df, event_type_col, duration_col, pid_col):
59 | assert event_type_col in df.columns, f'Event type column is missing from df: {event_type_col}'
60 | assert duration_col in df.columns, f'Duration column is missing from df: {duration_col}'
61 | assert pid_col in df.columns, f'Observation ID column is missing from df: {pid_col}'
62 |
63 |
64 | class ExpansionBasedFitter(BaseFitter):
65 | """
66 | This class implements the data expansion method which is common for the existing fitters
67 | """
68 |
69 | def _expand_data(self,
70 | df: pd.DataFrame,
71 | event_type_col: str,
72 | duration_col: str,
73 | pid_col: str) -> pd.DataFrame:
74 | """
75 | This method expands the raw data as explained in Lee et al. 2018
76 |
77 | Args:
78 | df (pandas.DataFrame): Dataframe to expand.
79 | event_type_col (str): The event type column name (must be a column in df),
80 | Right censored sample (i) is indicated by event value 0, df.loc[i, event_type_col] = 0.
81 | duration_col (str): Last follow up time column name (must be a column in df).
82 | pid_col (str): Sample ID column name (must be a column in df).
83 |
84 | Returns:
85 | Expanded df (pandas.DataFrame): the expanded dataframe.
86 | """
87 | self._validate_cols(df, event_type_col, duration_col, pid_col)
88 | return get_expanded_df(df=df, event_type_col=event_type_col, duration_col=duration_col, pid_col=pid_col)
89 |
90 | def predict_hazard_jt(self, df: pd.DataFrame, event: Union[str, int], t: Union[Iterable, int]) -> pd.DataFrame:
91 | """
92 | This method calculates the hazard for the given event at the given time values if they were included in
93 | the training set of the event.
94 |
95 | Args:
96 | df (pd.DataFrame): samples to predict for
97 | event (Union[str, int]): event name
98 | t (Union[Iterable, int]): times to calculate the hazard for
99 |
100 | Returns:
101 | df (pd.DataFrame): samples with the prediction columns
102 | """
103 | raise NotImplementedError
104 |
105 | def predict_hazard_t(self, df: pd.DataFrame, t: Union[int, np.array]) -> pd.DataFrame:
106 | """
107 | This function calculates the hazard for all the events at the requested time values if they were included in
108 | the training set of each event.
109 |
110 | Args:
111 | df (pd.DataFrame): samples to predict for
112 | t (int, np.array): times to calculate the hazard for
113 |
114 | Returns:
115 | df (pd.DataFrame): samples with the prediction columns
116 | """
117 | t = self._validate_t(t)
118 | self._validate_covariates_in_df(df.head())
119 |
120 | for event, model in self.event_models.items():
121 | df = self.predict_hazard_jt(df=df, event=event, t=t)
122 | return df
123 |
124 | def predict_hazard_all(self, df: pd.DataFrame) -> pd.DataFrame:
125 | """
126 | This function calculates the hazard for all the events at all time values included in the training set for each
127 | event.
128 |
129 | Args:
130 | df (pd.DataFrame): samples to predict for
131 |
132 | Returns:
133 | df (pd.DataFrame): samples with the prediction columns
134 |
135 | """
136 | self._validate_covariates_in_df(df.head())
137 | df = self.predict_hazard_t(df, t=self.times[:-1])
138 | return df
139 |
140 | def predict_overall_survival(self,
141 | df: pd.DataFrame,
142 | t: int = None,
143 | return_hazards: bool = False) -> pd.DataFrame:
144 | """
145 | This function adds columns of the overall survival until time t.
146 | Args:
147 | df (pandas.DataFrame): dataframe with covariates columns
148 | t (int): time
149 | return_hazards (bool): if to keep the hazard columns
150 |
151 | Returns:
152 | df (pandas.DataFrame): dataframe with the additional overall survival columns
153 |
154 | """
155 | if t is not None:
156 | self._validate_t(t, return_iter=False)
157 | self._validate_covariates_in_df(df.head())
158 |
159 | all_hazards = self.predict_hazard_all(df)
160 | _times = self.times[:-1] if t is None else [_t for _t in self.times[:-1] if _t <= t]
161 | overall = pd.DataFrame()
162 | for t_i in _times:
163 | cols = [f'hazard_j{e}_t{t_i}' for e in self.events]
164 | t_i_hazard = 1 - all_hazards[cols].sum(axis=1)
165 | t_i_hazard.name = f'overall_survival_t{t_i}'
166 | overall = pd.concat([overall, t_i_hazard], axis=1)
167 | overall = pd.concat([df, overall.cumprod(axis=1)], axis=1)
168 |
169 | if return_hazards:
170 | cols = all_hazards.columns[all_hazards.columns.str.startswith("hazard_")]
171 | cols = cols.difference(overall.columns)
172 | if len(cols) > 0:
173 | overall = pd.concat([overall, all_hazards[cols]], axis=1)
174 | return overall
175 |
176 | def predict_prob_event_j_at_t(self, df: pd.DataFrame, event: Union[str, int], t: int) -> pd.DataFrame:
177 | """
178 | This function adds a column with probability of occurance of a specific event at a specific a time.
179 |
180 | Args:
181 | df (pandas.DataFrame): dataframe with covariates columns
182 | event (Union[str, int]): event name
183 | t (int): time
184 |
185 | Returns:
186 | df (pandas.DataFrame): dataframe an additional probability column
187 |
188 | """
189 | assert event in self.events, \
190 | f"Cannot predict for event {event} - it was not included during .fit()"
191 | self._validate_t(t, return_iter=False)
192 | self._validate_covariates_in_df(df.head())
193 |
194 | if f'prob_j{event}_at_t{t}' not in df.columns:
195 | if t == 1:
196 | if f'hazard_j{event}_t{t}' not in df.columns:
197 | df = self.predict_hazard_jt(df=df, event=event, t=t)
198 | df[f'prob_j{event}_at_t{t}'] = df[f'hazard_j{event}_t{t}']
199 | return df
200 | elif not f'overall_survival_t{t - 1}' in df.columns:
201 | df = self.predict_overall_survival(df, t=t, return_hazards=True)
202 | elif not f'hazard_j{event}_t{t}' in df.columns:
203 | df = self.predict_hazard_t(df, t=np.array([_t for _t in self.times[:-1] if _t <= t]))
204 | df[f'prob_j{event}_at_t{t}'] = df[f'overall_survival_t{t - 1}'] * df[f'hazard_j{event}_t{t}']
205 | return df
206 |
207 | def predict_prob_event_j_all(self, df: pd.DataFrame, event: Union[str, int]) -> pd.DataFrame:
208 | """
209 | This function adds columns of a specific event occurrence probabilities.
210 |
211 | Args:
212 | df (pandas.DataFrame): dataframe with covariates columns
213 | event (Union[str, int]): event name
214 |
215 | Returns:
216 | df (pandas.DataFrame): dataframe with probabilities columns
217 |
218 | """
219 | assert event in self.events, \
220 | f"Cannot predict for event {event} - it was not included during .fit()"
221 | self._validate_covariates_in_df(df.head())
222 |
223 | if f'overall_survival_t{self.times[-2]}' not in df.columns:
224 | df = self.predict_overall_survival(df, return_hazards=True)
225 | for t in self.times[:-1]:
226 | if f'prob_j{event}_at_t{t}' not in df.columns:
227 | df = self.predict_prob_event_j_at_t(df=df, event=event, t=t)
228 | return df
229 |
230 | def predict_prob_events(self, df: pd.DataFrame) -> pd.DataFrame:
231 | """
232 | This function adds columns of all the events occurance probabilities.
233 | Args:
234 | df (pandas.DataFrame): dataframe with covariates columns
235 |
236 | Returns:
237 | df (pandas.DataFrame): dataframe with probabilities columns
238 |
239 | """
240 | self._validate_covariates_in_df(df.head())
241 |
242 | for event in self.events:
243 | df = self.predict_prob_event_j_all(df=df, event=event)
244 | return df
245 |
246 | def predict_event_cumulative_incident_function(self, df: pd.DataFrame, event: Union[str, int]) -> pd.DataFrame:
247 | """
248 | This function adds a specific event columns of the predicted hazard function, overall survival, probabilities
249 | of event occurance and cumulative incident function (CIF) to the given dataframe.
250 |
251 | Args:
252 | df (pandas.DataFrame): dataframe with covariates columns included
253 | event (Union[str, int]): event name
254 |
255 | Returns:
256 | df (pandas.DataFrame): dataframe with additional prediction columns
257 |
258 | """
259 | assert event in self.events, \
260 | f"Cannot predict for event {event} - it was not included during .fit()"
261 | self._validate_covariates_in_df(df.head())
262 |
263 | if f'prob_j{event}_at_t{self.times[-2]}' not in df.columns:
264 | df = self.predict_prob_events(df=df)
265 | cols = [f'prob_j{event}_at_t{t}' for t in self.times[:-1]]
266 | cif_df = df[cols].cumsum(axis=1)
267 | cif_df.columns = [f'cif_j{event}_at_t{t}' for t in self.times[:-1]]
268 | df = pd.concat([df, cif_df], axis=1)
269 | return df
270 |
271 | def predict_cumulative_incident_function(self, df: pd.DataFrame) -> pd.DataFrame:
272 | """
273 | This function adds columns of the predicted hazard function, overall survival, probabilities of event occurance
274 | and cumulative incident function (CIF) to the given dataframe.
275 |
276 | Args:
277 | df (pandas.DataFrame): dataframe with covariates columns included
278 |
279 | Returns:
280 | df (pandas.DataFrame): dataframe with additional prediction columns
281 |
282 | """
283 | self._validate_covariates_in_df(df.head())
284 |
285 | for event in self.events:
286 | if f'cif_j{event}_at_t{self.times[-2]}' not in df.columns:
287 | df = self.predict_event_cumulative_incident_function(df=df, event=event)
288 | return df
289 |
290 | def predict_marginal_prob_event_j(self, df: pd.DataFrame, event: Union[str, int]) -> pd.DataFrame:
291 | """
292 | This function calculates the marginal probability of an event given the covariates.
293 |
294 | Args:
295 | df (pandas.DataFrame): dataframe with covariates columns included
296 | event (Union[str, int]): event name
297 |
298 | Returns:
299 | df (pandas.DataFrame): dataframe with additional prediction columns
300 | """
301 |
302 | assert event in self.events, \
303 | f"Cannot predict for event {event} - it was not included during .fit()"
304 | self._validate_covariates_in_df(df.head())
305 |
306 | if f'prob_j{event}_at_t{self.times[-2]}' not in df.columns:
307 | df = self.predict_prob_event_j_all(df=df, event=event)
308 | cols = [f'prob_j{event}_at_t{_t}' for _t in self.times[:-1]]
309 | marginal_prob = df[cols].sum(axis=1)
310 | marginal_prob.name = f'marginal_prob_j{event}'
311 | return pd.concat([df, marginal_prob], axis=1)
312 |
313 | def predict_marginal_prob_all_events(self, df: pd.DataFrame) -> pd.DataFrame:
314 | """
315 | This function calculates the marginal probability per event given the covariates for all the events.
316 |
317 | Args:
318 | df (pandas.DataFrame): dataframe with covariates columns included
319 |
320 | Returns:
321 | df (pandas.DataFrame): dataframe with additional prediction columns
322 | """
323 | self._validate_covariates_in_df(df.head())
324 | for event in self.events:
325 | df = self.predict_marginal_prob_event_j(df=df, event=event)
326 | return df
327 |
--------------------------------------------------------------------------------
/src/pydts/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | PROJECT_DIR = os.path.join(os.path.dirname(__file__))
4 | OUTPUT_DIR = os.path.join(PROJECT_DIR, '../../output')
5 | if not os.path.isdir(OUTPUT_DIR):
6 | os.mkdir(OUTPUT_DIR)
7 |
--------------------------------------------------------------------------------
/src/pydts/cross_validation.py:
--------------------------------------------------------------------------------
1 | __all__ = ["TwoStagesCV", "TwoStagesCVExact", "PenaltyGridSearchCV", "PenaltyGridSearchCVExact"]
2 |
3 |
4 | import pandas as pd
5 | import numpy as np
6 | from .fitters import TwoStagesFitter, TwoStagesFitterExact
7 | import warnings
8 | from copy import deepcopy
9 | from sklearn.model_selection import KFold
10 | pd.set_option("display.max_rows", 500)
11 | warnings.filterwarnings('ignore')
12 | slicer = pd.IndexSlice
13 | from typing import Optional, List, Union
14 | import psutil
15 | from .evaluation import events_brier_score_at_t, events_integrated_brier_score, global_brier_score, \
16 | events_integrated_auc, global_auc, events_auc_at_t
17 | from .model_selection import PenaltyGridSearch, PenaltyGridSearchExact
18 | from time import time
19 |
20 | WORKERS = psutil.cpu_count(logical=False)
21 |
22 |
23 | class BaseTwoStagesCV(object):
24 | """
25 | This class implements K-fold cross-validation using TwoStagesFitters and TwoStagesFittersExact
26 | """
27 |
28 | def __init__(self):
29 | self.models = {}
30 | self.test_pids = {}
31 | self.results = pd.DataFrame()
32 | self.global_auc = {}
33 | self.integrated_auc = {}
34 | self.global_bs = {}
35 | self.integrated_bs = {}
36 | self.TwoStagesFitter_type = 'CoxPHFitter'
37 |
38 | def cross_validate(self,
39 | full_df: pd.DataFrame,
40 | n_splits: int = 5,
41 | shuffle: bool = True,
42 | seed: Union[int, None] = None,
43 | fit_beta_kwargs: dict = {},
44 | covariates=None,
45 | event_type_col: str = 'J',
46 | duration_col: str = 'X',
47 | pid_col: str = 'pid',
48 | x0: Union[np.array, int] = 0,
49 | verbose: int = 2,
50 | nb_workers: int = WORKERS,
51 | metrics=['BS', 'IBS', 'GBS', 'AUC', 'IAUC', 'GAUC']):
52 |
53 | """
54 | This method implements K-fold cross-validation using TwoStagesFitters and full_df data.
55 |
56 | Args:
57 | full_df (pd.DataFrame): Data to cross validate.
58 | n_splits (int): Number of folds, defaults to 5.
59 | shuffle (bool): Shuffle samples before splitting to folds. Defaults to True.
60 | seed: Pseudo-random seed to KFold instance. Defaults to None.
61 | fit_beta_kwargs (dict, Optional): Keyword arguments to pass on to the estimation procedure. If different model for beta is desired, it can be defined here.
62 | covariates (list): list of covariates to be used in estimating the regression coefficients.
63 | event_type_col (str): The event type column name (must be a column in df), Right-censored sample (i) is indicated by event value 0, df.loc[i, event_type_col] = 0.
64 | duration_col (str): Last follow up time column name (must be a column in full_df).
65 | pid_col (str): Sample ID column name (must be a column in full_df).
66 | x0 (Union[numpy.array, int], Optional): initial guess to pass to scipy.optimize.minimize function
67 | verbose (int, Optional): The verbosity level of pandaallel
68 | nb_workers (int, Optional): The number of workers to pandaallel. If not sepcified, defaults to the number of workers available.
69 | metrics (str, list): Evaluation metrics.
70 |
71 | Returns:
72 | Results (pd.DataFrame): Cross validation metrics results
73 | """
74 |
75 | if isinstance(metrics, str):
76 | metrics = [metrics]
77 |
78 | self.models = {}
79 | self.kfold_cv = KFold(n_splits=n_splits, shuffle=shuffle, random_state=seed)
80 |
81 | if 'C' in full_df.columns:
82 | full_df = full_df.drop(['C'], axis=1)
83 | if 'T' in full_df.columns:
84 | full_df = full_df.drop(['T'], axis=1)
85 |
86 | for i_fold, (train_index, test_index) in enumerate(self.kfold_cv.split(full_df)):
87 | self.test_pids[i_fold] = full_df.iloc[test_index][pid_col].values
88 | train_df, test_df = full_df.iloc[train_index], full_df.iloc[test_index]
89 | if self.TwoStagesFitter_type == 'Exact':
90 | fold_fitter = TwoStagesFitterExact()
91 | else:
92 | fold_fitter = TwoStagesFitter()
93 | print(f'Fitting fold {i_fold+1}/{n_splits}')
94 | fold_fitter.fit(df=train_df,
95 | covariates=covariates,
96 | event_type_col=event_type_col,
97 | duration_col=duration_col,
98 | pid_col=pid_col,
99 | x0=x0,
100 | fit_beta_kwargs=fit_beta_kwargs,
101 | verbose=verbose,
102 | nb_workers=nb_workers)
103 |
104 | #self.models[i_fold] = deepcopy(fold_fitter)
105 | self.models[i_fold] = fold_fitter
106 |
107 | pred_df = self.models[i_fold].predict_prob_events(test_df)
108 |
109 | for metric in metrics:
110 | if metric == 'IAUC':
111 | self.integrated_auc[i_fold] = events_integrated_auc(pred_df, event_type_col=event_type_col,
112 | duration_col=duration_col)
113 | elif metric == 'GAUC':
114 | self.global_auc[i_fold] = global_auc(pred_df, event_type_col=event_type_col,
115 | duration_col=duration_col)
116 | elif metric == 'IBS':
117 | self.integrated_bs[i_fold] = events_integrated_brier_score(pred_df, event_type_col=event_type_col,
118 | duration_col=duration_col)
119 | elif metric == 'GBS':
120 | self.global_bs[i_fold] = global_brier_score(pred_df, event_type_col=event_type_col,
121 | duration_col=duration_col)
122 | elif metric == 'AUC':
123 | tmp_res = events_auc_at_t(pred_df, event_type_col=event_type_col,
124 | duration_col=duration_col)
125 | tmp_res = pd.concat([tmp_res], keys=[i_fold], names=['fold'])
126 | tmp_res = pd.concat([tmp_res], keys=[metric], names=['metric'])
127 | self.results = pd.concat([self.results, tmp_res], axis=0)
128 | elif metric == 'BS':
129 | tmp_res = events_brier_score_at_t(pred_df, event_type_col=event_type_col,
130 | duration_col=duration_col)
131 | tmp_res = pd.concat([tmp_res], keys=[i_fold], names=['fold'])
132 | tmp_res = pd.concat([tmp_res], keys=[metric], names=['metric'])
133 | self.results = pd.concat([self.results, tmp_res], axis=0)
134 |
135 | return self.results
136 |
137 |
138 | class BasePenaltyGridSearchCV(object):
139 | """
140 | This class implements K-fold cross-validation of the PenaltyGridSearch
141 | """
142 |
143 | def __init__(self):
144 | self.folds_grids = {}
145 | self.test_pids = {}
146 | self.global_auc = {}
147 | self.integrated_auc = {}
148 | self.global_bs = {}
149 | self.integrated_bs = {}
150 | self.TwoStagesFitter_type = 'CoxPHFitter'
151 |
152 | def cross_validate(self,
153 | full_df: pd.DataFrame,
154 | l1_ratio: float,
155 | penalizers: list,
156 | n_splits: int = 5,
157 | shuffle: bool = True,
158 | seed: Union[int, None] = None,
159 | event_type_col: str = 'J',
160 | duration_col: str = 'X',
161 | pid_col: str = 'pid',
162 | twostages_fit_kwargs: dict = {'nb_workers': WORKERS},
163 | metrics=['IBS', 'GBS', 'IAUC', 'GAUC']) -> pd.DataFrame:
164 |
165 | """
166 | This method implements K-fold cross-validation using PenaltyGridSearch and full_df data.
167 |
168 | Args:
169 | full_df (pd.DataFrame): Data to cross validate.
170 | l1_ratio (float): regularization ratio for the CoxPHFitter (see lifelines.fitters.coxph_fitter.CoxPHFitter documentation).
171 | penalizers (list): penalizer options for each event (see lifelines.fitters.coxph_fitter.CoxPHFitter documentation).
172 | n_splits (int): Number of folds, defaults to 5.
173 | shuffle (boolean): Shuffle samples before splitting to folds. Defaults to True.
174 | seed: Pseudo-random seed to KFold instance. Defaults to None.
175 | event_type_col (str): The event type column name (must be a column in df), Right-censored sample (i) is indicated by event value 0, df.loc[i, event_type_col] = 0.
176 | duration_col (str): Last follow up time column name (must be a column in full_df).
177 | pid_col (str): Sample ID column name (must be a column in full_df).
178 | twostages_fit_kwargs (dict): keyword arguments to pass to each TwoStagesFitter.
179 | metrics (str, list): Evaluation metrics.
180 |
181 | Returns:
182 | gauc_output_df (pd.DataFrame): Global AUC k-fold mean and standard error for all possible combination of the penalizers.
183 | """
184 |
185 | if isinstance(metrics, str):
186 | metrics = [metrics]
187 |
188 | self.folds_grids = {}
189 | self.kfold_cv = KFold(n_splits=n_splits, shuffle=shuffle, random_state=seed)
190 |
191 | if 'C' in full_df.columns:
192 | full_df = full_df.drop(['C'], axis=1)
193 | if 'T' in full_df.columns:
194 | full_df = full_df.drop(['T'], axis=1)
195 |
196 | for i_fold, (train_index, test_index) in enumerate(self.kfold_cv.split(full_df)):
197 | print(f'Starting fold {i_fold+1}/{n_splits}')
198 | start = time()
199 | self.test_pids[i_fold] = full_df.iloc[test_index][pid_col].values
200 | train_df, test_df = full_df.iloc[train_index], full_df.iloc[test_index]
201 | if self.TwoStagesFitter_type == 'Exact':
202 | fold_pgs = PenaltyGridSearchExact()
203 | else:
204 | fold_pgs = PenaltyGridSearch()
205 |
206 | fold_pgs.evaluate(train_df=train_df,
207 | test_df=test_df,
208 | l1_ratio=l1_ratio,
209 | penalizers=penalizers,
210 | metrics=metrics,
211 | seed=seed,
212 | event_type_col=event_type_col,
213 | duration_col=duration_col,
214 | pid_col=pid_col,
215 | twostages_fit_kwargs=twostages_fit_kwargs)
216 |
217 | self.folds_grids[i_fold] = fold_pgs
218 |
219 | for metric in metrics:
220 | if metric == 'GAUC':
221 | self.global_auc[i_fold] = fold_pgs.convert_results_dict_to_df(fold_pgs.global_auc)
222 | elif metric == 'IAUC':
223 | self.integrated_auc[i_fold] = fold_pgs.convert_results_dict_to_df(fold_pgs.integrated_auc)
224 | elif metric == 'GBS':
225 | self.global_bs[i_fold] = fold_pgs.convert_results_dict_to_df(fold_pgs.global_bs)
226 | elif metric == 'IBS':
227 | self.integrated_bs[i_fold] = fold_pgs.convert_results_dict_to_df(fold_pgs.integrated_bs)
228 |
229 | end = time()
230 | print(f'Finished fold {i_fold+1}/{n_splits}, {int(end-start)} seconds')
231 |
232 | if 'GAUC' in metrics:
233 | res = [v for k, v in self.global_auc.items()]
234 | gauc_output_df = pd.concat([pd.concat(res, axis=1).mean(axis=1),
235 | pd.concat(res, axis=1).std(axis=1)],
236 | keys=['Mean', 'SE'], axis=1)
237 | else:
238 | gauc_output_df = pd.DataFrame()
239 | return gauc_output_df
240 |
241 |
242 | class PenaltyGridSearchCV(BasePenaltyGridSearchCV):
243 |
244 | def __init__(self):
245 | super().__init__()
246 | self.TwoStagesFitter_type = 'CoxPHFitter'
247 |
248 |
249 | class PenaltyGridSearchCVExact(BasePenaltyGridSearchCV):
250 |
251 | def __init__(self):
252 | super().__init__()
253 | self.TwoStagesFitter_type = 'Exact'
254 |
255 |
256 | class TwoStagesCV(BaseTwoStagesCV):
257 |
258 | def __init__(self):
259 | super().__init__()
260 | self.TwoStagesFitter_type = 'CoxPHFitter'
261 |
262 |
263 | class TwoStagesCVExact(BaseTwoStagesCV):
264 |
265 | def __init__(self):
266 | super().__init__()
267 | self.TwoStagesFitter_type = 'Exact'
268 |
--------------------------------------------------------------------------------
/src/pydts/data_generation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.special import expit
3 | from typing import Union
4 | import pandas as pd
5 |
6 |
7 | class EventTimesSampler(object):
8 |
9 | def __init__(self, d_times: int, j_event_types: int):
10 | """
11 | This class implements sampling procedure for discrete event times and censoring times for given observations.
12 |
13 | Args:
14 | d_times (int): number of possible event times
15 | j_event_types (int): number of possible event types
16 | """
17 |
18 | self.d_times = d_times
19 | self.times = range(1, self.d_times + 2) # d + 1 for administrative censoring
20 | self.j_event_types = j_event_types
21 | self.events = range(1, self.j_event_types + 1)
22 |
23 | def _validate_prob_dfs_list(self, dfs_list: list, numerical_error_tolerance: float = 0.001) -> list:
24 | for df in dfs_list:
25 | if (((df < (0-numerical_error_tolerance)).any().any()) or ((df > (1+numerical_error_tolerance)).any().any())):
26 | raise ValueError("The chosen sampling parameters result in invalid probabilities for event j at time t")
27 | # Only fixes numerical errors smaller than the tolerance size
28 | df.clip(0, 1, inplace=True)
29 | return dfs_list
30 |
31 | def calculate_hazards(self, observations_df: pd.DataFrame, hazard_coefs: dict, events: list = None) -> list:
32 | """
33 | Calculates the hazard function for the observations given the hazard coefficients.
34 |
35 | Args:
36 | observations_df (pd.DataFrame): Dataframe with observations covariates.
37 | coefficients (dict): time coefficients and covariates coefficients for each event type.
38 |
39 | Returns:
40 | hazards_dfs (list): A list of dataframes, one for each event type, with the hazard function at time t to each of the observations.
41 | """
42 | events = events if events is not None else self.events
43 | a_t = {}
44 | for event in events:
45 | if callable(hazard_coefs['alpha'][event]):
46 | a_t[event] = {t: hazard_coefs['alpha'][event](t) for t in range(1, self.d_times + 1)}
47 | else:
48 | a_t[event] = {t: hazard_coefs['alpha'][event][t-1] for t in range(1, self.d_times + 1)}
49 | b = pd.concat([observations_df.dot(hazard_coefs['beta'][j]) for j in events], axis=1, keys=events)
50 | hazards_dfs = [pd.concat([expit((a_t[j][t] + b[j]).astype(float)) for t in range(1, self.d_times + 1)],
51 | axis=1, keys=(range(1, self.d_times + 1))) for j in events]
52 | return hazards_dfs
53 |
54 | def calculate_overall_survival(self, hazards: list, numerical_error_tolerance: float = 0.001) -> pd.DataFrame:
55 | """
56 | Calculates the overall survival function given the hazards.
57 |
58 | Args:
59 | hazards (list): A list of hazards dataframes for each event type (as returned from EventTimesSampler.calculate_hazards function).
60 | numerical_error_tolerance (float): Tolerate numerical errors of probabilities up to this value.
61 |
62 | Returns:
63 | overall_survival (pd.Dataframe): The overall survival functions.
64 | """
65 | if (((sum(hazards)) > (1 + numerical_error_tolerance)).sum().sum() > 0):
66 | raise ValueError("The chosen sampling parameters result in negative values of the overall survival function")
67 | sum_hazards = sum(hazards).clip(0, 1)
68 | overall_survival = pd.concat([pd.Series(1, index=hazards[0].index),
69 | (1 - sum_hazards).cumprod(axis=1).iloc[:, :-1]], axis=1)
70 | overall_survival.columns += 1
71 | return overall_survival
72 |
73 | def calculate_prob_event_at_t(self, hazards: list, overall_survival: pd.DataFrame,
74 | numerical_error_tolerance: float = 0.001) -> list:
75 | """
76 | Calculates the probability for event j at time t.
77 |
78 | Args:
79 | hazards (list): A list of hazards dataframes for each event type (as returned from EventTimesSampler.calculate_hazards function)
80 | overall_survival (pd.Dataframe): The overall survival functions
81 | numerical_error_tolerance (float): Tolerate numerical errors of probabilities up to this value.
82 |
83 | Returns:
84 | prob_event_at_t (list): A list of dataframes, one for each event type, with the probability of event occurrance at time t to each of the observations.
85 | """
86 | prob_event_at_t = [hazard * overall_survival for hazard in hazards]
87 | prob_event_at_t = self._validate_prob_dfs_list(prob_event_at_t, numerical_error_tolerance)
88 | return prob_event_at_t
89 |
90 | def calculate_prob_event_j(self, prob_j_at_t: list, numerical_error_tolerance: float = 0.001) -> list:
91 | """
92 | Calculates the total probability for event j.
93 |
94 | Args:
95 | prob_j_at_t (list): A list of dataframes, one for each event type, with the probability of event occurrance at time t to each of the observations.
96 | numerical_error_tolerance (float): Tolerate numerical errors of probabilities up to this value.
97 |
98 | Returns:
99 | total_prob_j (list): A list of dataframes, one for each event type, with the total probability of event occurrance to each of the observations.
100 | """
101 | total_prob_j = [prob.sum(axis=1) for prob in prob_j_at_t]
102 | total_prob_j = self._validate_prob_dfs_list(total_prob_j, numerical_error_tolerance)
103 | return total_prob_j
104 |
105 | def calc_prob_t_given_j(self, prob_j_at_t, total_prob_j, numerical_error_tolerance=0.001):
106 | """
107 | Calculates the conditional probability for event occurrance at time t given J_i=j.
108 |
109 | Args:
110 | prob_j_at_t (list): A list of dataframes, one for each event type, with the probability of event occurrance at time t to each of the observations.
111 | total_prob_j (list): A list of dataframes, one for each event type, with the total probability of event occurrance to each of the observations.
112 | numerical_error_tolerance (float): Tolerate numerical errors of probabilities up to this value.
113 |
114 | Returns:
115 | conditional_prob (list): A list of dataframes, one for each event type, with the conditional probability of event occurrance at t given event type j to each of the observations.
116 | """
117 | conditional_prob = [prob.div(sumj, axis=0) for prob, sumj in zip(prob_j_at_t, total_prob_j)]
118 | conditional_prob = self._validate_prob_dfs_list(conditional_prob, numerical_error_tolerance)
119 | return conditional_prob
120 |
121 | def sample_event_times(self, observations_df: pd.DataFrame,
122 | hazard_coefs: dict,
123 | covariates: Union[list, None] = None,
124 | events: Union[list, None] = None,
125 | seed: Union[int, None] = None) -> pd.DataFrame:
126 | """
127 | Sample event type and event occurance times.
128 |
129 | Args:
130 | observations_df (pd.DataFrame): Dataframe with observations covariates.
131 | covariates (list): list of covariates name, must be a subset of observations_df.columns
132 | coefficients (dict): time coefficients and covariates coefficients for each event type.
133 | seed (int, None): numpy seed number for pseudo random sampling.
134 |
135 | Returns:
136 | observations_df (pd.DataFrame): Dataframe with additional columns for sampled event time (T) and event type (J).
137 | """
138 | np.random.seed(seed)
139 | if covariates is None:
140 | covariates = [c for c in observations_df.columns if c not in ['X', 'T', 'C', 'J']]
141 | events = events if events is not None else self.events
142 | cov_df = observations_df[covariates]
143 | hazards = self.calculate_hazards(cov_df, hazard_coefs, events=events)
144 | overall_survival = self.calculate_overall_survival(hazards)
145 | probs_j_at_t = self.calculate_prob_event_at_t(hazards, overall_survival)
146 | total_prob_j = self.calculate_prob_event_j(probs_j_at_t)
147 | probs_t_given_j = self.calc_prob_t_given_j(probs_j_at_t, total_prob_j)
148 | sampled_jt = self.sample_jt(total_prob_j, probs_t_given_j)
149 | if 'J' in observations_df.columns:
150 | observations_df.drop('J', axis=1, inplace=True)
151 | if 'T' in observations_df.columns:
152 | observations_df.drop('T', axis=1, inplace=True)
153 | observations_df = pd.concat([observations_df, sampled_jt], axis=1)
154 | return observations_df
155 |
156 | def sample_jt(self, total_prob_j: list, probs_t_given_j: list, numerical_error_tolerance: float = 0.001) -> pd.DataFrame:
157 | """
158 | Sample event type and event time for each observation.
159 |
160 | Args:
161 | total_prob_j (list): A list of dataframes, one for each event type, with the total probability of event occurrance to each of the observations.
162 | probs_t_given_j (list): A list of dataframes, one for each event type, with the conditional probability of event occurrance at t given event type j to each of the observations.
163 |
164 | Returns:
165 | sampled_df (pd.DataFrame): A dataframe with sampled event time and event type for each observation.
166 | """
167 |
168 | total_prob_j = self._validate_prob_dfs_list(total_prob_j, numerical_error_tolerance)
169 | probs_t_given_j = self._validate_prob_dfs_list(probs_t_given_j, numerical_error_tolerance)
170 |
171 | # Add administrative censoring (no event occured until Tmax) probability as J=0
172 | temp_sums = pd.concat([1 - sum(total_prob_j), *total_prob_j], axis=1, keys=[0, *self.events])
173 | if (((temp_sums < (0 - numerical_error_tolerance)).any().any()) or \
174 | ((temp_sums > (1 + numerical_error_tolerance)).any().any())):
175 | raise ValueError("The chosen sampling parameters result in invalid probabilities")
176 | # Only fixes numerical errors smaller than the tolerance size
177 | temp_sums.clip(0, 1, inplace=True)
178 |
179 | # Efficient way to sample j for each observation with different event probabilities
180 | sampled_df = (temp_sums.cumsum(1) > np.random.rand(temp_sums.shape[0])[:, None]).idxmax(axis=1).to_frame('J')
181 |
182 | temp_ts = []
183 | for j in self.events:
184 | # Get the index of the observations with J_i = j
185 | rel_j = sampled_df.query("J==@j").index
186 |
187 | # Get probs dataframe from the dfs list
188 | prob_df = probs_t_given_j[j - 1] # the prob j to sample from
189 |
190 | # Sample time of occurrence given J_i = j
191 | temp_ts.append((prob_df.loc[rel_j].cumsum(1) >= np.random.rand(rel_j.shape[0])[:, None]).idxmax(axis=1))
192 |
193 | # Add Tmax+1 for observations with J_i = 0
194 | temp_ts.append(pd.Series(self.d_times + 1, index=sampled_df.query('J==0').index))
195 | sampled_df["T"] = pd.concat(temp_ts).sort_index()
196 | return sampled_df
197 |
198 | def sample_independent_lof_censoring(self, observations_df: pd.DataFrame,
199 | prob_lof_at_t: np.array, seed: Union[int, None] = None) -> pd.DataFrame:
200 | """
201 | Samples loss of follow-up censoring time from probabilities independent of covariates.
202 |
203 | Args:
204 | observations_df (pd.DataFrame): Dataframe with observations covariates.
205 | prob_lof_at_t (np.array): Array of probabilities for sampling each of the possible times.
206 | seed (int): pseudo random seed number for numpy.random.seed()
207 |
208 | Returns:
209 | observations_df (pd.DataFrame): Upadted dataframe including sampled censoring time.
210 | """
211 | np.random.seed(seed)
212 | administrative_censoring_prob = (1 - sum(prob_lof_at_t))
213 | assert (administrative_censoring_prob >= 0), "Check the sum of prob_lof_at_t argument."
214 | assert (administrative_censoring_prob <= 1), "Check the sum of prob_lof_at_t argument."
215 |
216 | prob_lof_at_t = np.append(prob_lof_at_t, administrative_censoring_prob)
217 | sampled_df = pd.DataFrame(np.random.choice(a=self.times, size=len(observations_df), p=prob_lof_at_t),
218 | index=observations_df.index, columns=['C'])
219 | # No follow-up censoring, C=d+2 such that T wins when building X column:
220 | #sampled_df.loc[sampled_df['C'] == self.times[-1], 'C'] = self.d_times + 2
221 | if 'C' in observations_df.columns:
222 | observations_df.drop('C', axis=1, inplace=True)
223 | observations_df = pd.concat([observations_df, sampled_df], axis=1)
224 | return observations_df
225 |
226 | def sample_hazard_lof_censoring(self, observations_df: pd.DataFrame, censoring_hazard_coefs: dict,
227 | seed: Union[int, None] = None,
228 | covariates: Union[list, None] = None) -> pd.DataFrame:
229 | """
230 | Samples loss of follow-up censoring time from hazard coefficients.
231 |
232 | Args:
233 | observations_df (pd.DataFrame): Dataframe with observations covariates.
234 | censoring_hazard_coefs (dict): time coefficients and covariates coefficients for the censoring hazard.
235 | seed (int): pseudo random seed number for numpy.random.seed()
236 | covariates (list): list of covariates names, must be a subset of observations_df.columns
237 |
238 | Returns:
239 | observations_df (pd.DataFrame): Upadted dataframe including sampled censoring time.
240 | """
241 | if covariates is None:
242 | covariates = [c for c in observations_df.columns if c not in ['X', 'T', 'C', 'J']]
243 | cov_df = observations_df[covariates]
244 | tmp_ets = EventTimesSampler(d_times=self.d_times, j_event_types=1)
245 | sampled_df = tmp_ets.sample_event_times(cov_df, censoring_hazard_coefs, seed=seed, covariates=covariates,
246 | events=[0])
247 |
248 | # No follow-up censoring, C=d+2 such that T wins when building X column:
249 | #sampled_df.loc[sampled_df['J'] == 0, 'T'] = self.d_times + 2
250 | sampled_df = sampled_df[['T']]
251 | sampled_df.columns = ['C']
252 | if 'C' in observations_df.columns:
253 | observations_df.drop('C', axis=1, inplace=True)
254 | observations_df = pd.concat([observations_df, sampled_df], axis=1)
255 | return observations_df
256 |
257 | def update_event_or_lof(self, observations_df: pd.DataFrame) -> pd.DataFrame:
258 | """
259 | Updates time column 'X' to be the minimum between event time column 'T' and censoring time column 'C'.
260 | Event type 'J' will be changed to 0 for observation with 'C' < 'T'.
261 |
262 | Args:
263 | observations_df (pd.DataFrame): Dataframe with observations after sampling event times 'T' and censoring time 'C'.
264 |
265 | Returns:
266 | observations_df (pd.DataFrame): Dataframe with updated time column 'X' and event type column 'J'
267 | """
268 | assert ('T' in observations_df.columns), "Trying to update event or censoring before sampling event times"
269 | assert ('C' in observations_df.columns), "Trying to update event or censoring before sampling censoring time"
270 | observations_df['X'] = observations_df[['T', 'C']].min(axis=1)
271 | observations_df.loc[observations_df.loc[(observations_df['C'] < observations_df['T'])].index, 'J'] = 0
272 | return observations_df
273 |
--------------------------------------------------------------------------------
/src/pydts/examples_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomer1812/pydts/578929a457c111efe009d3461aab531b793b33d0/src/pydts/examples_utils/__init__.py
--------------------------------------------------------------------------------
/src/pydts/examples_utils/datasets.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from pydts.config import *
3 |
4 | DATASETS_DIR = os.path.join(os.path.dirname((os.path.dirname(__file__))), 'datasets')
5 |
6 | def load_LOS_simulated_data():
7 | os.path.join(os.path.dirname(__file__))
8 | return pd.read_csv(os.path.join(DATASETS_DIR, 'LOS_simulated_data.csv'))
--------------------------------------------------------------------------------
/src/pydts/examples_utils/generate_simulations_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pydts.examples_utils.simulations_data_config import *
3 | from pydts.config import *
4 | import pandas as pd
5 | from scipy.special import expit
6 | from pandarallel import pandarallel
7 |
8 |
9 | def sample_los(new_patient, age_mean, age_std, bmi_mean, bmi_std, coefs=COEFS, baseline_hazard_scale=8,
10 | los_bounds=[1, 150]):
11 | # Columns normalization:
12 | new_patient[AGE_COL] = (new_patient[AGE_COL] - age_mean) / age_std
13 | new_patient[GENDER_COL] = 2 * (new_patient[GENDER_COL] - 0.5)
14 | new_patient[BMI_COL] = (new_patient[BMI_COL] - bmi_mean) / bmi_std
15 | new_patient[SMOKING_COL] = new_patient[SMOKING_COL] - 1
16 | new_patient[HYPERTENSION_COL] = 2 * (new_patient[HYPERTENSION_COL] - 0.5)
17 | new_patient[DIABETES_COL] = 2 * (new_patient[DIABETES_COL] - 0.5)
18 | new_patient[ART_FIB_COL] = 2 * (new_patient[ART_FIB_COL] - 0.5)
19 | new_patient[COPD_COL] = 2 * (new_patient[COPD_COL] - 0.5)
20 | new_patient[CRF_COL] = 2 * (new_patient[CRF_COL] - 0.5)
21 | new_patient = pd.Series(new_patient)
22 |
23 | # Baseline hazard
24 | baseline_hazard = np.random.exponential(scale=baseline_hazard_scale)
25 |
26 | # Patient's correction
27 | beta_x = coefs.dot(new_patient[coefs.index])
28 |
29 | # Sample, round (for ties), and clip to bounds the patient's length of stay at the hospital
30 | los = np.clip(np.round(baseline_hazard * np.exp(beta_x)), a_min=los_bounds[0], a_max=los_bounds[1])
31 | los_death = np.nan if new_patient[IN_HOSPITAL_DEATH_COL] == 0 else los
32 | return los, los_death
33 |
34 |
35 | def hide_weight_info(row):
36 | #admyear = row[ADMISSION_YEAR_COL]
37 | p_weight = 0.3 #+ int(admyear > (min_year + 3)) * 0.8 * ((admyear - min_year) / (max_year - min_year))
38 | sample_weight = np.random.binomial(1, p=p_weight)
39 | if sample_weight == 0:
40 | row[WEIGHT_COL] = np.nan
41 | row[BMI_COL] = np.nan
42 | return row
43 |
44 |
45 | def main(seed=0, N_patients=DEFAULT_N_PATIENTS, output_dir=OUTPUT_DIR, filename=SIMULATED_DATA_FILENAME):
46 | # Set random seed for consistent sampling
47 | np.random.seed(seed)
48 |
49 | # Female - 1, Male - 0
50 | gender = np.random.binomial(n=1, p=0.5, size=N_patients)
51 |
52 | simulated_patients_df = pd.DataFrame()
53 |
54 | for p in range(N_patients):
55 | # Sample gender dependent age for each patient
56 | age = np.round(np.random.normal(loc=72 + 5 * gender[p], scale=12), decimals=1)
57 |
58 | # Random sample admission year
59 | admyear = np.random.randint(low=min_year, high=max_year)
60 |
61 | # Sample gender dependent height
62 | height = np.random.normal(loc=175 - 5 * gender[p], scale=7)
63 |
64 | # Sample height, gender and age dependent weight
65 | weight = np.random.normal(loc=(height / 175) * 80 - 5 * gender[p] + (age / 20), scale=8)
66 |
67 | # Calculate body mass index (BMI) from weight and height
68 | bmi = weight / ((height / 100) ** 2)
69 |
70 | # Random sample of previous admissions
71 | admserial = np.clip(np.round(np.random.lognormal(mean=0, sigma=0.75)), 1, 20)
72 |
73 | # Random sample of categorical smoking status: No - 0, Previously - 1, Currently - 2
74 | smoking = np.random.choice([0, 1, 2], p=[0.5, 0.3, 0.2])
75 |
76 | # Sample patient's preconditions based on gender, age, BMI, and smoking status with limits on the value of p
77 | pre_p = np.clip((bmi_coef * bmi + gender_coef * gender[p] + age_coef * age + smk_coef * smoking),
78 | a_min=0.05, a_max=max_p)
79 | hypertension = np.random.binomial(n=1, p=pre_p)
80 | diabetes = np.random.binomial(n=1, p=pre_p + bmi_coef * bmi)
81 | artfib = np.random.binomial(n=1, p=pre_p) # Arterial Fibrillation
82 | copd = np.random.binomial(n=1, p=pre_p + smk_coef * smoking) # Chronic Obstructive Pulmonary Disease
83 | crf = np.random.binomial(n=1, p=pre_p) # Chronic Renal Failure
84 |
85 | new_patient = {
86 | PATIENT_NO_COL: p,
87 | AGE_COL: age,
88 | GENDER_COL: gender[p],
89 | ADMISSION_YEAR_COL: int(admyear),
90 | FIRST_ADMISSION_COL: int(admserial == 1),
91 | ADMISSION_SERIAL_COL: int(admserial),
92 | WEIGHT_COL: weight,
93 | HEIGHT_COL: height,
94 | BMI_COL: bmi,
95 | SMOKING_COL: smoking,
96 | HYPERTENSION_COL: hypertension,
97 | DIABETES_COL: diabetes,
98 | ART_FIB_COL: artfib,
99 | COPD_COL: copd,
100 | CRF_COL: crf,
101 | }
102 |
103 | simulated_patients_df = simulated_patients_df.append(new_patient, ignore_index=True)
104 |
105 | # [age, gender, ADMISSION_YEAR, FIRST_ADMISSION, ADMISSION_SERIAL, WEIGHT_COL, HEIGHT_COL, BMI, SMOKING_COL, 5*preconditions]
106 | #b1 = [-0.003, 0.05, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
107 | b2 = [0.003, -0.05, 0, 0, 0, 0.003, 0.0001, 0.005, 0.05, 0.02, 0.02, 0.02, 0.02, 0.02]
108 |
109 | real_coef_dict = {
110 | "alpha": {
111 | 1: lambda t: -2.2 - 0.2 * np.log(t),
112 | 2: lambda t: -2.3 - 0.2 * np.log(t)
113 | },
114 | "beta": {
115 | 1: [-1*b for b in b2],
116 | 2: [b for b in b2]
117 | }
118 | }
119 |
120 | # Sample event type and relative time
121 | events_df = new_sample_logic(simulated_patients_df.set_index(PATIENT_NO_COL), j_events=2, d_times=30,
122 | real_coef_dict=real_coef_dict)
123 | simulated_patients_df = pd.concat([simulated_patients_df, events_df], axis=1)
124 |
125 | simulated_patients_df[DEATH_RELATIVE_COL] = np.nan
126 | simulated_patients_df.loc[simulated_patients_df.J == 1, DEATH_RELATIVE_COL] = simulated_patients_df.loc[
127 | simulated_patients_df.J == 1, 'T'].values
128 | simulated_patients_df = simulated_patients_df.rename(columns={'T': DISCHARGE_RELATIVE_COL})
129 | simulated_patients_df.loc[simulated_patients_df.J == 0, DISCHARGE_RELATIVE_COL] = 31
130 | simulated_patients_df = simulated_patients_df.drop('J', axis=1)
131 |
132 | # Remove weight and bmi based on admission year
133 | simulated_patients_df = simulated_patients_df.apply(hide_weight_info, axis=1)
134 |
135 | simulated_patients_df[DEATH_MISSING_COL] = simulated_patients_df[DEATH_RELATIVE_COL].isnull().astype(int)
136 | simulated_patients_df[IN_HOSPITAL_DEATH_COL] = simulated_patients_df[DEATH_RELATIVE_COL].notnull().astype(int)
137 | simulated_patients_df[RETURNING_PATIENT_COL] = pd.cut(simulated_patients_df[ADMISSION_SERIAL_COL],
138 | bins=ADMISSION_SERIAL_BINS, labels=ADMISSION_SERIAL_LABELS)
139 |
140 | simulated_patients_df.set_index(PATIENT_NO_COL).to_csv(os.path.join(output_dir, filename))
141 |
142 |
143 | def default_sampling_logic(Z, d_times):
144 | alpha1t = -1 -0.3*np.log(np.arange(start=1, stop=d_times+1))
145 | beta1 = -np.log([0.8, 3, 3, 2.5, 2])
146 | alpha2t = -1.75 -0.15*np.log(np.arange(start=1, stop=d_times+1))
147 | beta2 = -np.log([1, 3, 4, 3, 2])
148 | hazard1 = expit(alpha1t+(Z*beta1).sum())
149 | hazard2 = expit(alpha2t+(Z*beta2).sum())
150 | surv_func = np.array([1, *np.cumprod(1-hazard1-hazard2)[:-1]])
151 | proba1 = hazard1*surv_func
152 | proba2 = hazard2*surv_func
153 | sum1 = np.sum(proba1)
154 | sum2 = np.sum(proba2)
155 | probj1t = proba1 / sum1
156 | probj2t = proba2 / sum2
157 | j_i = np.random.choice(a=[0, 1, 2], p=[1-sum1-sum2, sum1, sum2])
158 | if j_i == 0:
159 | T_i = d_times
160 | elif j_i == 1:
161 | T_i = np.random.choice(a=np.arange(1, d_times+1), p=probj1t)
162 | else:
163 | T_i = np.random.choice(a=np.arange(1, d_times+1), p=probj2t)
164 | return j_i, T_i
165 |
166 |
167 | def calculate_jt(sums, probs_jt, d_times, j_events):
168 | events = range(1, j_events + 1)
169 | temp_sums = pd.concat([1 - sum(sums), *sums], axis=1, keys=[0, *events])
170 |
171 | j_df = (temp_sums.cumsum(1) > np.random.rand(temp_sums.shape[0])[:, None]).idxmax(axis=1).to_frame('J')
172 |
173 | temp_ts = []
174 | for j in events:
175 | rel_j = j_df.query("J==@j").index
176 | prob_df = probs_jt[j - 1] # the prob j to sample from
177 | # sample T
178 | temp_ts.append((prob_df.loc[rel_j].cumsum(1) >= np.random.rand(rel_j.shape[0])[:, None]).idxmax(axis=1))
179 |
180 | temp_ts.append(pd.Series(d_times+1, index=j_df.query('J==0').index))
181 |
182 | j_df["T"] = pd.concat(temp_ts).sort_index()
183 | return j_df
184 |
185 |
186 | def new_sample_logic(patients_df: pd.DataFrame, j_events: int, d_times: int, real_coef_dict: dict) -> pd.DataFrame:
187 | """
188 | A quicker sample logic, that uses coefficients supplied by the user
189 | """
190 | events = range(1, j_events + 1)
191 | a_t = {event: {t: real_coef_dict['alpha'][event](t) for t in range(1, d_times+1)} for event in events}
192 | b = pd.concat([patients_df.dot(real_coef_dict['beta'][j]) for j in events], axis=1, keys=events)
193 |
194 | hazards = [pd.concat([expit(a_t[j][t] + b[j]) for t in range(1, d_times + 1)],
195 | axis=1, keys=(range(1, d_times + 1))) for j in events]
196 | surv_func = pd.concat([pd.Series(1, index=hazards[0].index),
197 | (1 - sum(hazards)).cumprod(axis=1).iloc[:, :-1]], axis=1)
198 | surv_func.columns += 1
199 |
200 | probs = [hazard * surv_func for hazard in hazards]
201 | sums = [prob.sum(axis=1) for prob in probs]
202 | probjt = [prob.div(sumj, axis=0) for prob, sumj in zip(probs, sums)]
203 |
204 | ret = calculate_jt(sums, probjt, d_times, j_events)
205 | return ret
206 |
207 |
208 | def generate_quick_start_df(n_patients=10000, d_times=30, j_events=2, n_cov=5, seed=0, pid_col='pid',
209 | real_coef_dict: dict = DEFAULT_REAL_COEF_DICT, sampling_logic=new_sample_logic,
210 | censoring_prob=1.):
211 | np.random.seed(seed)
212 | assert real_coef_dict is not None, "The user should supply the coefficients of the experiment"
213 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
214 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]),
215 | columns=covariates)
216 | sampled = sampling_logic(patients_df, j_events, d_times, real_coef_dict)
217 | patients_df = pd.concat([patients_df, sampled], axis=1)
218 | patients_df.index.name = pid_col
219 | patients_df['C'] = np.where(np.random.rand(n_patients) < censoring_prob,
220 | np.random.randint(low=1, high=d_times+1,
221 | size=n_patients), d_times+1)
222 | patients_df['X'] = patients_df[['T', 'C']].min(axis=1)
223 | patients_df.loc[patients_df['C'] < patients_df['T'], 'J'] = 0
224 | return patients_df.reset_index()
225 |
226 |
227 | if __name__ == "__main__":
228 | main(N_patients=50000)
229 | #generate_quick_start_df(n_patients=2)
230 |
--------------------------------------------------------------------------------
/src/pydts/examples_utils/mimic_consts.py:
--------------------------------------------------------------------------------
1 |
2 | ADMISSION_TIME_COL = 'admittime'
3 | DISCHARGE_TIME_COL = 'dischtime'
4 | DEATH_TIME_COL = 'deathtime'
5 | ED_REG_TIME = 'edregtime'
6 | ED_OUT_TIME = 'edouttime'
7 | AGE_COL = 'anchor_age'
8 | GENDER_COL = 'gender'
9 |
10 | AGE_BINS = list(range(0, 125, 5))
11 | AGE_LABELS = [f'{AGE_BINS[a]}' for a in range(len(AGE_BINS) - 1)]
12 |
13 | font_sz = 14
14 | title_sz = 18
15 |
16 | YEAR_GROUP_COL = 'anchor_year_group'
17 | SUBSET_YEAR_GROUP = '2017 - 2019'
18 | SUBJECT_ID_COL = 'subject_id'
19 | ADMISSION_ID_COL = 'hadm_id'
20 | ADMISSION_TYPE_COL = 'admission_type'
21 | CHART_TIME_COL = 'charttime'
22 | STORE_TIME_COL = 'storetime'
23 | LOS_EXACT_COL = 'LOS exact'
24 | LOS_DAYS_COL = 'LOS days'
25 | ADMISSION_LOCATION_COL = 'admission_location'
26 | DISCHARGE_LOCATION_COL = 'discharge_location'
27 | RACE_COL = 'race'
28 | INSURANCE_COL = 'insurance'
29 | ADMISSION_TO_RESULT_COL = 'admission_to_result_time'
30 | ADMISSION_AGE_COL = 'admission_age'
31 | ADMISSION_YEAR_COL = 'admission_year'
32 | ADMISSION_COUNT_COL = 'admissions_count'
33 | ITEM_ID_COL = 'itemid'
34 | NIGHT_ADMISSION_FLAG = 'night_admission'
35 | MARITAL_STATUS_COL = 'marital_status'
36 | STANDARDIZED_AGE_COL = 'standardized_age'
37 | COEF_COL = ' coef '
38 | STDERR_COL = ' std err '
39 | DIRECT_IND_COL = 'direct_emrgency_flag'
40 | PREV_ADMISSION_IND_COL = 'last_less_than_diff'
41 | ADMISSION_COUNT_GROUP_COL = ADMISSION_COUNT_COL + '_group'
42 |
43 | DISCHARGE_REGROUPING_DICT = {
44 | 'HOME': 'HOME',
45 | 'HOME HEALTH CARE': 'HOME',
46 | 'SKILLED NURSING FACILITY': 'FURTHER TREATMENT',
47 | 'DIED': 'DIED',
48 | 'REHAB': 'HOME',
49 | 'CHRONIC/LONG TERM ACUTE CARE': 'FURTHER TREATMENT',
50 | 'HOSPICE': 'FURTHER TREATMENT',
51 | 'AGAINST ADVICE': 'CENSORED',
52 | 'ACUTE HOSPITAL': 'FURTHER TREATMENT',
53 | 'PSYCH FACILITY': 'FURTHER TREATMENT',
54 | 'OTHER FACILITY': 'FURTHER TREATMENT',
55 | 'ASSISTED LIVING': 'HOME',
56 | 'HEALTHCARE FACILITY': 'FURTHER TREATMENT',
57 | }
58 |
59 | RACE_REGROUPING_DICT = {
60 | 'WHITE': 'WHITE',
61 | 'UNKNOWN': 'OTHER',
62 | 'BLACK/AFRICAN AMERICAN': 'BLACK',
63 | 'OTHER': 'OTHER',
64 | 'ASIAN': 'ASIAN',
65 | 'WHITE - OTHER EUROPEAN': 'WHITE',
66 | 'HISPANIC/LATINO - PUERTO RICAN': 'HISPANIC',
67 | 'HISPANIC/LATINO - DOMINICAN': 'HISPANIC',
68 | 'ASIAN - CHINESE': 'ASIAN',
69 | 'BLACK/CARIBBEAN ISLAND': 'BLACK',
70 | 'BLACK/AFRICAN': 'BLACK',
71 | 'BLACK/CAPE VERDEAN': 'BLACK',
72 | 'PATIENT DECLINED TO ANSWER': 'OTHER',
73 | 'WHITE - BRAZILIAN': 'WHITE',
74 | 'PORTUGUESE': 'HISPANIC',
75 | 'ASIAN - SOUTH EAST ASIAN': 'ASIAN',
76 | 'WHITE - RUSSIAN': 'WHITE',
77 | 'ASIAN - ASIAN INDIAN': 'ASIAN',
78 | 'WHITE - EASTERN EUROPEAN': 'WHITE',
79 | 'AMERICAN INDIAN/ALASKA NATIVE': 'OTHER',
80 | 'HISPANIC/LATINO - GUATEMALAN': 'HISPANIC',
81 | 'HISPANIC/LATINO - MEXICAN': 'HISPANIC',
82 | 'HISPANIC/LATINO - SALVADORAN': 'HISPANIC',
83 | 'SOUTH AMERICAN': 'HISPANIC',
84 | 'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER': 'OTHER',
85 | 'HISPANIC/LATINO - COLUMBIAN': 'HISPANIC',
86 | 'HISPANIC/LATINO - CUBAN': 'HISPANIC',
87 | 'ASIAN - KOREAN': 'ASIAN',
88 | 'HISPANIC/LATINO - HONDURAN': 'HISPANIC',
89 | 'HISPANIC/LATINO - CENTRAL AMERICAN': 'HISPANIC',
90 | 'UNABLE TO OBTAIN': 'OTHER',
91 | 'HISPANIC OR LATINO': 'HISPANIC'
92 | }
93 |
94 | # 'MCH': 'Mean Cell Hemoglobin',
95 | # 'MCHC': 'Mean Cell Hemoglobin Concentration',
96 | table1_rename_columns = {
97 | 'AnionGap': 'Anion gap',
98 | 'Bicarbonate': 'Bicarbonate',
99 | 'CalciumTotal': 'Calcium total',
100 | 'Chloride': 'Chloride',
101 | 'Creatinine': 'Creatinine',
102 | 'Glucose': 'Glucose',
103 | 'Magnesium': 'Magnesium',
104 | 'Phosphate': 'Phosphate',
105 | 'Potassium': 'Potassium',
106 | 'Sodium': 'Sodium',
107 | 'UreaNitrogen': 'Urea nitrogen',
108 | 'Hematocrit': 'Hematocrit',
109 | 'Hemoglobin': 'Hemoglobin',
110 | 'MCH': 'MCH',
111 | 'MCHC': 'MCHC',
112 | 'MCV': 'MCV',
113 | 'PlateletCount': 'Platelet count',
114 | 'RDW': 'RDW',
115 | 'RedBloodCells': 'Red blood cells',
116 | 'WhiteBloodCells': 'White blood cells',
117 | NIGHT_ADMISSION_FLAG: 'Night admission',
118 | GENDER_COL: 'Sex',
119 | DIRECT_IND_COL: 'Direct emergency',
120 | PREV_ADMISSION_IND_COL: 'Previous admission this month',
121 | ADMISSION_AGE_COL: 'Admission age',
122 | INSURANCE_COL: 'Insurance',
123 | MARITAL_STATUS_COL: 'Marital status',
124 | RACE_COL: 'Race',
125 | ADMISSION_COUNT_GROUP_COL: 'Admissions number',
126 | LOS_DAYS_COL: 'LOS (days)',
127 | DISCHARGE_LOCATION_COL: 'Discharge location'
128 | }
129 |
130 | table1_rename_sex = {0: 'Male', 1: 'Female'}
131 | table1_rename_race = {'ASIAN': 'Asian', 'BLACK': 'Black', 'HISPANIC': 'Hispanic', 'OTHER': 'Other',
132 | 'WHITE': 'White'}
133 | table1_rename_marital = {'SINGLE': 'Single', 'MARRIED': 'Married', 'DIVORCED': 'Divorced', 'WIDOWED': 'Widowed'}
134 | table1_rename_yes_no = {0: 'No', 1: 'Yes'}
135 | table1_rename_normal_abnormal = {0: 'Normal', 1: 'Abnormal'}
136 | table1_rename_discharge = {1: 'Home', 2: 'Further Treatment', 3: 'Died', 0: 'Censored'}
137 |
138 |
139 |
140 |
141 | rename_beta_index = {
142 | 'AdmsCount 2': 'Admissions number 2',
143 | 'AdmsCount 3up': 'Admissions number 3+',
144 | 'AnionGap': 'Anion gap',
145 | 'Bicarbonate': 'Bicarbonate',
146 | 'CalciumTotal': 'Calcium total',
147 | 'Chloride': 'Chloride',
148 | 'Creatinine': 'Creatinine',
149 | 'Ethnicity BLACK': 'Ethnicity black',
150 | 'Ethnicity HISPANIC': 'Ethnicity hispanic',
151 | 'Ethnicity OTHER': 'Ethnicity other',
152 | 'Ethnicity WHITE': 'Ethnicity white',
153 | 'Glucose': 'Glucose',
154 | 'Hematocrit': 'Hematocrit',
155 | 'Hemoglobin': 'Hemoglobin',
156 | 'Insurance Medicare': 'Insurance medicare',
157 | 'Insurance Other': 'Insurance other',
158 | 'MCH': 'MCH',
159 | 'MCHC': 'MCHC',
160 | 'MCV': 'MCV',
161 | 'Magnesium': 'Magnesium',
162 | 'Marital MARRIED': 'Marital married',
163 | 'Marital SINGLE': 'Marital single',
164 | 'Marital WIDOWED': 'Marital widowed',
165 | 'Phosphate': 'Phosphate',
166 | 'PlateletCount': 'Platelet count',
167 | 'Potassium': 'Potassium',
168 | 'RDW': 'RDW',
169 | 'RedBloodCells': 'Red blood cells',
170 | 'Sodium': 'Sodium',
171 | 'UreaNitrogen': 'Urea nitrogen',
172 | 'WhiteBloodCells': 'White blood cells',
173 | 'direct emrgency flag': 'Direct emergency',
174 | 'gender': 'Sex',
175 | 'last less than diff': 'Recent admission',
176 | 'night admission': 'Night admission',
177 | 'standardized age': 'Standardized age',
178 | }
179 |
180 | beta_units = {
181 | 'Admissions number 2': '2',
182 | 'Admissions number 3+': '3+',
183 | 'Anion gap': 'Abnormal',
184 | 'Bicarbonate': 'Abnormal',
185 | 'Calcium total': 'Abnormal',
186 | 'Chloride': 'Abnormal',
187 | 'Creatinine': 'Abnormal',
188 | 'Ethnicity black': 'Black',
189 | 'Ethnicity hispanic': 'Hispanic',
190 | 'Ethnicity other': 'Other',
191 | 'Ethnicity white': 'White',
192 | 'Glucose': 'Abnormal',
193 | 'Hematocrit': 'Abnormal',
194 | 'Hemoglobin': 'Abnormal',
195 | 'Insurance medicare': 'Medicare',
196 | 'Insurance other': 'Other',
197 | 'MCH': 'Abnormal',
198 | 'MCHC': 'Abnormal',
199 | 'MCV': 'Abnormal',
200 | 'Magnesium': 'Abnormal',
201 | 'Marital married': 'Married',
202 | 'Marital single': 'Single',
203 | 'Marital widowed': 'Widowed',
204 | 'Phosphate': 'Abnormal',
205 | 'Platelet count': 'Abnormal',
206 | 'Potassium': 'Abnormal',
207 | 'RDW': 'Abnormal',
208 | 'Red blood cells': 'Abnormal',
209 | 'Sodium': 'Abnormal',
210 | 'Urea nitrogen': 'Abnormal',
211 | 'White blood cells': 'Abnormal',
212 | 'Direct emergency': 'Yes',
213 | 'Sex': 'Female',
214 | 'Recent admission': 'Yes',
215 | 'Night admission': 'Yes',
216 | 'Standardized age': '',
217 | }
218 |
219 | # ADMISSION_TIME_COL = 'admittime'
220 | # DISCHARGE_TIME_COL = 'dischtime'
221 | # DEATH_TIME_COL = 'deathtime'
222 | # ED_REG_TIME = 'edregtime'
223 | # ED_OUT_TIME = 'edouttime'
224 | # AGE_COL = 'anchor_age'
225 | # GENDER_COL = 'gender'
226 | #
227 | # AGE_BINS = list(range(0, 125, 5))
228 | # AGE_LABELS = [f'{AGE_BINS[a]}' for a in range(len(AGE_BINS) - 1)]
229 | #
230 | # font_sz = 14
231 | # title_sz = 18
232 | #
233 | # YEAR_GROUP_COL = 'anchor_year_group'
234 | # SUBSET_YEAR_GROUP = '2017 - 2019'
235 | # SUBJECT_ID_COL = 'subject_id'
236 | # ADMISSION_ID_COL = 'hadm_id'
237 | # ADMISSION_TYPE_COL = 'admission_type'
238 | # CHART_TIME_COL = 'charttime'
239 | # STORE_TIME_COL = 'storetime'
240 | # LOS_EXACT_COL = 'LOS exact'
241 | # LOS_DAYS_COL = 'LOS days'
242 | # ADMISSION_LOCATION_COL = 'admission_location'
243 | # DISCHARGE_LOCATION_COL = 'discharge_location'
244 | # RACE_COL = 'race'
245 | # INSURANCE_COL = 'insurance'
246 | # ADMISSION_TO_RESULT_COL = 'admission_to_result_time'
247 | # ADMISSION_AGE_COL = 'admission_age'
248 | # ADMISSION_YEAR_COL = 'admission_year'
249 | # ADMISSION_COUNT_COL = 'admissions_count'
250 | # ITEM_ID_COL = 'itemid'
251 | # NIGHT_ADMISSION_FLAG = 'night_admission'
252 | # MARITAL_STATUS_COL = 'marital_status'
253 | # STANDARDIZED_AGE_COL = 'standardized_age'
254 | # COEF_COL = ' coef '
255 | # STDERR_COL = ' std err '
256 | # DIRECT_IND_COL = 'direct_emrgency_flag'
257 | # PREV_ADMISSION_IND_COL = 'last_less_than_diff'
258 | # ADMISSION_COUNT_GROUP_COL = ADMISSION_COUNT_COL + '_group'
259 | #
260 | # DISCHARGE_REGROUPING_DICT = {
261 | # 'HOME': 'HOME',
262 | # 'HOME HEALTH CARE': 'HOME',
263 | # 'SKILLED NURSING FACILITY': 'FURTHER TREATMENT',
264 | # 'DIED': 'DIED',
265 | # 'REHAB': 'HOME',
266 | # 'CHRONIC/LONG TERM ACUTE CARE': 'FURTHER TREATMENT',
267 | # 'HOSPICE': 'FURTHER TREATMENT',
268 | # 'AGAINST ADVICE': 'CENSORED',
269 | # 'ACUTE HOSPITAL': 'FURTHER TREATMENT',
270 | # 'PSYCH FACILITY': 'FURTHER TREATMENT',
271 | # 'OTHER FACILITY': 'FURTHER TREATMENT',
272 | # 'ASSISTED LIVING': 'HOME',
273 | # 'HEALTHCARE FACILITY': 'FURTHER TREATMENT',
274 | # }
275 | #
276 | # RACE_REGROUPING_DICT = {
277 | # 'WHITE': 'WHITE',
278 | # 'UNKNOWN': 'OTHER',
279 | # 'BLACK/AFRICAN AMERICAN': 'BLACK',
280 | # 'OTHER': 'OTHER',
281 | # 'ASIAN': 'ASIAN',
282 | # 'WHITE - OTHER EUROPEAN': 'WHITE',
283 | # 'HISPANIC/LATINO - PUERTO RICAN': 'HISPANIC',
284 | # 'HISPANIC/LATINO - DOMINICAN': 'HISPANIC',
285 | # 'ASIAN - CHINESE': 'ASIAN',
286 | # 'BLACK/CARIBBEAN ISLAND': 'BLACK',
287 | # 'BLACK/AFRICAN': 'BLACK',
288 | # 'BLACK/CAPE VERDEAN': 'BLACK',
289 | # 'PATIENT DECLINED TO ANSWER': 'OTHER',
290 | # 'WHITE - BRAZILIAN': 'WHITE',
291 | # 'PORTUGUESE': 'HISPANIC',
292 | # 'ASIAN - SOUTH EAST ASIAN': 'ASIAN',
293 | # 'WHITE - RUSSIAN': 'WHITE',
294 | # 'ASIAN - ASIAN INDIAN': 'ASIAN',
295 | # 'WHITE - EASTERN EUROPEAN': 'WHITE',
296 | # 'AMERICAN INDIAN/ALASKA NATIVE': 'OTHER',
297 | # 'HISPANIC/LATINO - GUATEMALAN': 'HISPANIC',
298 | # 'HISPANIC/LATINO - MEXICAN': 'HISPANIC',
299 | # 'HISPANIC/LATINO - SALVADORAN': 'HISPANIC',
300 | # 'SOUTH AMERICAN': 'HISPANIC',
301 | # 'NATIVE HAWAIIAN OR OTHER PACIFIC ISLANDER': 'OTHER',
302 | # 'HISPANIC/LATINO - COLUMBIAN': 'HISPANIC',
303 | # 'HISPANIC/LATINO - CUBAN': 'HISPANIC',
304 | # 'ASIAN - KOREAN': 'ASIAN',
305 | # 'HISPANIC/LATINO - HONDURAN': 'HISPANIC',
306 | # 'HISPANIC/LATINO - CENTRAL AMERICAN': 'HISPANIC',
307 | # 'UNABLE TO OBTAIN': 'OTHER',
308 | # 'HISPANIC OR LATINO': 'HISPANIC'
309 | # }
310 | #
311 | # # 'MCH': 'Mean Cell Hemoglobin',
312 | # # 'MCHC': 'Mean Cell Hemoglobin Concentration',
313 | # table1_rename_columns = {
314 | # 'AnionGap': 'Anion gap',
315 | # 'Bicarbonate': 'Bicarbonate',
316 | # 'CalciumTotal': 'Calcium total',
317 | # 'Chloride': 'Chloride',
318 | # 'Creatinine': 'Creatinine',
319 | # 'Glucose': 'Glucose',
320 | # 'Magnesium': 'Magnesium',
321 | # 'Phosphate': 'Phosphate',
322 | # 'Potassium': 'Potassium',
323 | # 'Sodium': 'Sodium',
324 | # 'UreaNitrogen': 'Urea nitrogen',
325 | # 'Hematocrit': 'Hematocrit',
326 | # 'Hemoglobin': 'Hemoglobin',
327 | # 'MCH': 'MCH',
328 | # 'MCHC': 'MCHC',
329 | # 'MCV': 'MCV',
330 | # 'PlateletCount': 'Platelet count',
331 | # 'RDW': 'RDW',
332 | # 'RedBloodCells': 'Red blood cells',
333 | # 'WhiteBloodCells': 'White blood cells',
334 | # NIGHT_ADMISSION_FLAG: 'Night admission',
335 | # GENDER_COL: 'Sex',
336 | # DIRECT_IND_COL: 'Direct emergency',
337 | # PREV_ADMISSION_IND_COL: 'Previous admission this month',
338 | # ADMISSION_AGE_COL: 'Admission age',
339 | # INSURANCE_COL: 'Insurance',
340 | # MARITAL_STATUS_COL: 'Marital status',
341 | # RACE_COL: 'Race',
342 | # ADMISSION_COUNT_GROUP_COL: 'Admissions number',
343 | # LOS_DAYS_COL: 'LOS (days)',
344 | # DISCHARGE_LOCATION_COL: 'Discharge location'
345 | # }
346 | #
347 | # table1_rename_sex = {0: 'Male', 1: 'Female'}
348 | # table1_rename_race = {'ASIAN': 'Asian', 'BLACK': 'Black', 'HISPANIC': 'Hispanic', 'OTHER': 'Other',
349 | # 'WHITE': 'White'}
350 | # table1_rename_marital = {'SINGLE': 'Single', 'MARRIED': 'Married', 'DIVORCED': 'Divorced', 'WIDOWED': 'Widowed'}
351 | # table1_rename_yes_no = {0: 'No', 1: 'Yes'}
352 | # table1_rename_normal_abnormal = {0: 'Normal', 1: 'Abnormal'}
353 | # table1_rename_discharge = {1: 'Home', 2: 'Further Treatment', 3: 'Died', 0: 'Censored'}
354 |
355 | # rename_beta_index = {
356 | # 'AdmsCount 2': 'Admissions number 2',
357 | # 'AdmsCount 3up': 'Admissions number 3+',
358 | # 'AnionGap': 'Anion gap',
359 | # 'Bicarbonate': 'Bicarbonate',
360 | # 'CalciumTotal': 'Calcium total',
361 | # 'Chloride': 'Chloride',
362 | # 'Creatinine': 'Creatinine',
363 | # 'Ethnicity BLACK': 'Ethnicity black',
364 | # 'Ethnicity HISPANIC': 'Ethnicity hispanic',
365 | # 'Ethnicity OTHER': 'Ethnicity other',
366 | # 'Ethnicity WHITE': 'Ethnicity white',
367 | # 'Glucose': 'Glucose',
368 | # 'Hematocrit': 'Hematocrit',
369 | # 'Hemoglobin': 'Hemoglobin',
370 | # 'Insurance Medicare': 'Insurance medicare',
371 | # 'Insurance Other': 'Insurance other',
372 | # 'MCH': 'MCH',
373 | # 'MCHC': 'MCHC',
374 | # 'MCV': 'MCV',
375 | # 'Magnesium': 'Magnesium',
376 | # 'Marital MARRIED': 'Marital married',
377 | # 'Marital SINGLE': 'Marital single',
378 | # 'Marital WIDOWED': 'Marital widowed',
379 | # 'Phosphate': 'Phosphate',
380 | # 'PlateletCount': 'Platelet count',
381 | # 'Potassium': 'Potassium',
382 | # 'RDW': 'RDW',
383 | # 'RedBloodCells': 'Red blood cells',
384 | # 'Sodium': 'Sodium',
385 | # 'UreaNitrogen': 'Urea nitrogen',
386 | # 'WhiteBloodCells': 'White blood cells',
387 | # 'direct emrgency flag': 'Direct emergency',
388 | # 'gender': 'Sex',
389 | # 'last less than diff': 'Recent admission',
390 | # 'night admission': 'Night admission',
391 | # 'standardized age': 'Standardized age',
392 | # }
393 | #
394 | # beta_units = {
395 | # 'Admissions number 2': '2',
396 | # 'Admissions number 3+': '3+',
397 | # 'Anion gap': 'Abnormal',
398 | # 'Bicarbonate': 'Abnormal',
399 | # 'Calcium total': 'Abnormal',
400 | # 'Chloride': 'Abnormal',
401 | # 'Creatinine': 'Abnormal',
402 | # 'Ethnicity black': 'Black',
403 | # 'Ethnicity hispanic': 'Hispanic',
404 | # 'Ethnicity other': 'Other',
405 | # 'Ethnicity white': 'White',
406 | # 'Glucose': 'Abnormal',
407 | # 'Hematocrit': 'Abnormal',
408 | # 'Hemoglobin': 'Abnormal',
409 | # 'Insurance medicare': 'Medicare',
410 | # 'Insurance other': 'Other',
411 | # 'MCH': 'Abnormal',
412 | # 'MCHC': 'Abnormal',
413 | # 'MCV': 'Abnormal',
414 | # 'Magnesium': 'Abnormal',
415 | # 'Marital married': 'Married',
416 | # 'Marital single': 'Single',
417 | # 'Marital widowed': 'Widowed',
418 | # 'Phosphate': 'Abnormal',
419 | # 'Platelet count': 'Abnormal',
420 | # 'Potassium': 'Abnormal',
421 | # 'RDW': 'Abnormal',
422 | # 'Red blood cells': 'Abnormal',
423 | # 'Sodium': 'Abnormal',
424 | # 'Urea nitrogen': 'Abnormal',
425 | # 'White blood cells': 'Abnormal',
426 | # 'Direct emergency': 'Yes',
427 | # 'Sex': 'Female',
428 | # 'Recent admission': 'Yes',
429 | # 'Night admission': 'Yes',
430 | # 'Standardized age': '',
431 | # }
--------------------------------------------------------------------------------
/src/pydts/examples_utils/simulations_data_config.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 | DEFAULT_N_PATIENTS = 10000
5 |
6 | min_year = 2000
7 | max_year = 2015
8 |
9 | bmi_coef = 0.003
10 | age_coef = 0.002
11 | gender_coef = -0.15
12 | smk_coef = 0.1
13 | max_p = 0.65
14 |
15 | PATIENT_NO_COL = 'ID'
16 | AGE_COL = 'Age'
17 | GENDER_COL = 'Sex'
18 | ADMISSION_YEAR_COL = 'Admyear'
19 | FIRST_ADMISSION_COL = 'Firstadm'
20 | ADMISSION_SERIAL_COL = 'Admserial'
21 | WEIGHT_COL = 'Weight'
22 | HEIGHT_COL = 'Height'
23 | BMI_COL = 'BMI'
24 | SMOKING_COL = 'Smoking'
25 | HYPERTENSION_COL = 'Hypertension'
26 | DIABETES_COL = 'Diabetes'
27 | ART_FIB_COL = 'AF' # Arterial Fibrillation
28 | COPD_COL = 'COPD' # Chronic Obstructive Pulmonary Disease
29 | CRF_COL = 'CRF' # Chronic Renal Failure
30 | IN_HOSPITAL_DEATH_COL = 'In_hospital_death'
31 | DISCHARGE_RELATIVE_COL = 'Discharge_relative_date'
32 | DEATH_RELATIVE_COL = 'Death_relative_date_in_hosp'
33 | DEATH_MISSING_COL = 'Death_date_in_hosp_missing'
34 | RETURNING_PATIENT_COL = 'Returning_patient'
35 |
36 | COEFS = pd.Series({
37 | AGE_COL: 0.1,
38 | GENDER_COL: -0.1,
39 | BMI_COL: 0.2,
40 | SMOKING_COL: 0.2,
41 | HYPERTENSION_COL: 0.2,
42 | DIABETES_COL: 0.2,
43 | ART_FIB_COL: 0.2,
44 | COPD_COL: 0.2,
45 | CRF_COL: 0.2,
46 | })
47 |
48 | ADMISSION_SERIAL_BINS = [0, 1.5, 4.5, 8.5, 21]
49 | ADMISSION_SERIAL_LABELS = [0, 1, 2, 3]
50 | SIMULATED_DATA_FILENAME = 'simulated_data.csv'
51 |
52 | preconditions = [SMOKING_COL, HYPERTENSION_COL, DIABETES_COL, ART_FIB_COL, COPD_COL, CRF_COL]
53 | font_sz = 14
54 | title_sz = 18
55 | AGE_BINS = list(range(0, 125, 5))
56 | AGE_LABELS = [f'{AGE_BINS[a]}' for a in range(len(AGE_BINS)-1)]
57 |
58 | DEFAULT_REAL_COEF_DICT = {
59 | "alpha": {
60 | 1: lambda t: -1 - 0.3 * np.log(t),
61 | 2: lambda t: -1.75 - 0.15 * np.log(t)
62 | },
63 | "beta": {
64 | 1: -np.log([0.8, 3, 3, 2.5, 2]),
65 | 2: -np.log([1, 3, 4, 3, 2])
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/src/pydts/model_selection.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from itertools import product
4 | from .fitters import TwoStagesFitter, TwoStagesFitterExact
5 | import warnings
6 | from copy import deepcopy
7 | pd.set_option("display.max_rows", 500)
8 | warnings.filterwarnings('ignore')
9 | from time import time
10 | from typing import Union
11 | slicer = pd.IndexSlice
12 | from .evaluation import events_integrated_brier_score, global_brier_score, events_integrated_auc, global_auc
13 |
14 |
15 | class BasePenaltyGridSearch(object):
16 |
17 | """ This class implements the penalty parameter grid search. """
18 |
19 | def __init__(self):
20 | self.l1_ratio = None
21 | self.penalizers = []
22 | self.seed = None
23 | self.meta_models = {}
24 | self.train_df = None
25 | self.test_df = None
26 | self.global_auc = {}
27 | self.integrated_auc = {}
28 | self.global_bs = {}
29 | self.integrated_bs = {}
30 | self.TwoStagesFitter_type = 'CoxPHFitter'
31 |
32 | def evaluate(self,
33 | train_df: pd.DataFrame,
34 | test_df: pd.DataFrame,
35 | l1_ratio: float,
36 | penalizers: list,
37 | metrics: Union[list, str] = ['IBS', 'GBS', 'IAUC', 'GAUC'],
38 | seed: Union[None, int] = None,
39 | event_type_col: str = 'J',
40 | duration_col: str = 'X',
41 | pid_col: str = 'pid',
42 | twostages_fit_kwargs: dict = {}) -> tuple:
43 |
44 | """
45 | This function implements model estimation using train_df and evaluation of the metrics using test_df to all the possible combinations of penalizers.
46 |
47 | Args:
48 | train_df (pd.DataFrame): training data for fitting the model.
49 | test_df (pd.DataFrame): testing data for evaluating the estimated model's performance.
50 | l1_ratio (float): regularization ratio for the CoxPHFitter (see lifelines.fitters.coxph_fitter.CoxPHFitter documentation).
51 | penalizers (list): penalizer options for each event (see lifelines.fitters.coxph_fitter.CoxPHFitter documentation).
52 | metrics (str, list): Evaluation metrics.
53 | seed (int): pseudo random seed number for numpy.random.seed()
54 | event_type_col (str): The event type column name (must be a column in df), Right-censored sample (i) is indicated by event value 0, df.loc[i, event_type_col] = 0.
55 | duration_col (str): Last follow up time column name (must be a column in df).
56 | pid_col (str): Sample ID column name (must be a column in df).
57 | twostages_fit_kwargs (dict): keyword arguments to pass to the TwoStagesFitter.
58 |
59 | Returns:
60 | output (Tuple): Penalizers with best performance in terms of Global-AUC, if 'GAUC' is in metrics.
61 |
62 | """
63 |
64 | self.l1_ratio = l1_ratio
65 | self.penalizers = penalizers
66 | self.seed = seed
67 | np.random.seed(seed)
68 |
69 | for idp, penalizer in enumerate(penalizers):
70 |
71 | fit_beta_kwargs = self._get_model_fit_kwargs(penalizer, l1_ratio)
72 |
73 | if self.TwoStagesFitter_type == 'Exact':
74 | self.meta_models[penalizer] = TwoStagesFitterExact()
75 | else:
76 | self.meta_models[penalizer] = TwoStagesFitter()
77 | print(f"Started estimating the coefficients for penalizer {penalizer} ({idp+1}/{len(penalizers)})")
78 | start = time()
79 | self.meta_models[penalizer].fit(df=train_df, fit_beta_kwargs=fit_beta_kwargs,
80 | pid_col=pid_col, event_type_col=event_type_col, duration_col=duration_col,
81 | **twostages_fit_kwargs)
82 | end = time()
83 | print(f"Finished estimating the coefficients for penalizer {penalizer} ({idp+1}/{len(penalizers)}), {int(end - start)} seconds")
84 |
85 | events = [j for j in sorted(train_df[event_type_col].unique()) if j != 0]
86 | grid = [penalizers for e in events]
87 | penalizers_combinations = list(product(*grid))
88 |
89 | for idc, combination in enumerate(penalizers_combinations):
90 | mixed_two_stages = self.get_mixed_two_stages_fitter(combination)
91 |
92 | pred_df = mixed_two_stages.predict_prob_events(test_df)
93 |
94 | for metric in metrics:
95 | if metric == 'IAUC':
96 | self.integrated_auc[combination] = events_integrated_auc(pred_df, event_type_col=event_type_col,
97 | duration_col=duration_col)
98 | elif metric == 'GAUC':
99 | self.global_auc[combination] = global_auc(pred_df, event_type_col=event_type_col,
100 | duration_col=duration_col)
101 | elif metric == 'IBS':
102 | self.integrated_bs[combination] = events_integrated_brier_score(pred_df,
103 | event_type_col=event_type_col,
104 | duration_col=duration_col)
105 | elif metric == 'GBS':
106 | self.global_bs[combination] = global_brier_score(pred_df, event_type_col=event_type_col,
107 | duration_col=duration_col)
108 |
109 | output = self.convert_results_dict_to_df(self.global_auc).idxmax().values[0] if 'GAUC' in metrics else []
110 | return output
111 |
112 | def convert_results_dict_to_df(self, results_dict):
113 | """
114 | This function converts a results dictionary to a pd.DataFrame format.
115 |
116 | Args:
117 | results_dict: one of the class attributes: global_auc, integrated_auc, global_bs, integrated_bs.
118 |
119 | Returns:
120 | df (pd.DataFrame): Results in a pd.DataFrame format.
121 | """
122 | df = pd.DataFrame(results_dict.values(), index=pd.MultiIndex.from_tuples(results_dict.keys()))
123 | return df
124 |
125 | def get_mixed_two_stages_fitter(self, penalizers_combination: list) -> TwoStagesFitter:
126 | """
127 | This function creates a mixed TwoStagesFitter from the estimated meta models for a specific penalizers combination.
128 |
129 | Args:
130 | penalizers_combination (list): List with length equals to the number of competing events. The penalizers value to each of the events. Each of the values must be one of the values that was previously passed to the evaluate() method.
131 |
132 | Returns:
133 | mixed_two_stages (pydts.fitters.TwoStagesFitter): TwoStagesFitter for the required penalty combination.
134 | """
135 | _validate_estimated_value = [p for p in penalizers_combination if p not in list(self.meta_models.keys())]
136 | assert len(_validate_estimated_value) == 0, \
137 | f"Values {_validate_estimated_value} were note estimated. All the penalizers in penalizers_combination must be estimated using evaluate() before a mixed combination can be generated."
138 |
139 | events = self.meta_models[penalizers_combination[0]].events
140 | event_type_col = self.meta_models[penalizers_combination[0]].event_type_col
141 | if self.TwoStagesFitter_type == 'Exact':
142 | mixed_two_stages = TwoStagesFitterExact()
143 | else:
144 | mixed_two_stages = TwoStagesFitter()
145 |
146 | for ide, event in enumerate(sorted(events)):
147 | if ide == 0:
148 | mixed_two_stages.covariates = self.meta_models[penalizers_combination[ide]].covariates
149 | mixed_two_stages.duration_col = self.meta_models[penalizers_combination[ide]].duration_col
150 | mixed_two_stages.event_type_col = self.meta_models[penalizers_combination[ide]].event_type_col
151 | mixed_two_stages.events = self.meta_models[penalizers_combination[ide]].events
152 | mixed_two_stages.pid_col = self.meta_models[penalizers_combination[ide]].pid_col
153 | mixed_two_stages.times = self.meta_models[penalizers_combination[ide]].times
154 |
155 | mixed_two_stages.beta_models[event] = self.meta_models[penalizers_combination[ide]].beta_models[event]
156 | mixed_two_stages.event_models[event] = []
157 | mixed_two_stages.event_models[event].append(self.meta_models[penalizers_combination[ide]].beta_models[event])
158 |
159 | event_alpha = self.meta_models[penalizers_combination[ide]].alpha_df.copy()
160 | event_alpha = event_alpha[event_alpha[event_type_col] == event]
161 | mixed_two_stages.alpha_df = pd.concat([mixed_two_stages.alpha_df, event_alpha])
162 | mixed_two_stages.event_models[event].append(event_alpha)
163 |
164 | return mixed_two_stages
165 |
166 | def _get_model_fit_kwargs(self, penalizer, l1_ratio):
167 | if self.TwoStagesFitter_type == 'Exact':
168 | fit_beta_kwargs = {
169 | 'model_fit_kwargs': {
170 | 'alpha': penalizer,
171 | 'L1_wt': l1_ratio
172 | }
173 | }
174 | else:
175 | fit_beta_kwargs = {
176 | 'model_kwargs': {
177 | 'penalizer': penalizer,
178 | 'l1_ratio': l1_ratio
179 | },
180 | }
181 | return fit_beta_kwargs
182 |
183 |
184 | class PenaltyGridSearch(BasePenaltyGridSearch):
185 |
186 | def __init__(self):
187 | super().__init__()
188 | self.TwoStagesFitter_type = 'CoxPHFitter'
189 |
190 |
191 | class PenaltyGridSearchExact(BasePenaltyGridSearch):
192 |
193 | def __init__(self):
194 | super().__init__()
195 | self.TwoStagesFitter_type = 'Exact'
196 |
--------------------------------------------------------------------------------
/src/pydts/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable, Optional
2 |
3 | import pandas as pd
4 | import numpy as np
5 | from scipy.special import expit
6 |
7 |
8 | # def get_expanded_df(df, event_type_col='J', duration_col='X', pid_col='pid'):
9 | # """
10 | # This function gets a dataframe describing each sample the time of the observed events,
11 | # and returns an expanded dataframe as explained in [1]-[2].
12 | # Right censoring is allowed and must be marked as event type 0.
13 | #
14 | # :param df: original dataframe (pd.DataFrame)
15 | # :param event_type_col: event type column name (str)
16 | # :param duration_col: time column name (str)
17 | # :param pid_col: patient id column name (str)
18 | #
19 | # :return: result_df: expanded dataframe
20 | #
21 | # References:
22 | # [1] Meir, Tomer and Gorfine, Malka, "Discrete-time Competing-Risks Regression with or without Penalization", https://arxiv.org/abs/2303.01186
23 | # [2] Meir, Tomer and Gutman, Rom and Gorfine, Malka "PyDTS: A Python Package for Discrete-Time Survival (Regularized) Regression with Competing Risks", https://arxiv.org/abs/2204.05731
24 | # """
25 |
26 | def get_expanded_df(
27 | df: pd.DataFrame,
28 | event_type_col: str = 'J',
29 | duration_col: str = 'X',
30 | pid_col: str = 'pid') -> pd.DataFrame:
31 | """
32 | Expands a discrete-time survival dataset into a long-format dataframe suitable for modeling. This function receives a dataframe where each row corresponds to a subject with observed event type and duration. It returns an expanded dataframe where each subject is represented by multiple rows, one for each time point up to their observed time. Right censoring is allowed and should be indicated by event type 0.
33 |
34 | Args:
35 | df (pd.DataFrame): Original input dataframe containing one row per subject.
36 | event_type_col (str): Name of the column indicating event type. Censoring is marked by 0.
37 | duration_col (str): Name of the column indicating event or censoring time.
38 | pid_col (str): Name of the column indicating subject/patient ID.
39 |
40 | Returns:
41 | pd.DataFrame: Expanded dataframe in long format, with one row per subject-time pair.
42 | """
43 | unique_times = df[duration_col].sort_values().unique()
44 | result_df = df.reindex(df.index.repeat(df[duration_col]))
45 | result_df[duration_col] = result_df.groupby(pid_col).cumcount() + 1
46 | # drop times that didn't happen
47 | result_df.drop(index=result_df.loc[~result_df[duration_col].isin(unique_times)].index, inplace=True)
48 | result_df.reset_index(drop=True, inplace=True)
49 | last_idx = result_df.drop_duplicates(subset=[pid_col], keep='last').index
50 | events = sorted(df[event_type_col].unique())
51 | result_df.loc[last_idx, [f'j_{e}' for e in events]] = pd.get_dummies(
52 | result_df.loc[last_idx, event_type_col]).values
53 | result_df[[f'j_{e}' for e in events]] = result_df[[f'j_{e}' for e in events]].fillna(0)
54 | result_df[f'j_0'] = 1 - result_df[[f'j_{e}' for e in events if e > 0]].sum(axis=1)
55 | return result_df
56 |
57 |
58 | def compare_models_coef_per_event(first_model: pd.Series,
59 | second_model: pd.Series,
60 | real_values: np.array,
61 | event: int,
62 | first_model_label:str = "first",
63 | second_model_label:str = "second"
64 | ) -> pd.DataFrame:
65 | event_suffix = f"_{event}"
66 | assert (first_model.index == second_model.index).all(), "All index should be the same"
67 | models = pd.concat([first_model.to_frame(first_model_label),
68 | second_model.to_frame(second_model_label)], axis=1)
69 | models.index += event_suffix
70 | real_values_s = pd.Series(real_values, index=models.index)
71 |
72 | return pd.concat([models, real_values_s.to_frame("real")], axis=1)
73 |
74 |
75 | def present_coefs(res_dict):
76 | from IPython.display import display
77 | for coef_type, events_dict in res_dict.items():
78 | print(f"for coef: {coef_type.capitalize()}")
79 | df = pd.concat([temp_df for temp_df in events_dict.values()])
80 | display(df)
81 |
82 |
83 | def get_real_hazard(df, real_coef_dict, times, events):
84 | a_t = {event: {t: real_coef_dict['alpha'][event](t) for t in times} for event in events}
85 | b = pd.concat([df.dot(real_coef_dict['beta'][j]) for j in events], axis=1, keys=events)
86 |
87 | for j in events:
88 | df[[f'hazard_j{j}_t{t}' for t in times]] = pd.concat([expit(a_t[j][t] + b[j]) for t in times],
89 | axis=1).values
90 | return df
91 |
92 |
93 | def assert_fit(event_df, times, event_type_col='J', duration_col='X'):
94 | if not event_df['success'].all():
95 | problematic_times = event_df.loc[~event_df['success'], duration_col].tolist()
96 | event = event_df[event_type_col].max() # all the events in the dataframe are the same
97 | raise RuntimeError(f"Number of observed events at some time points are too small. Consider collapsing neighbor time points."
98 | f"\n See https://tomer1812.github.io/pydts/UsageExample-RegroupingData/ for more details.")
99 | if not np.all([t in event_df[duration_col].values for t in times]):
100 | event = event_df[event_type_col].max() # all the events in the dataframe are the same
101 | problematic_times = pd.Index(event_df[duration_col]).symmetric_difference(times).tolist()
102 | raise RuntimeError(f"Number of observed events at some time points are too small. Consider collapsing neighbor time points."
103 | f"\n See https://tomer1812.github.io/pydts/UsageExample-RegroupingData/ for more details.")
104 |
105 |
106 | def create_df_for_cif_plots(df: pd.DataFrame, field: str,
107 | covariates: Iterable,
108 | vals: Optional[Iterable] = None,
109 | quantiles: Optional[Iterable] = None,
110 | zero_others: Optional[bool] = False
111 | ) -> pd.DataFrame:
112 | """
113 | This method creates df for cif plot, where it zeros
114 |
115 | Args:
116 | df (pd.DataFrame): Dataframe which we yield the statiscal propetrics (means, quantiles, etc) and stacture
117 | field (str): The field which will represent the change
118 | covariates (Iterable): The covariates of the given model
119 | vals (Optional[Iterable]): The values to use for the field
120 | quantiles (Optional[Iterable]): The quantiles to use as values for the field
121 | zero_others (bool): Whether to zero the other covarites or to zero them
122 |
123 | Returns:
124 | df (pd.DataFrame): A dataframe that contains records per value for cif ploting
125 | """
126 |
127 | cov_not_fitted = [cov for cov in covariates if cov not in df.columns]
128 | assert len(cov_not_fitted) == 0, \
129 | f"Required covariates are missing from df: {cov_not_fitted}"
130 |
131 | df_for_ploting = df.copy()
132 | if vals is not None:
133 | pass
134 | elif quantiles is not None:
135 | vals = df_for_ploting[field].quantile(quantiles).values
136 | else:
137 | raise NotImplemented("Only Quantiles or specific values is supported")
138 | temp_series = []
139 | template_s = df_for_ploting.iloc[0][covariates].copy()
140 | if zero_others:
141 | impute_val = 0
142 | else:
143 | impute_val = df_for_ploting[covariates].mean().values
144 | for val in vals:
145 | temp_s = template_s.copy()
146 | temp_s[covariates] = impute_val
147 | temp_s[field] = val
148 | temp_series.append(temp_s)
149 |
150 | return pd.concat(temp_series, axis=1).T
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomer1812/pydts/578929a457c111efe009d3461aab531b793b33d0/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_DataExpansionFitter.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from src.pydts.examples_utils.generate_simulations_data import generate_quick_start_df
3 | from src.pydts.fitters import DataExpansionFitter
4 | import numpy as np
5 | from src.pydts.utils import get_real_hazard
6 |
7 |
8 | class TestDataExpansionFitter(unittest.TestCase):
9 | def setUp(self):
10 | self.real_coef_dict = {
11 | "alpha": {
12 | 1: lambda t: -1 - 0.3 * np.log(t),
13 | 2: lambda t: -1.75 - 0.15 * np.log(t)
14 | },
15 | "beta": {
16 | 1: -np.log([0.8, 3, 3, 2.5, 2]),
17 | 2: -np.log([1, 3, 4, 3, 2])
18 | }
19 | }
20 | self.df = generate_quick_start_df(n_patients=1000, n_cov=5, d_times=10, j_events=2, pid_col='pid', seed=0,
21 | real_coef_dict=self.real_coef_dict, censoring_prob=0.8)
22 | self.m = DataExpansionFitter()
23 | self.fitted_model = DataExpansionFitter()
24 | self.fitted_model.fit(df=self.df.drop(['C', 'T'], axis=1))
25 |
26 | def test_fit_case_C_in_df(self):
27 | # 'C' named column cannot be passed in df to .fit()
28 | with self.assertRaises(ValueError):
29 | m = DataExpansionFitter()
30 | m.fit(df=self.df)
31 |
32 | def test_fit_case_event_col_not_in_df(self):
33 | # Event column (here 'J') must be passed in df to .fit()
34 | with self.assertRaises(AssertionError):
35 | self.m.fit(df=self.df.drop(['C', 'J', 'T'], axis=1))
36 |
37 | def test_fit_case_duration_col_not_in_df(self):
38 | # Duration column (here 'X') must be passed in df to .fit()
39 | with self.assertRaises(AssertionError):
40 | self.m.fit(df=self.df.drop(['C', 'X', 'T'], axis=1))
41 |
42 | def test_fit_case_pid_col_not_in_df(self):
43 | # Duration column (here 'pid') must be passed in df to .fit()
44 | with self.assertRaises(AssertionError):
45 | self.m.fit(df=self.df.drop(['C', 'pid', 'T'], axis=1))
46 |
47 | def test_fit_case_cov_col_not_in_df(self):
48 | # Covariates columns (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .fit()
49 | with self.assertRaises(ValueError):
50 | self.m.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=['Z6'])
51 |
52 | def test_fit_case_correct_fit(self):
53 | # Fit should be successful
54 | m = DataExpansionFitter()
55 | m.fit(df=self.df.drop(['C', 'T'], axis=1))
56 |
57 | def test_fit_with_kwargs(self):
58 | import statsmodels.api as sm
59 | m = DataExpansionFitter()
60 | m.fit(df=self.df.drop(columns=['C', 'T']), models_kwargs=dict(family=sm.families.Binomial()))
61 |
62 | def test_print_summary(self):
63 | self.fitted_model.print_summary()
64 |
65 | def test_get_beta_SE(self):
66 | self.fitted_model.get_beta_SE()
67 |
68 | def test_get_alpha_SE(self):
69 | self.fitted_model.get_alpha_df()
70 |
71 | def test_predict_hazard_jt_case_covariate_not_in_df(self):
72 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict()
73 | with self.assertRaises(AssertionError):
74 | self.fitted_model.predict_hazard_jt(
75 | df=self.df.drop(['C', 'T', 'Z1'], axis=1),
76 | event=self.fitted_model.events[0],
77 | t=self.fitted_model.times[0])
78 |
79 | def test_predict_hazard_jt_case_event_not_in_events(self):
80 | # Event passed to .predict() must be in fitted events
81 | with self.assertRaises(AssertionError):
82 | self.fitted_model.predict_hazard_jt(
83 | df=self.df.drop(['C', 'T'], axis=1), event=100, t=self.fitted_model.times[0])
84 |
85 | def test_predict_hazard_jt_case_time_not_in_times(self):
86 | # Event passed to .predict() must be in fitted events
87 | with self.assertRaises(AssertionError):
88 | self.fitted_model.predict_hazard_jt(
89 | df=self.df.drop(['C', 'T'], axis=1), event=self.fitted_model.events[0], t=1000)
90 |
91 | def test_predict_hazard_jt_case_successful_predict(self):
92 | self.fitted_model.predict_hazard_jt(
93 | df=self.df.drop(['C', 'T'], axis=1),
94 | event=self.fitted_model.events[0], t=self.fitted_model.times[0])
95 |
96 | def test_predict_hazard_t_case_successful_predict(self):
97 | self.fitted_model.predict_hazard_t(df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[:3])
98 |
99 | def test_predict_hazard_all_case_successful_predict(self):
100 | self.fitted_model.predict_hazard_all(df=self.df.drop(['C', 'T'], axis=1))
101 |
102 | def test_predict_overall_survival_case_successful_predict(self):
103 | self.fitted_model.predict_overall_survival(
104 | df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[5], return_hazards=True)
105 |
106 | def test_predict_prob_event_j_at_t_case_successful_predict(self):
107 | self.fitted_model.predict_prob_event_j_at_t(df=self.df.drop(['C', 'T'], axis=1),
108 | event=self.fitted_model.events[0],
109 | t=self.fitted_model.times[3])
110 |
111 | def test_predict_prob_event_j_all_case_successful_predict(self):
112 | self.fitted_model.predict_prob_event_j_all(df=self.df.drop(['C', 'T'], axis=1),
113 | event=self.fitted_model.events[0])
114 |
115 | def test_predict_prob_events_case_successful_predict(self):
116 | self.fitted_model.predict_prob_events(df=self.df.drop(['C', 'T'], axis=1))
117 |
118 | def test_predict_event_cumulative_incident_function_case_successful_predict(self):
119 | self.fitted_model.predict_event_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1),
120 | event=self.fitted_model.events[0])
121 |
122 | def test_predict_cumulative_incident_function_case_successful_predict(self):
123 | self.fitted_model.predict_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1))
124 |
125 | def test_predict_hazard_jt_case_hazard_already_on_df(self):
126 | df_temp = get_real_hazard(self.df.drop(['C', 'T', 'X', 'J'], axis=1).set_index('pid').copy(),
127 | real_coef_dict=self.real_coef_dict,
128 | times=self.fitted_model.times,
129 | events=self.fitted_model.events)
130 | assert (df_temp == self.fitted_model.predict_hazard_jt(df=df_temp,
131 | event=self.fitted_model.events[0],
132 | t=self.fitted_model.times
133 | )
134 | ).all().all()
--------------------------------------------------------------------------------
/tests/test_EventTimesSampler.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from src.pydts.data_generation import EventTimesSampler
3 | from src.pydts.fitters import TwoStagesFitter, DataExpansionFitter
4 | from time import time
5 | import numpy as np
6 | import pandas as pd
7 | slicer = pd.IndexSlice
8 | COEF_COL = ' coef '
9 | STDERR_COL = ' std err '
10 |
11 |
12 | class TestEventTimesSampler(unittest.TestCase):
13 |
14 | def test_sample_2_events_5_covariates(self):
15 | n_cov = 5
16 | n_patients = 1000
17 | seed = 0
18 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
19 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]),
20 | columns=covariates)
21 |
22 | ets = EventTimesSampler(d_times=15, j_event_types=2)
23 | real_coef_dict = {
24 | "alpha": {
25 | 1: lambda t: -1 - 0.3 * np.log(t),
26 | 2: lambda t: -1.75 - 0.15 * np.log(t),
27 | },
28 | "beta": {
29 | 1: -np.log([0.8, 3, 3, 2.5, 2]),
30 | 2: -np.log([1, 3, 4, 3, 2]),
31 | }
32 | }
33 | ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
34 |
35 | def test_sample_3_events_4_covariates(self):
36 | n_cov = 4
37 | n_patients = 1000
38 | seed = 0
39 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
40 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]),
41 | columns=covariates)
42 |
43 | ets = EventTimesSampler(d_times=15, j_event_types=2)
44 | real_coef_dict = {
45 | "alpha": {
46 | 1: lambda t: -1 - 0.3 * np.log(t),
47 | 2: lambda t: -1.75 - 0.15 * np.log(t),
48 | 3: lambda t: -1.5 - 0.2 * np.log(t),
49 | },
50 | "beta": {
51 | 1: -np.log([0.8, 3, 3, 2.5]),
52 | 2: -np.log([1, 3, 4, 3]),
53 | 3: -np.log([0.5, 2, 3, 4]),
54 | }
55 | }
56 | ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
57 |
58 | def test_sample_hazard_censoring(self):
59 | seed = 0
60 | n_cov = 5
61 | n_patients = 10000
62 | np.random.seed(seed)
63 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
64 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]),
65 | columns=covariates)
66 |
67 | ets = EventTimesSampler(d_times=15, j_event_types=4)
68 | censoring_coef_dict = {
69 | "alpha": {
70 | 0: lambda t: -1 - 0.3 * np.log(t),
71 | },
72 | "beta": {
73 | 0: -np.log([0.8, 3, 3, 2.5, 2]),
74 | }
75 | }
76 | ets.sample_hazard_lof_censoring(patients_df, censoring_coef_dict, seed)
77 |
78 | def test_sample_independent_censoring(self):
79 | n_cov = 4
80 | n_patients = 1000
81 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
82 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]),
83 | columns=covariates)
84 |
85 | d_times = 15
86 | ets = EventTimesSampler(d_times=15, j_event_types=2)
87 | ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.03 * np.ones(d_times))
88 |
89 | def test_update_event_or_lof(self):
90 | n_cov = 5
91 | n_patients = 1000
92 | seed = 0
93 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
94 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]),
95 | columns=covariates)
96 |
97 | ets = EventTimesSampler(d_times=15, j_event_types=3)
98 | real_coef_dict = {
99 | "alpha": {
100 | 1: lambda t: -1 - 0.3 * np.log(t),
101 | 2: lambda t: -1.75 - 0.15 * np.log(t),
102 | 3: lambda t: -1.75 - 0.15 * np.log(t),
103 | },
104 | "beta": {
105 | 1: -np.log([0.8, 3, 3, 2.5, 2]),
106 | 2: -np.log([1, 3, 4, 3, 2]),
107 | 3: -np.log([1, 3, 4, 3, 2]),
108 | }
109 | }
110 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
111 | censoring_coef_dict = {
112 | "alpha": {
113 | 0: lambda t: -1 - 0.3 * np.log(t),
114 | },
115 | "beta": {
116 | 0: -np.log([0.8, 3, 3, 2.5, 2]),
117 | }
118 | }
119 |
120 | patients_df = ets.sample_hazard_lof_censoring(patients_df, censoring_coef_dict, seed=seed)
121 | patients_df = ets.update_event_or_lof(patients_df)
122 |
123 | def test_update_event_or_lof_T_assertion(self):
124 | with self.assertRaises(AssertionError):
125 | seed = 0
126 | n_cov = 5
127 | n_patients = 10000
128 | np.random.seed(seed)
129 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
130 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]),
131 | columns=covariates)
132 |
133 | ets = EventTimesSampler(d_times=15, j_event_types=3)
134 | censoring_coef_dict = {
135 | "alpha": {
136 | 0: lambda t: -1 - 0.3 * np.log(t),
137 | },
138 | "beta": {
139 | 0: -np.log([0.8, 3, 3, 2.5, 2]),
140 | }
141 | }
142 |
143 | patients_df = ets.sample_hazard_lof_censoring(patients_df, censoring_coef_dict, seed=seed)
144 | patients_df = ets.update_event_or_lof(patients_df)
145 |
146 | def test_update_event_or_lof_C_assertion(self):
147 | with self.assertRaises(AssertionError):
148 | n_cov = 5
149 | n_patients = 1000
150 | seed = 0
151 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
152 | patients_df = pd.DataFrame(data=np.random.uniform(low=0.0, high=1.0, size=[n_patients, n_cov]),
153 | columns=covariates)
154 |
155 | ets = EventTimesSampler(d_times=15, j_event_types=3)
156 | real_coef_dict = {
157 | "alpha": {
158 | 1: lambda t: -1 - 0.3 * np.log(t),
159 | 2: lambda t: -1.75 - 0.15 * np.log(t),
160 | 3: lambda t: -1.75 - 0.15 * np.log(t),
161 | },
162 | "beta": {
163 | 1: -np.log([0.8, 3, 3, 2.5, 2]),
164 | 2: -np.log([1, 3, 4, 3, 2]),
165 | 3: -np.log([1, 3, 4, 3, 2]),
166 | }
167 | }
168 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
169 | patients_df = ets.update_event_or_lof(patients_df)
170 |
171 | # def test_sample_and_fit_from_multinormal(self):
172 | # # real_coef_dict = {
173 | # # "alpha": {
174 | # # 1: lambda t: -3 - 3 * np.log(t),
175 | # # 2: lambda t: -3 - 0.15 * np.log(t)
176 | # # },
177 | # # "beta": {
178 | # # # 1: [0.3, -1, -0.5, 0.7, 0.9],
179 | # # # 2: [-0.5, 0.5, 0.7, -0.7, 0.5]
180 | # # 1: -np.log([0.8, 3, 3, 2.5, 2]),
181 | # # 2: -np.log([1, 3, 4, 3, 2])
182 | # # }
183 | # # }
184 | #
185 | # real_coef_dict = {
186 | # "alpha": {
187 | # 1: lambda t: -1.75 - 0.3 * np.log(t),
188 | # 2: lambda t: -1.5 - 0.15 * np.log(t)
189 | # },
190 | # "beta": {
191 | # 1: -0.5*np.log([0.8, 3, 3, 2.5, 2]),
192 | # 2: -0.5*np.log([1, 3, 4, 3, 2])
193 | # }
194 | # }
195 | #
196 | # censoring_hazard_coef_dict = {
197 | # "alpha": {
198 | # 0: lambda t: -1.75 - 0.3 * np.log(t),
199 | # },
200 | # "beta": {
201 | # 0: -0.5*np.log([1, 3, 4, 3, 2]),
202 | # }
203 | # }
204 | #
205 | # n_patients = 15000
206 | # n_cov = 5
207 | # d_times = 50
208 | # j_events = 2
209 | # clip_value = 1
210 | # means_vector = np.zeros(n_cov)
211 | # covariance_matrix = 0.4*np.identity(n_cov)
212 | #
213 | # ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
214 | # seed = 0
215 | # covariates = [f'Z{i + 1}' for i in range(n_cov)]
216 | #
217 | # np.random.seed(seed)
218 | #
219 | # patients_df = pd.DataFrame(data=pd.DataFrame(
220 | # data=np.random.multivariate_normal(means_vector, covariance_matrix, size=n_patients),
221 | # columns=covariates))
222 | # # patients_df = pd.DataFrame(data=pd.DataFrame(
223 | # # data=np.random.uniform(0, 1, size=[n_patients, n_cov]),
224 | # # columns=covariates))
225 | #
226 | # patients_df.clip(lower=-1*clip_value, upper=clip_value, inplace=True)
227 | # patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
228 | # patients_df = ets.sample_hazard_lof_censoring(patients_df,
229 | # censoring_hazard_coefs=censoring_hazard_coef_dict,
230 | # seed=seed + 1, events=[0])
231 | # patients_df = ets.update_event_or_lof(patients_df)
232 | # patients_df.index.name = 'pid'
233 | # patients_df = patients_df.reset_index()
234 | #
235 | # # Two step fitter
236 | # new_fitter = TwoStagesFitter()
237 | # two_step_start = time()
238 | # new_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), nb_workers=1)
239 | # two_step_end = time()
240 | #
241 | # # Lee et al fitter
242 | # lee_fitter = DataExpansionFitter()
243 | # lee_start = time()
244 | # lee_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1))
245 | # lee_end = time()
246 | # lee_alpha_results = lee_fitter.get_alpha_df().loc[:,
247 | # slicer[:, [COEF_COL, STDERR_COL]]].unstack().to_frame()
248 | # lee_beta_results = lee_fitter.get_beta_SE().loc[:, slicer[:, [COEF_COL, STDERR_COL]]].unstack().to_frame()
249 | #
250 | # # Save results only if both fitters were successful
251 | # two_step_fit_time = two_step_end - two_step_start
252 | # lee_fit_time = lee_end - lee_start
253 | #
254 | # two_step_alpha_results = new_fitter.alpha_df[['J', 'X', 'alpha_jt']].set_index(['J', 'X'])
255 | # two_step_beta_results = new_fitter.get_beta_SE().unstack().to_frame()
256 | # print('x')
257 |
258 | # def test_sample_and_fit_normal(self):
259 | #
260 | # real_coef_dict = {
261 | # "alpha": {
262 | # 1: lambda t: -3 - 0.3 * np.log(t),
263 | # 2: lambda t: -3 - 0.15 * np.log(t)
264 | # },
265 | # "beta": {
266 | # 1: -np.log([0.8, 3, 3, 2.5, 2]),
267 | # 2: -np.log([1, 3, 4, 3, 2])
268 | # }
269 | # }
270 | #
271 | # n_patients = 10000
272 | # n_cov = 5
273 | # d_times = 50
274 | # j_events = 2
275 | #
276 | # ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
277 | # seed = 0
278 | # covariates = [f'Z{i + 1}' for i in range(n_cov)]
279 | #
280 | # np.random.seed(seed)
281 | #
282 | # patients_df = pd.DataFrame(data=pd.DataFrame(
283 | # data=np.random.uniform(0, 1, size=[n_patients, n_cov]),
284 | # columns=covariates))
285 | # patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
286 | # patients_df['X'] = patients_df['T']
287 | # patients_df['C'] = 51
288 | # patients_df.index.name = 'pid'
289 | # patients_df = patients_df.reset_index()
290 | #
291 | # # Two step fitter
292 | # new_fitter = TwoStagesFitter()
293 | # two_step_start = time()
294 | # new_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), nb_workers=1) # , x0=-3
295 | # two_step_end = time()
296 | #
297 | # # Lee et al fitter
298 | # lee_fitter = DataExpansionFitter()
299 | # lee_start = time()
300 | # lee_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1))
301 | # lee_end = time()
302 | # lee_alpha_results = lee_fitter.get_alpha_df().loc[:,
303 | # slicer[:, [COEF_COL, STDERR_COL]]].unstack().to_frame()
304 | # lee_beta_results = lee_fitter.get_beta_SE().loc[:, slicer[:, [COEF_COL, STDERR_COL]]].unstack().to_frame()
305 | #
306 | # # Save results only if both fitters were successful
307 | # two_step_fit_time = two_step_end - two_step_start
308 | # lee_fit_time = lee_end - lee_start
309 | #
310 | # two_step_alpha_results = new_fitter.alpha_df[['J', 'X', 'alpha_jt']].set_index(['J', 'X'])
311 | # two_step_beta_results = new_fitter.get_beta_SE().unstack().to_frame()
312 | # print('x')
313 |
314 |
315 | def test_raise_negative_values_overall_survival_assertion(self):
316 | with self.assertRaises(ValueError):
317 | real_coef_dict = {
318 | "alpha": {
319 | 1: lambda t: -9 + 3 * np.log(t),
320 | 2: lambda t: -7 + 2.5 * np.log(t)
321 | },
322 | "beta": {
323 | 1: [1.3, 1.7, -1.5, 0.5, 1.6],
324 | 2: [-1.5, 1.5, 1.8, -1, 1.2]
325 | }
326 | }
327 |
328 | censoring_hazard_coef_dict = {
329 | "alpha": {
330 | 0: lambda t: -8 + 2.1 * np.log(t),
331 | },
332 | "beta": {
333 | 0: [2, 1, -1.5, 1.5, -1.3],
334 | }
335 | }
336 |
337 | n_patients = 25000
338 | n_cov = 5
339 | d_times = 12
340 | j_events = 2
341 | means_vector = np.zeros(n_cov)
342 | covariance_matrix = np.identity(n_cov)
343 |
344 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
345 | seed = 0
346 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
347 |
348 | np.random.seed(seed)
349 |
350 | patients_df = pd.DataFrame(data=pd.DataFrame(
351 | data=np.random.multivariate_normal(means_vector, covariance_matrix, size=n_patients),
352 | columns=covariates))
353 |
354 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
355 |
356 |
--------------------------------------------------------------------------------
/tests/test_TwoStagesFitter.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import matplotlib.pyplot as plt
4 | from unittest.mock import patch
5 | from src.pydts.examples_utils.generate_simulations_data import generate_quick_start_df
6 | from src.pydts.fitters import TwoStagesFitter
7 | from src.pydts.utils import get_real_hazard
8 | import numpy as np
9 |
10 |
11 | class TestTwoStagesFitter(unittest.TestCase):
12 |
13 | def setUp(self):
14 | self.real_coef_dict = {
15 | "alpha": {
16 | 1: lambda t: -1 - 0.3 * np.log(t),
17 | 2: lambda t: -1.75 - 0.15 * np.log(t)
18 | },
19 | "beta": {
20 | 1: -np.log([0.8, 3, 3, 2.5, 2]),
21 | 2: -np.log([1, 3, 4, 3, 2])
22 | }
23 | }
24 | self.df = generate_quick_start_df(n_patients=5000, n_cov=5, d_times=10, j_events=2, pid_col='pid', seed=0,
25 | real_coef_dict=self.real_coef_dict, censoring_prob=.8)
26 | self.m = TwoStagesFitter()
27 | self.fitted_model = TwoStagesFitter()
28 | self.fitted_model.fit(df=self.df.drop(['C', 'T'], axis=1))
29 |
30 | def test_fit_case_event_col_not_in_df(self):
31 | # Event column (here 'J') must be passed in df to .fit()
32 | with self.assertRaises(AssertionError):
33 | self.m.fit(df=self.df.drop(['C', 'J', 'T'], axis=1))
34 |
35 | def test_fit_case_duration_col_not_in_df(self):
36 | # Duration column (here 'X') must be passed in df to .fit()
37 | with self.assertRaises(AssertionError):
38 | self.m.fit(df=self.df.drop(['C', 'X', 'T'], axis=1))
39 |
40 | def test_fit_case_pid_col_not_in_df(self):
41 | # Duration column (here 'pid') must be passed in df to .fit()
42 | with self.assertRaises(AssertionError):
43 | self.m.fit(df=self.df.drop(['C', 'pid', 'T'], axis=1))
44 |
45 | def test_fit_case_cov_col_not_in_df(self):
46 | # Covariates columns (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .fit()
47 | with self.assertRaises(ValueError):
48 | self.m.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=['Z6'])
49 |
50 | def test_fit_case_missing_jt(self):
51 | # df to .fit() should contain observations to all (event, times)
52 |
53 | # drop (events[1], times[2])
54 | tmp_df = self.df[
55 | ~((self.df[self.fitted_model.duration_col] == self.fitted_model.times[2]) &
56 | (self.df[self.fitted_model.event_type_col] == self.fitted_model.events[1]))
57 | ]
58 |
59 | with self.assertRaises(RuntimeError):
60 | self.m.fit(df=tmp_df.drop(['C', 'T'], axis=1))
61 |
62 | def test_fit_case_correct_fit(self):
63 | # Fit should be successful
64 | m = TwoStagesFitter()
65 | m.fit(df=self.df.drop(['C', 'T'], axis=1))
66 |
67 | def test_print_summary(self):
68 | self.fitted_model.print_summary()
69 |
70 | def test_plot_event_alpha_case_correct_event(self):
71 | self.fitted_model.plot_event_alpha(event=self.fitted_model.events[0], show=False)
72 |
73 | def test_plot_event_alpha_case_correct_event_and_show(self):
74 | with patch('matplotlib.pyplot.show') as p:
75 | self.fitted_model.plot_event_alpha(event=self.fitted_model.events[0], show=True)
76 |
77 | def test_plot_event_alpha_case_incorrect_event(self):
78 | with self.assertRaises(AssertionError):
79 | self.fitted_model.plot_event_alpha(event=100, show=False)
80 |
81 | def test_plot_all_events_alpha(self):
82 | self.fitted_model.plot_all_events_alpha(show=False)
83 |
84 | def test_plot_all_events_alpha_case_show(self):
85 | with patch('matplotlib.pyplot.show') as p:
86 | self.fitted_model.plot_all_events_alpha(show=True)
87 |
88 | def test_plot_all_events_alpha_case_axis_exits(self):
89 | fig, ax = plt.subplots()
90 | self.fitted_model.plot_all_events_alpha(show=False, ax=ax)
91 |
92 | def test_get_beta_SE(self):
93 | self.fitted_model.get_beta_SE()
94 |
95 | def test_get_alpha_df(self):
96 | self.fitted_model.get_alpha_df()
97 |
98 | def test_plot_all_events_beta(self):
99 | self.fitted_model.plot_all_events_beta(show=False)
100 |
101 | def test_plot_all_events_beta_case_show(self):
102 | with patch('matplotlib.pyplot.show') as p:
103 | self.fitted_model.plot_all_events_beta(show=True)
104 |
105 | def test_plot_all_events_beta_case_ax_exists(self):
106 | fig, ax = plt.subplots()
107 | self.fitted_model.plot_all_events_beta(show=False, ax=ax)
108 |
109 | def test_predict_hazard_jt_case_covariate_not_in_df(self):
110 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict()
111 | with self.assertRaises(AssertionError):
112 | self.fitted_model.predict_hazard_jt(
113 | df=self.df.drop(['C', 'T', 'Z1'], axis=1),
114 | event=self.fitted_model.events[0],
115 | t=self.fitted_model.times[0])
116 |
117 | def test_predict_hazard_jt_case_hazard_already_on_df(self):
118 | # print(self.df.drop(['C', 'T', 'X', 'J'], axis=1).set_index('pid').copy())
119 | df_temp = get_real_hazard(self.df.drop(['C', 'T', 'X', 'J'], axis=1).set_index('pid').copy(),
120 | real_coef_dict=self.real_coef_dict,
121 | times=self.fitted_model.times,
122 | events=self.fitted_model.events)
123 | assert (df_temp == self.fitted_model.predict_hazard_jt(df=df_temp,
124 | event=self.fitted_model.events[0],
125 | t=self.fitted_model.times
126 | )
127 | ).all().all()
128 |
129 | def test_hazard_transformation_result(self):
130 | from scipy.special import logit
131 | num = np.array([0.5])
132 | a = logit(num)
133 | print(a)
134 | assert (a == self.fitted_model._hazard_transformation(num)).all()
135 |
136 | def test_predict_hazard_jt_case_event_not_in_events(self):
137 | # Event passed to .predict() must be in fitted events
138 | with self.assertRaises(AssertionError):
139 | self.fitted_model.predict_hazard_jt(
140 | df=self.df.drop(['C', 'T'], axis=1), event=100, t=self.fitted_model.times[0])
141 |
142 | def test_predict_hazard_jt_case_time_not_in_times(self):
143 | # Event passed to .predict() must be in fitted events
144 | with self.assertRaises(AssertionError):
145 | self.fitted_model.predict_hazard_jt(
146 | df=self.df.drop(['C', 'T'], axis=1), event=self.fitted_model.events[0], t=1000)
147 |
148 | def test_predict_hazard_jt_case_successful_predict(self):
149 | self.fitted_model.predict_hazard_jt(
150 | df=self.df.drop(['C', 'T'], axis=1),
151 | event=self.fitted_model.events[0], t=self.fitted_model.times[0])
152 |
153 | def test_predict_hazard_t_case_successful_predict(self):
154 | self.fitted_model.predict_hazard_t(df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[:3])
155 |
156 | def test_predict_hazard_all_case_successful_predict(self):
157 | self.fitted_model.predict_hazard_all(df=self.df.drop(['C', 'T'], axis=1))
158 |
159 | def test_predict_overall_survival_case_successful_predict(self):
160 | self.fitted_model.predict_overall_survival(
161 | df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[5], return_hazards=True)
162 |
163 | def test_predict_prob_event_j_at_t_case_successful_predict(self):
164 | self.fitted_model.predict_prob_event_j_at_t(df=self.df.drop(['C', 'T'], axis=1),
165 | event=self.fitted_model.events[0],
166 | t=self.fitted_model.times[3])
167 |
168 | def test_predict_prob_event_j_all_case_successful_predict(self):
169 | self.fitted_model.predict_prob_event_j_all(df=self.df.drop(['C', 'T'], axis=1),
170 | event=self.fitted_model.events[0])
171 |
172 | def test_predict_prob_events_case_successful_predict(self):
173 | self.fitted_model.predict_prob_events(df=self.df.drop(['C', 'T'], axis=1))
174 |
175 | def test_predict_event_cumulative_incident_function_case_successful_predict(self):
176 | self.fitted_model.predict_event_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1),
177 | event=self.fitted_model.events[0])
178 |
179 | def test_predict_cumulative_incident_function_case_successful_predict(self):
180 | self.fitted_model.predict_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1))
181 |
182 | def test_predict_marginal_prob_function_case_successful(self):
183 | self.fitted_model.predict_marginal_prob_event_j(df=self.df.drop(columns=['C', 'T']),
184 | event=1)
185 |
186 | def test_predict_marginal_prob_function_case_event_not_exists(self):
187 | # Event passed to .predict_margnial() must be in fitted events
188 | with self.assertRaises(AssertionError):
189 | self.fitted_model.predict_marginal_prob_event_j(
190 | df=self.df.drop(['C', 'T'], axis=1), event=100)
191 |
192 | def test_predict_marginal_prob_function_cov_not_exists(self):
193 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict()
194 | with self.assertRaises(AssertionError):
195 | self.fitted_model.predict_marginal_prob_event_j(
196 | df=self.df.drop(columns=['C', 'T', 'Z1']),
197 | event=self.fitted_model.events[0])
198 |
199 | def test_predict_marginal_prob_all_events_function_successful(self):
200 | self.fitted_model.predict_marginal_prob_all_events(df=self.df.drop(columns=['C', 'T']))
201 |
202 | def test_predict_marginal_prob_all_events_cov_not_exits(self):
203 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict()
204 | with self.assertRaises(AssertionError):
205 | self.fitted_model.predict_marginal_prob_all_events(
206 | df=self.df.drop(columns=['C', 'T', 'Z1']))
207 |
208 | def test_alpha_jt_function_value(self):
209 | t = 1
210 | j = 1
211 | row = self.fitted_model.alpha_df.query("X == @t and J == @j")
212 | x = row['alpha_jt'].item()
213 | y_t = (self.df["X"]
214 | .value_counts()
215 | .sort_index(ascending=False) # each event count for its occurring time and the times before
216 | .cumsum()
217 | .sort_index()
218 | )
219 | rel_y_t = y_t.loc[t]
220 | rel_beta = self.fitted_model.beta_models[j].params_
221 | n_jt = row['n_jt']
222 | df = self.df.drop(columns=['C', 'T'])
223 | partial_df = df[df["X"] >= t]
224 | expit_add = np.dot(partial_df[self.fitted_model.covariates], rel_beta)
225 | from scipy.special import expit
226 | a_jt = ((1 / rel_y_t) * np.sum(expit(x + expit_add)) - (n_jt / rel_y_t)) ** 2
227 | a_jt_from_func = self.fitted_model._alpha_jt(x=x, df=df,
228 | y_t=rel_y_t, beta_j=rel_beta,
229 | n_jt=n_jt, t=t, event=j)
230 | self.assertEqual(a_jt.item(), a_jt_from_func.item())
231 |
232 | def test_predict_event_jt_case_t1_not_hazard(self):
233 | self.fitted_model.predict_prob_event_j_at_t(df=self.df.drop(['C', 'T'], axis=1),
234 | event=self.fitted_model.events[0],
235 | t=self.fitted_model.times[0])
236 |
237 | def test_predict_event_jt_case_t3_not_hazard(self):
238 | temp_df = self.df.drop(['C', 'T'], axis=1)
239 | temp_df = self.fitted_model.predict_overall_survival(df=temp_df,
240 | t=self.fitted_model.times[3],
241 | return_hazards=False)
242 | self.fitted_model.predict_prob_event_j_at_t(df=temp_df,
243 | event=self.fitted_model.events[0],
244 | t=self.fitted_model.times[3])
245 |
246 | def test_regularization_same_for_all_beta_models(self):
247 | L1_regularized_fitter = TwoStagesFitter()
248 |
249 | fit_beta_kwargs = {
250 | 'model_kwargs': {
251 | 'penalizer': 0.003,
252 | 'l1_ratio': 1
253 | }
254 | }
255 |
256 | L1_regularized_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)
257 |
258 | def test_regularization_different_to_each_beta_model(self):
259 | L1_regularized_fitter = TwoStagesFitter()
260 |
261 | fit_beta_kwargs = {
262 | 'model_kwargs': {
263 | 1: {
264 | 'penalizer': 0.003,
265 | 'l1_ratio': 1
266 | },
267 | 2: {
268 | 'penalizer': 0.005,
269 | 'l1_ratio': 1
270 | }
271 | }
272 | }
273 |
274 | L1_regularized_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)
275 |
276 | def test_different_covariates_to_each_beta_model(self):
277 | twostages_fitter = TwoStagesFitter()
278 | covariates = {
279 | 1: ['Z1', 'Z2', 'Z3'],
280 | 2: ['Z2', 'Z3', 'Z4', 'Z5']
281 | }
282 | twostages_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=covariates)
283 |
284 | def test_different_covariates_to_each_beta_model_prediction(self):
285 | twostages_fitter = TwoStagesFitter()
286 | covariates = {
287 | 1: ['Z1', 'Z2', 'Z3'],
288 | 2: ['Z2', 'Z3', 'Z4', 'Z5']
289 | }
290 | twostages_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=covariates)
291 | twostages_fitter.predict_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1))
292 |
293 |
--------------------------------------------------------------------------------
/tests/test_TwoStagesFitterExact.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import matplotlib.pyplot as plt
4 | from unittest.mock import patch
5 | from src.pydts.examples_utils.generate_simulations_data import generate_quick_start_df
6 | from src.pydts.fitters import TwoStagesFitterExact
7 | from src.pydts.data_generation import EventTimesSampler
8 | from src.pydts.utils import get_real_hazard
9 | import numpy as np
10 | import pandas as pd
11 |
12 |
13 | class TestTwoStagesFitterExact(unittest.TestCase):
14 |
15 | def setUp(self):
16 | self.real_coef_dict = {
17 | "alpha": {
18 | 1: lambda t: -1. + 0.4 * np.log(t),
19 | 2: lambda t: -1. + 0.4 * np.log(t),
20 | },
21 | "beta": {
22 | 1: -0.4*np.log([0.8, 3, 3, 2.5, 2]),
23 | 2: -0.3*np.log([1, 3, 4, 3, 2]),
24 | }
25 | }
26 | self.df = generate_quick_start_df(n_patients=300, n_cov=5, d_times=4, j_events=2, pid_col='pid', seed=0,
27 | real_coef_dict=self.real_coef_dict, censoring_prob=0.1)
28 |
29 | self.m = TwoStagesFitterExact()
30 | self.fitted_model = TwoStagesFitterExact()
31 | self.fitted_model.fit(df=self.df.drop(['C', 'T'], axis=1))
32 |
33 | def test_fit_case_event_col_not_in_df(self):
34 | # Event column (here 'J') must be passed in df to .fit()
35 | print(self.fitted_model.get_beta_SE())
36 |
37 |
38 | def test_fit_case_event_col_not_in_df(self):
39 | # Event column (here 'J') must be passed in df to .fit()
40 | with self.assertRaises(AssertionError):
41 | self.m.fit(df=self.df.drop(['C', 'J', 'T'], axis=1))
42 |
43 | def test_fit_case_duration_col_not_in_df(self):
44 | # Duration column (here 'X') must be passed in df to .fit()
45 | with self.assertRaises(AssertionError):
46 | self.m.fit(df=self.df.drop(['C', 'X', 'T'], axis=1))
47 |
48 | def test_fit_case_pid_col_not_in_df(self):
49 | # Duration column (here 'pid') must be passed in df to .fit()
50 | with self.assertRaises(AssertionError):
51 | self.m.fit(df=self.df.drop(['C', 'pid', 'T'], axis=1))
52 |
53 | def test_fit_case_cov_col_not_in_df(self):
54 | # Covariates columns (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .fit()
55 | with self.assertRaises(ValueError):
56 | self.m.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=['Z6'])
57 |
58 | def test_fit_case_missing_jt(self):
59 | # df to .fit() should contain observations to all (event, times)
60 |
61 | # drop (events[1], times[2])
62 | tmp_df = self.df[
63 | ~((self.df[self.fitted_model.duration_col] == self.fitted_model.times[2]) &
64 | (self.df[self.fitted_model.event_type_col] == self.fitted_model.events[1]))
65 | ]
66 |
67 | with self.assertRaises(RuntimeError):
68 | self.m.fit(df=tmp_df.drop(['C', 'T'], axis=1))
69 |
70 | def test_fit_case_correct_fit(self):
71 | # Fit should be successful
72 | m = TwoStagesFitterExact()
73 | m.fit(df=self.df.drop(['C', 'T'], axis=1))
74 |
75 | def test_print_summary(self):
76 | self.fitted_model.print_summary()
77 |
78 | def test_plot_event_alpha_case_correct_event(self):
79 | self.fitted_model.plot_event_alpha(event=self.fitted_model.events[0], show=False)
80 |
81 | def test_plot_event_alpha_case_correct_event_and_show(self):
82 | with patch('matplotlib.pyplot.show') as p:
83 | self.fitted_model.plot_event_alpha(event=self.fitted_model.events[0], show=True)
84 |
85 | def test_plot_event_alpha_case_incorrect_event(self):
86 | with self.assertRaises(AssertionError):
87 | self.fitted_model.plot_event_alpha(event=100, show=False)
88 |
89 | def test_plot_all_events_alpha(self):
90 | self.fitted_model.plot_all_events_alpha(show=False)
91 |
92 | def test_plot_all_events_alpha_case_show(self):
93 | with patch('matplotlib.pyplot.show') as p:
94 | self.fitted_model.plot_all_events_alpha(show=True)
95 |
96 | def test_plot_all_events_alpha_case_axis_exits(self):
97 | fig, ax = plt.subplots()
98 | self.fitted_model.plot_all_events_alpha(show=False, ax=ax)
99 |
100 | def test_get_beta_SE(self):
101 | self.fitted_model.get_beta_SE()
102 |
103 | def test_get_alpha_df(self):
104 | self.fitted_model.get_alpha_df()
105 |
106 | # def test_plot_all_events_beta(self):
107 | # self.fitted_model.plot_all_events_beta(show=False)
108 | #
109 | # def test_plot_all_events_beta_case_show(self):
110 | # with patch('matplotlib.pyplot.show') as p:
111 | # self.fitted_model.plot_all_events_beta(show=True)
112 | #
113 | # def test_plot_all_events_beta_case_ax_exists(self):
114 | # fig, ax = plt.subplots()
115 | # self.fitted_model.plot_all_events_beta(show=False, ax=ax)
116 |
117 | def test_predict_hazard_jt_case_covariate_not_in_df(self):
118 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict()
119 | with self.assertRaises(AssertionError):
120 | self.fitted_model.predict_hazard_jt(
121 | df=self.df.drop(['C', 'T', 'Z1'], axis=1),
122 | event=self.fitted_model.events[0],
123 | t=self.fitted_model.times[0])
124 |
125 | def test_predict_hazard_jt_case_hazard_already_on_df(self):
126 | # print(self.df.drop(['C', 'T', 'X', 'J'], axis=1).set_index('pid').copy())
127 | df_temp = get_real_hazard(self.df.drop(['C', 'T', 'X', 'J'], axis=1).set_index('pid').copy(),
128 | real_coef_dict=self.real_coef_dict,
129 | times=self.fitted_model.times,
130 | events=self.fitted_model.events)
131 | assert (df_temp == self.fitted_model.predict_hazard_jt(df=df_temp,
132 | event=self.fitted_model.events[0],
133 | t=self.fitted_model.times
134 | )
135 | ).all().all()
136 |
137 | def test_hazard_transformation_result(self):
138 | from scipy.special import logit
139 | num = np.array([0.5])
140 | a = logit(num)
141 | print(a)
142 | assert (a == self.fitted_model._hazard_transformation(num)).all()
143 |
144 | def test_predict_hazard_jt_case_event_not_in_events(self):
145 | # Event passed to .predict() must be in fitted events
146 | with self.assertRaises(AssertionError):
147 | self.fitted_model.predict_hazard_jt(
148 | df=self.df.drop(['C', 'T'], axis=1), event=100, t=self.fitted_model.times[0])
149 |
150 | def test_predict_hazard_jt_case_time_not_in_times(self):
151 | # Event passed to .predict() must be in fitted events
152 | with self.assertRaises(AssertionError):
153 | self.fitted_model.predict_hazard_jt(
154 | df=self.df.drop(['C', 'T'], axis=1), event=self.fitted_model.events[0], t=1000)
155 |
156 | def test_predict_hazard_jt_case_successful_predict(self):
157 | self.fitted_model.predict_hazard_jt(
158 | df=self.df.drop(['C', 'T'], axis=1),
159 | event=self.fitted_model.events[0], t=self.fitted_model.times[0])
160 |
161 | def test_predict_hazard_t_case_successful_predict(self):
162 | self.fitted_model.predict_hazard_t(df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[:3])
163 |
164 | def test_predict_hazard_all_case_successful_predict(self):
165 | self.fitted_model.predict_hazard_all(df=self.df.drop(['C', 'T'], axis=1))
166 |
167 | def test_predict_overall_survival_case_successful_predict(self):
168 | self.fitted_model.predict_overall_survival(
169 | df=self.df.drop(['C', 'T'], axis=1), t=self.fitted_model.times[3], return_hazards=True)
170 |
171 | def test_predict_prob_event_j_at_t_case_successful_predict(self):
172 | self.fitted_model.predict_prob_event_j_at_t(df=self.df.drop(['C', 'T'], axis=1),
173 | event=self.fitted_model.events[0],
174 | t=self.fitted_model.times[3])
175 |
176 | def test_predict_prob_event_j_all_case_successful_predict(self):
177 | self.fitted_model.predict_prob_event_j_all(df=self.df.drop(['C', 'T'], axis=1),
178 | event=self.fitted_model.events[0])
179 |
180 | def test_predict_prob_events_case_successful_predict(self):
181 | self.fitted_model.predict_prob_events(df=self.df.drop(['C', 'T'], axis=1))
182 |
183 | def test_predict_event_cumulative_incident_function_case_successful_predict(self):
184 | self.fitted_model.predict_event_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1),
185 | event=self.fitted_model.events[0])
186 |
187 | def test_predict_cumulative_incident_function_case_successful_predict(self):
188 | self.fitted_model.predict_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1))
189 |
190 | def test_predict_marginal_prob_function_case_successful(self):
191 | self.fitted_model.predict_marginal_prob_event_j(df=self.df.drop(columns=['C', 'T']),
192 | event=1)
193 |
194 | def test_predict_marginal_prob_function_case_event_not_exists(self):
195 | # Event passed to .predict_margnial() must be in fitted events
196 | with self.assertRaises(AssertionError):
197 | self.fitted_model.predict_marginal_prob_event_j(
198 | df=self.df.drop(['C', 'T'], axis=1), event=100)
199 |
200 | def test_predict_marginal_prob_function_cov_not_exists(self):
201 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict()
202 | with self.assertRaises(AssertionError):
203 | self.fitted_model.predict_marginal_prob_event_j(
204 | df=self.df.drop(columns=['C', 'T', 'Z1']),
205 | event=self.fitted_model.events[0])
206 |
207 | def test_predict_marginal_prob_all_events_function_successful(self):
208 | self.fitted_model.predict_marginal_prob_all_events(df=self.df.drop(columns=['C', 'T']))
209 |
210 | def test_predict_marginal_prob_all_events_cov_not_exits(self):
211 | # Covariates columns used in fit (here ['Z1','Z2','Z3','Z4','Z5']) must be passed in df to .predict()
212 | with self.assertRaises(AssertionError):
213 | self.fitted_model.predict_marginal_prob_all_events(
214 | df=self.df.drop(columns=['C', 'T', 'Z1']))
215 |
216 | def test_alpha_jt_function_value(self):
217 | t = 1
218 | j = 1
219 | row = self.fitted_model.alpha_df.query("X == @t and J == @j")
220 | x = row['alpha_jt'].item()
221 | y_t = (self.df["X"]
222 | .value_counts()
223 | .sort_index(ascending=False) # each event count for its occurring time and the times before
224 | .cumsum()
225 | .sort_index()
226 | )
227 | rel_y_t = y_t.loc[t]
228 | #rel_beta = self.fitted_model.beta_models[j].params_
229 | rel_beta = getattr(self.fitted_model.beta_models[j], self.fitted_model.beta_models_params_attr)
230 | n_jt = row['n_jt']
231 | df = self.df.drop(columns=['C', 'T'])
232 | partial_df = df[df["X"] >= t]
233 | expit_add = np.dot(partial_df[self.fitted_model.covariates], rel_beta)
234 | from scipy.special import expit
235 | a_jt = ((1 / rel_y_t) * np.sum(expit(x + expit_add)) - (n_jt / rel_y_t)) ** 2
236 | a_jt_from_func = self.fitted_model._alpha_jt(x=x, df=df,
237 | y_t=rel_y_t, beta_j=rel_beta,
238 | n_jt=n_jt, t=t, event=j)
239 | self.assertEqual(a_jt.item(), a_jt_from_func.item())
240 |
241 | def test_predict_event_jt_case_t1_not_hazard(self):
242 | self.fitted_model.predict_prob_event_j_at_t(df=self.df.drop(['C', 'T'], axis=1),
243 | event=self.fitted_model.events[0],
244 | t=self.fitted_model.times[0])
245 |
246 | def test_predict_event_jt_case_t3_not_hazard(self):
247 | temp_df = self.df.drop(['C', 'T'], axis=1)
248 | temp_df = self.fitted_model.predict_overall_survival(df=temp_df,
249 | t=self.fitted_model.times[3],
250 | return_hazards=False)
251 | self.fitted_model.predict_prob_event_j_at_t(df=temp_df,
252 | event=self.fitted_model.events[0],
253 | t=self.fitted_model.times[3])
254 |
255 | def test_regularization_same_for_all_beta_models(self):
256 | L1_regularized_fitter = TwoStagesFitterExact()
257 |
258 | fit_beta_kwargs = {
259 | 'model_kwargs': {
260 | 'alpha': 0.03,
261 | 'L1_wt': 1
262 | }
263 | }
264 |
265 | L1_regularized_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)
266 | print(L1_regularized_fitter.get_beta_SE())
267 |
268 | def test_regularization_different_to_each_beta_model(self):
269 | L1_regularized_fitter = TwoStagesFitterExact()
270 |
271 | fit_beta_kwargs = {
272 | 'model_fit_kwargs': {
273 | 1: {
274 | 'alpha': 0.003,
275 | 'L1_wt': 1
276 | },
277 | 2: {
278 | 'alpha': 0.005,
279 | 'L1_wt': 1
280 | }
281 | }
282 | }
283 |
284 | L1_regularized_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)
285 | # print(L1_regularized_fitter.get_beta_SE())
286 |
287 | L2_regularized_fitter = TwoStagesFitterExact()
288 |
289 | fit_beta_kwargs = {
290 | 'model_fit_kwargs': {
291 | 1: {
292 | 'alpha': 0.003,
293 | 'L1_wt': 0
294 | },
295 | 2: {
296 | 'alpha': 0.005,
297 | 'L1_wt': 0
298 | }
299 | }
300 | }
301 |
302 | L2_regularized_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)
303 | # print(L2_regularized_fitter.get_beta_SE())
304 |
305 | def test_different_covariates_to_each_beta_model(self):
306 | twostages_fitter = TwoStagesFitterExact()
307 | covariates = {
308 | 1: ['Z1', 'Z2', 'Z3'],
309 | 2: ['Z2', 'Z3', 'Z4', 'Z5']
310 | }
311 | twostages_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=covariates)
312 |
313 | def test_different_covariates_to_each_beta_model_prediction(self):
314 | twostages_fitter = TwoStagesFitterExact()
315 | covariates = {
316 | 1: ['Z1', 'Z2', 'Z3'],
317 | 2: ['Z2', 'Z3', 'Z4', 'Z5']
318 | }
319 | twostages_fitter.fit(df=self.df.drop(['C', 'T'], axis=1), covariates=covariates)
320 | twostages_fitter.predict_cumulative_incident_function(df=self.df.drop(['C', 'T'], axis=1))
321 |
--------------------------------------------------------------------------------
/tests/test_basefitter.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | from src.pydts.base_fitters import BaseFitter, ExpansionBasedFitter
6 | from src.pydts.examples_utils.generate_simulations_data import generate_quick_start_df
7 |
8 |
9 | class TestBaseFitter(unittest.TestCase):
10 | def setUp(self):
11 | self.real_coef_dict = {
12 | "alpha": {
13 | 1: lambda t: -1 - 0.3 * np.log(t),
14 | 2: lambda t: -1.75 - 0.15 * np.log(t)
15 | },
16 | "beta": {
17 | 1: -np.log([0.8, 3, 3, 2.5, 2]),
18 | 2: -np.log([1, 3, 4, 3, 2])
19 | }
20 | }
21 | self.df = generate_quick_start_df(n_patients=1000, n_cov=5, d_times=10, j_events=2, pid_col='pid', seed=0,
22 | real_coef_dict=self.real_coef_dict, censoring_prob=0.8)
23 |
24 | self.base_fitter = BaseFitter()
25 | self.expansion_fitter = ExpansionBasedFitter()
26 |
27 | def test_base_fit(self):
28 | with self.assertRaises(NotImplementedError):
29 | self.base_fitter.fit(self.df)
30 |
31 | def test_base_predict(self):
32 | with self.assertRaises(NotImplementedError):
33 | self.base_fitter.predict(self.df)
34 |
35 | def test_base_evaluate(self):
36 | with self.assertRaises(NotImplementedError):
37 | self.base_fitter.evaluate(self.df)
38 |
39 | def test_base_print_summary(self):
40 | with self.assertRaises(NotImplementedError):
41 | self.base_fitter.print_summary()
42 |
43 | def test_expansion_predict_hazard_not_implemented(self):
44 | with self.assertRaises(NotImplementedError):
45 | self.expansion_fitter.predict_hazard_jt(self.df, event=1, t=100)
--------------------------------------------------------------------------------
/tests/test_cross_validation.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import pandas as pd
4 | from src.pydts.data_generation import EventTimesSampler
5 | from src.pydts.cross_validation import TwoStagesCV, PenaltyGridSearchCV, TwoStagesCVExact, PenaltyGridSearchCVExact
6 |
7 |
8 | class TestCrossValidation(unittest.TestCase):
9 |
10 | def setUp(self):
11 | n_cov = 6
12 | beta1 = np.zeros(n_cov)
13 | beta1[:5] = (-0.5 * np.log([0.8, 3, 3, 2.5, 2]))
14 | beta2 = np.zeros(n_cov)
15 | beta2[:5] = (-0.5 * np.log([1, 3, 4, 3, 2]))
16 |
17 | real_coef_dict = {
18 | "alpha": {
19 | 1: lambda t: -2.0 - 0.2 * np.log(t),
20 | 2: lambda t: -2.2 - 0.2 * np.log(t)
21 | },
22 | "beta": {
23 | 1: beta1,
24 | 2: beta2
25 | }
26 | }
27 | n_patients = 1000
28 | d_times = 5
29 | j_events = 2
30 |
31 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
32 |
33 | seed = 0
34 | means_vector = np.zeros(n_cov)
35 | covariance_matrix = 0.5 * np.identity(n_cov)
36 | clip_value = 1
37 |
38 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
39 |
40 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix,
41 | size=n_patients),
42 | columns=covariates))
43 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True)
44 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
45 | patients_df.index.name = 'pid'
46 | patients_df = patients_df.reset_index()
47 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times))
48 | self.patients_df = ets.update_event_or_lof(patients_df)
49 | self.tscv = TwoStagesCV()
50 |
51 | def test_cross_validation_bs(self):
52 | self.tscv.cross_validate(self.patients_df, metrics='BS', n_splits=2)
53 |
54 | def test_cross_validation_auc(self):
55 | self.tscv.cross_validate(self.patients_df, metrics='AUC', n_splits=2)
56 |
57 | def test_cross_validation_iauc(self):
58 | self.tscv.cross_validate(self.patients_df, metrics='IAUC', n_splits=2)
59 |
60 | def test_cross_validation_gauc(self):
61 | self.tscv.cross_validate(self.patients_df, metrics='GAUC', n_splits=2)
62 |
63 | def test_cross_validation_ibs(self):
64 | self.tscv.cross_validate(self.patients_df, metrics='IBS', n_splits=2)
65 |
66 | def test_cross_validation_gbs(self):
67 | self.tscv.cross_validate(self.patients_df, metrics='GBS', n_splits=3)
68 |
69 |
70 | class TestCrossValidationExact(TestCrossValidation):
71 |
72 | def setUp(self):
73 | n_cov = 6
74 | beta1 = np.zeros(n_cov)
75 | beta1[:5] = (-0.25 * np.log([0.8, 3, 3, 2.5, 2]))
76 | beta2 = np.zeros(n_cov)
77 | beta2[:5] = (-0.25 * np.log([1, 3, 4, 3, 2]))
78 |
79 | real_coef_dict = {
80 | "alpha": {
81 | 1: lambda t: -1.7 - 0.2 * np.log(t),
82 | 2: lambda t: -1.8 - 0.2 * np.log(t)
83 | },
84 | "beta": {
85 | 1: beta1,
86 | 2: beta2
87 | }
88 | }
89 | n_patients = 500
90 | d_times = 4
91 | j_events = 2
92 |
93 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
94 |
95 | seed = 0
96 | means_vector = np.zeros(n_cov)
97 | covariance_matrix = 0.5 * np.identity(n_cov)
98 | clip_value = 1
99 |
100 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
101 |
102 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix,
103 | size=n_patients),
104 | columns=covariates))
105 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True)
106 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
107 | patients_df.index.name = 'pid'
108 | patients_df = patients_df.reset_index()
109 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times))
110 | self.patients_df = ets.update_event_or_lof(patients_df)
111 | self.tscv = TwoStagesCVExact()
112 |
113 |
114 | class TestPenaltyGridSearchCV(unittest.TestCase):
115 |
116 | def setUp(self):
117 | n_cov = 6
118 | beta1 = np.zeros(n_cov)
119 | beta1[:5] = (-0.5 * np.log([0.8, 3, 3, 2.5, 2]))
120 | beta2 = np.zeros(n_cov)
121 | beta2[:5] = (-0.5 * np.log([1, 3, 4, 3, 2]))
122 |
123 | real_coef_dict = {
124 | "alpha": {
125 | 1: lambda t: -2.0 - 0.2 * np.log(t),
126 | 2: lambda t: -2.2 - 0.2 * np.log(t)
127 | },
128 | "beta": {
129 | 1: beta1,
130 | 2: beta2
131 | }
132 | }
133 | n_patients = 2000
134 | d_times = 5
135 | j_events = 2
136 |
137 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
138 |
139 | seed = 0
140 | means_vector = np.zeros(n_cov)
141 | covariance_matrix = 0.5 * np.identity(n_cov)
142 | clip_value = 1
143 |
144 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
145 |
146 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix,
147 | size=n_patients),
148 | columns=covariates))
149 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True)
150 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
151 | patients_df.index.name = 'pid'
152 | patients_df = patients_df.reset_index()
153 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times))
154 | self.patients_df = ets.update_event_or_lof(patients_df)
155 | self.pgscv = PenaltyGridSearchCV()
156 |
157 | def test_penalty_grid_search_cross_validate(self):
158 | self.pgscv.cross_validate(full_df=self.patients_df,
159 | l1_ratio=1,
160 | n_splits=2,
161 | penalizers=[0.0001, 0.02],
162 | seed=0)
163 |
164 |
165 | class TestPenaltyGridSearchCVExact(TestPenaltyGridSearchCV):
166 |
167 | def setUp(self):
168 | n_cov = 6
169 | beta1 = np.zeros(n_cov)
170 | beta1[:5] = (-0.3 * np.log([0.8, 3, 3, 2.5, 2]))
171 | beta2 = np.zeros(n_cov)
172 | beta2[:5] = (-0.3 * np.log([1, 3, 4, 3, 2]))
173 |
174 | real_coef_dict = {
175 | "alpha": {
176 | 1: lambda t: -1.9 + 0.2 * np.log(t),
177 | 2: lambda t: -1.9 + 0.2 * np.log(t)
178 | },
179 | "beta": {
180 | 1: beta1,
181 | 2: beta2
182 | }
183 | }
184 | n_patients = 400
185 | d_times = 4
186 | j_events = 2
187 |
188 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
189 |
190 | seed = 0
191 | means_vector = np.zeros(n_cov)
192 | covariance_matrix = 0.5 * np.identity(n_cov)
193 | clip_value = 1
194 |
195 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
196 |
197 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix,
198 | size=n_patients),
199 | columns=covariates))
200 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True)
201 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
202 | patients_df.index.name = 'pid'
203 | patients_df = patients_df.reset_index()
204 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times))
205 | self.patients_df = ets.update_event_or_lof(patients_df)
206 | self.pgscv = PenaltyGridSearchCVExact()
207 |
--------------------------------------------------------------------------------
/tests/test_model_selection.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import pandas as pd
4 | from src.pydts.data_generation import EventTimesSampler
5 | from src.pydts.model_selection import PenaltyGridSearch, PenaltyGridSearchExact
6 | from sklearn.model_selection import train_test_split
7 | from src.pydts.fitters import TwoStagesFitter, TwoStagesFitterExact
8 |
9 |
10 | class TestPenaltyGridSearch(unittest.TestCase):
11 |
12 | def setUp(self):
13 | n_cov = 6
14 | beta1 = np.zeros(n_cov)
15 | beta1[:5] = (-0.5 * np.log([0.8, 3, 3, 2.5, 2]))
16 | beta2 = np.zeros(n_cov)
17 | beta2[:5] = (-0.5 * np.log([1, 3, 4, 3, 2]))
18 |
19 | real_coef_dict = {
20 | "alpha": {
21 | 1: lambda t: -2.0 - 0.2 * np.log(t),
22 | 2: lambda t: -2.2 - 0.2 * np.log(t)
23 | },
24 | "beta": {
25 | 1: beta1,
26 | 2: beta2
27 | }
28 | }
29 | n_patients = 1000
30 | d_times = 5
31 | j_events = 2
32 |
33 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
34 |
35 | seed = 0
36 | means_vector = np.zeros(n_cov)
37 | covariance_matrix = 0.5 * np.identity(n_cov)
38 | clip_value = 1
39 |
40 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
41 |
42 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix,
43 | size=n_patients),
44 | columns=covariates))
45 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True)
46 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
47 | patients_df.index.name = 'pid'
48 | patients_df = patients_df.reset_index()
49 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times))
50 | self.patients_df = ets.update_event_or_lof(patients_df)
51 | self.pgs = PenaltyGridSearch()
52 |
53 | def test_get_mixed_two_stages_fitter(self):
54 | train_df, test_df = train_test_split(self.patients_df.drop(['C', 'T'], axis=1), random_state=0)
55 | self.pgs.evaluate(train_df=train_df,
56 | test_df=test_df,
57 | l1_ratio=1,
58 | metrics=[],
59 | penalizers=[0.005, 0.02],
60 | seed=0)
61 | mixed_two_stages = self.pgs.get_mixed_two_stages_fitter([0.005, 0.02])
62 |
63 | mixed_params = pd.concat([mixed_two_stages.beta_models[1].params_,
64 | mixed_two_stages.beta_models[2].params_], axis=1)
65 |
66 | fit_beta_kwargs = {
67 | 'model_kwargs': {
68 | 1: {'penalizer': 0.005, 'l1_ratio': 1},
69 | 2: {'penalizer': 0.02, 'l1_ratio': 1},
70 | }
71 | }
72 |
73 | two_stages_fitter = TwoStagesFitter()
74 | two_stages_fitter.fit(df=train_df, fit_beta_kwargs=fit_beta_kwargs)
75 |
76 | two_stages_params = pd.concat([two_stages_fitter.beta_models[1].params_,
77 | two_stages_fitter.beta_models[2].params_], axis=1)
78 |
79 | pd.testing.assert_frame_equal(mixed_params, two_stages_params)
80 | pd.testing.assert_frame_equal(mixed_two_stages.alpha_df, two_stages_fitter.alpha_df)
81 |
82 | def test_assertion_get_mixed_two_stages_fitter_not_included(self):
83 | with self.assertRaises(AssertionError):
84 |
85 | train_df, test_df = train_test_split(self.patients_df.drop(['C', 'T'], axis=1), random_state=0)
86 | self.pgs.evaluate(train_df=train_df,
87 | test_df=test_df,
88 | l1_ratio=1,
89 | metrics=[],
90 | penalizers=[0.005, 0.02],
91 | seed=0)
92 |
93 | mixed_two_stages = self.pgs.get_mixed_two_stages_fitter([0.1, 0.02])
94 |
95 | def test_assertion_get_mixed_two_stages_fitter_empty(self):
96 | with self.assertRaises(AssertionError):
97 | mixed_two_stages = self.pgs.get_mixed_two_stages_fitter([0.001, 0.02])
98 |
99 | def test_evaluate(self):
100 | train_df, test_df = train_test_split(self.patients_df.drop(['C', 'T'], axis=1), random_state=0)
101 | idx_max = self.pgs.evaluate(train_df=train_df, test_df=test_df, l1_ratio=1,
102 | penalizers=[0.0001, 0.005, 0.02],
103 | seed=0)
104 |
105 | def test_convert_results_dict_to_df(self):
106 | train_df, test_df = train_test_split(self.patients_df.drop(['C', 'T'], axis=1), random_state=0)
107 | idx_max = self.pgs.evaluate(train_df=train_df, test_df=test_df, l1_ratio=1,
108 | penalizers=[0.0001, 0.005],
109 | seed=0)
110 | self.pgs.convert_results_dict_to_df(self.pgs.global_bs)
111 |
112 |
113 | class TestPenaltyGridSearchExact(TestPenaltyGridSearch):
114 |
115 | def setUp(self):
116 | n_cov = 6
117 | beta1 = np.zeros(n_cov)
118 | beta1[:5] = (-0.5 * np.log([0.8, 3, 3, 2.5, 2]))
119 | beta2 = np.zeros(n_cov)
120 | beta2[:5] = (-0.5 * np.log([1, 3, 4, 3, 2]))
121 |
122 | real_coef_dict = {
123 | "alpha": {
124 | 1: lambda t: -2.0 - 0.2 * np.log(t),
125 | 2: lambda t: -2.2 - 0.2 * np.log(t)
126 | },
127 | "beta": {
128 | 1: beta1,
129 | 2: beta2
130 | }
131 | }
132 | n_patients = 300
133 | d_times = 4
134 | j_events = 2
135 |
136 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
137 |
138 | seed = 0
139 | means_vector = np.zeros(n_cov)
140 | covariance_matrix = 0.5 * np.identity(n_cov)
141 | clip_value = 1
142 |
143 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
144 |
145 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix,
146 | size=n_patients),
147 | columns=covariates))
148 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True)
149 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
150 | patients_df.index.name = 'pid'
151 | patients_df = patients_df.reset_index()
152 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times))
153 | self.patients_df = ets.update_event_or_lof(patients_df)
154 | self.pgs = PenaltyGridSearchExact()
155 |
156 | def test_get_mixed_two_stages_fitter(self):
157 | train_df, test_df = train_test_split(self.patients_df.drop(['C', 'T'], axis=1), random_state=0)
158 | self.pgs.evaluate(train_df=train_df,
159 | test_df=test_df,
160 | l1_ratio=1,
161 | metrics=[],
162 | penalizers=[0.005, 0.02],
163 | seed=0)
164 | mixed_two_stages = self.pgs.get_mixed_two_stages_fitter([0.005, 0.02])
165 |
166 | mixed_params = pd.concat([mixed_two_stages.beta_models[1].params,
167 | mixed_two_stages.beta_models[2].params], axis=1)
168 |
169 |
170 | fit_beta_kwargs = {
171 | 'model_fit_kwargs': {
172 | 1: {
173 | 'alpha': 0.005,
174 | 'L1_wt': 1
175 | },
176 | 2: {
177 | 'alpha': 0.02,
178 | 'L1_wt': 1
179 | }
180 | }
181 | }
182 |
183 | two_stages_fitter = TwoStagesFitterExact()
184 | two_stages_fitter.fit(df=train_df, fit_beta_kwargs=fit_beta_kwargs)
185 |
186 | two_stages_params = pd.concat([two_stages_fitter.beta_models[1].params,
187 | two_stages_fitter.beta_models[2].params], axis=1)
188 |
189 | pd.testing.assert_frame_equal(mixed_params, two_stages_params)
190 | pd.testing.assert_frame_equal(mixed_two_stages.alpha_df, two_stages_fitter.alpha_df)
191 |
--------------------------------------------------------------------------------
/tests/test_repetative_fitter.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | #
3 | # import numpy as np
4 | #
5 | # from src.pydts.fitters import repetitive_fitters
6 | #
7 | #
8 | class TestRepFitters(unittest.TestCase):
9 | def setUp(self):
10 | pass
11 | # self.real_coef_dict = {
12 | # "alpha": {
13 | # 1: lambda t: -1 - 0.3 * np.log(t),
14 | # 2: lambda t: -1.75 - 0.15 * np.log(t)
15 | # },
16 | # "beta": {
17 | # 1: -np.log([0.8, 3, 3, 2.5, 2]),
18 | # 2: -np.log([1, 3, 4, 3, 2])
19 | # }
20 | # }
21 | #
22 | # self.n_patients = 10000
23 | # self.n_cov = 5
24 | # self.d = 15
25 | #
26 | # # def test_fit_function_case_successful(self):
27 | # # _ = repetitive_fitters(rep=5, n_patients=self.n_patients, n_cov=self.n_cov,
28 | # # d_times=self.d, j_events=2, pid_col='pid', verbose=0,
29 | # # allow_fails=20, real_coef_dict=self.real_coef_dict,
30 | # # censoring_prob=.8)
31 | # #
32 | # # def test_fit_not_sending_coef(self):
33 | # # # event where fit are sent without real coefficient dict
34 | # # with self.assertRaises(AssertionError):
35 | # # _ = repetitive_fitters(rep=5, n_patients=self.n_patients, n_cov=self.n_cov,
36 | # # d_times=self.d, j_events=2, pid_col='pid', verbose=0,
37 | # # allow_fails=20, censoring_prob=.8)
38 | # #
39 | # # def test_fit_repetitive_function_case_j_event_not_equal_to_real_coef(self):
40 | # # # event where fit are sent with wrong j_events, causing except to print it,
41 | # # # but not deal with value error in the end
42 | # # with self.assertRaises(ValueError):
43 | # # _ = repetitive_fitters(rep=2, n_patients=self.n_patients, n_cov=self.n_cov,
44 | # # d_times=self.d, j_events=3, pid_col='pid', verbose=0,
45 | # # allow_fails=0, real_coef_dict=self.real_coef_dict,
46 | # # censoring_prob=.8)
47 | # #
48 | # # def test_fit_function_case_second_model_is_not_twoStages(self):
49 | # # from src.pydts.fitters import DataExpansionFitter
50 | # # _ = repetitive_fitters(rep=2, n_patients=self.n_patients, n_cov=self.n_cov,
51 | # # d_times=self.d, j_events=2, pid_col='pid',
52 | # # model2=DataExpansionFitter, verbose=0,
53 | # # allow_fails=20, real_coef_dict=self.real_coef_dict,
54 | # # censoring_prob=.8)
--------------------------------------------------------------------------------
/tests/test_screening.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from src.pydts.data_generation import EventTimesSampler
3 | from src.pydts.screening import SISTwoStagesFitterExact, SISTwoStagesFitter, get_expanded_df
4 | import numpy as np
5 | import pandas as pd
6 |
7 |
8 | class TestScreening(unittest.TestCase):
9 |
10 | def setUp(self):
11 | n_cov = 50
12 | beta1 = np.zeros(n_cov)
13 | beta1[:5] = np.array([-0.6, 0.5, -0.5, 0.6, -0.6])
14 | beta2 = np.zeros(n_cov)
15 | beta2[:5] = np.array([0.5, -0.7, 0.7, -0.5, -0.7])
16 |
17 | real_coef_dict = {
18 | "alpha": {
19 | 1: lambda t: -2.6 + 0.1 * np.log(t),
20 | 2: lambda t: -2.7 + 0.2 * np.log(t)
21 | },
22 | "beta": {
23 | 1: beta1,
24 | 2: beta2
25 | }
26 | }
27 |
28 | n_patients = 400
29 | d_times = 7
30 | j_events = 2
31 |
32 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
33 |
34 | seed = 2
35 | means_vector = np.zeros(n_cov)
36 | covariance_matrix = np.identity(n_cov)
37 |
38 | clip_value = 3
39 |
40 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
41 |
42 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix,
43 | size=n_patients),
44 | columns=covariates))
45 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True)
46 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
47 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times),
48 | seed=seed + 1)
49 | patients_df = ets.update_event_or_lof(patients_df)
50 | patients_df.index.name = 'pid'
51 | self.patients_df = patients_df.reset_index()
52 | self.covariates = covariates
53 | self.fitter = SISTwoStagesFitter()
54 |
55 | def test_psis_permute_df(self):
56 | self.fitter.permute_df(df=self.patients_df)
57 |
58 | def test_psis_fit_marginal_model(self):
59 | expanded_df = get_expanded_df(self.patients_df.drop(['C', 'T'], axis=1))
60 | self.fitter.fit_marginal_model(expanded_df, covariate='Z1')
61 |
62 | def test_psis_get_marginal_estimates(self):
63 | expanded_df = get_expanded_df(self.patients_df.drop(['C', 'T'], axis=1))
64 | self.fitter.get_marginal_estimates(expanded_df)
65 |
66 | def test_psis_get_data_driven_treshold(self):
67 | self.fitter.get_data_driven_threshold(df=self.patients_df.drop(['C', 'T'], axis=1))
68 |
69 | def test_psis_fit_data_driven_threshold(self):
70 | self.fitter.fit(df=self.patients_df.drop(['C', 'T'], axis=1), quantile=0.95)
71 |
72 | def test_psis_fit_user_defined_threshold(self):
73 | self.fitter.fit(df=self.patients_df.drop(['C', 'T'], axis=1), threshold=0.15)
74 |
75 | def test_psis_covs_dict(self):
76 | with self.assertRaises(ValueError):
77 | self.fitter.fit(df=self.patients_df.drop(['C', 'T'], axis=1),
78 | covariates={1: self.covariates[:-3], 2: self.covariates[:-8]})
79 |
80 |
81 | class TestScreeningExact(TestScreening):
82 |
83 | def setUp(self):
84 | n_cov = 30
85 | beta1 = np.zeros(n_cov)
86 | beta1[:5] = np.array([-0.6, 0.5, -0.5, 0.6, -0.6])
87 | beta2 = np.zeros(n_cov)
88 | beta2[:5] = np.array([0.5, -0.7, 0.7, -0.5, -0.7])
89 |
90 | real_coef_dict = {
91 | "alpha": {
92 | 1: lambda t: -3.1 + 0.1 * np.log(t),
93 | 2: lambda t: -3.2 + 0.2 * np.log(t)
94 | },
95 | "beta": {
96 | 1: beta1,
97 | 2: beta2
98 | }
99 | }
100 |
101 | n_patients = 400
102 | d_times = 7
103 | j_events = 2
104 |
105 | ets = EventTimesSampler(d_times=d_times, j_event_types=j_events)
106 |
107 | seed = 2
108 | means_vector = np.zeros(n_cov)
109 | covariance_matrix = np.identity(n_cov)
110 |
111 | clip_value = 3
112 |
113 | covariates = [f'Z{i + 1}' for i in range(n_cov)]
114 |
115 | patients_df = pd.DataFrame(data=pd.DataFrame(data=np.random.multivariate_normal(means_vector, covariance_matrix,
116 | size=n_patients),
117 | columns=covariates))
118 | patients_df.clip(lower=-1 * clip_value, upper=clip_value, inplace=True)
119 | patients_df = ets.sample_event_times(patients_df, hazard_coefs=real_coef_dict, seed=seed)
120 | patients_df = ets.sample_independent_lof_censoring(patients_df, prob_lof_at_t=0.01 * np.ones(d_times),
121 | seed=seed + 1)
122 | patients_df = ets.update_event_or_lof(patients_df)
123 | patients_df.index.name = 'pid'
124 | self.patients_df = patients_df.reset_index()
125 | self.covariates = covariates
126 | self.fitter = SISTwoStagesFitterExact()
127 |
--------------------------------------------------------------------------------