├── .github
└── workflows
│ ├── publish_pypi.yml
│ └── pythonpackage.yml
├── .gitignore
├── LICENSE
├── README.md
├── examples
├── 01_introduction.ipynb
├── 02_introduction.ipynb
├── 03_network_architectures.ipynb
├── 04_mnist_dataloaders_cnn.ipynb
├── README.md
├── administrative_brier_score.ipynb
├── cox-cc.ipynb
├── cox-ph.ipynb
├── cox-time.ipynb
├── deephit.ipynb
├── deephit_competing_risks.ipynb
├── mtlr.ipynb
├── pc-hazard.ipynb
└── pmf.ipynb
├── figures
└── logo.svg
├── pycox
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── _dataset_loader.py
│ ├── from_deepsurv.py
│ ├── from_kkbox.py
│ ├── from_rdatasets.py
│ └── from_simulations.py
├── evaluation
│ ├── __init__.py
│ ├── admin.py
│ ├── concordance.py
│ ├── eval_surv.py
│ ├── ipcw.py
│ └── metrics.py
├── models
│ ├── __init__.py
│ ├── base.py
│ ├── bce_surv.py
│ ├── cox.py
│ ├── cox_cc.py
│ ├── cox_time.py
│ ├── data.py
│ ├── deephit.py
│ ├── interpolation.py
│ ├── logistic_hazard.py
│ ├── loss.py
│ ├── mtlr.py
│ ├── pc_hazard.py
│ ├── pmf.py
│ └── utils.py
├── preprocessing
│ ├── __init__.py
│ ├── discretization.py
│ ├── feature_transforms.py
│ └── label_transforms.py
├── simulations
│ ├── __init__.py
│ ├── base.py
│ ├── discrete_logit_hazard.py
│ └── relative_risk.py
└── utils.py
├── requirements-dev.txt
├── setup.cfg
├── setup.py
└── tests
├── evaluation
└── test_admin.py
├── models
├── test_bce_surv.py
├── test_cox.py
├── test_cox_cc.py
├── test_cox_time.py
├── test_deephit.py
├── test_interpolation.py
├── test_logistic_hazard.py
├── test_loss.py
├── test_models_utils.py
├── test_mtlr.py
├── test_pc_hazard.py
├── test_pmf.py
└── utils_model_testing.py
└── test_utils.py
/.github/workflows/publish_pypi.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python Package to PyPI
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Set up Python
13 | uses: actions/setup-python@v5
14 | with:
15 | python-version: 3.9
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install --upgrade pip
19 | pip install setuptools wheel twine
20 | - name: Build
21 | run: |
22 | python setup.py sdist bdist_wheel
23 | - name: Publish to PyPI
24 | uses: pypa/gh-action-pypi-publish@master
25 | with:
26 | password: ${{ secrets.PYPI_PASSWORD }}
27 |
--------------------------------------------------------------------------------
/.github/workflows/pythonpackage.yml:
--------------------------------------------------------------------------------
1 | name: Python package
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 |
9 | jobs:
10 | build_test:
11 | name: Test on ${{ matrix.config.os }} with Python ${{ matrix.python-version }}
12 | runs-on: ${{ matrix.config.os }}
13 | strategy:
14 | max-parallel: 4
15 | fail-fast: false
16 | matrix:
17 | python-version: ['3.8', '3.9', '3.10']
18 | config:
19 | - { os: ubuntu-latest, torch-version: "torch --index-url https://download.pytorch.org/whl/cpu"}
20 | - { os: windows-latest, torch-version: "torch"}
21 | - { os: macOS-latest, torch-version: "torch"}
22 | steps:
23 | - uses: actions/checkout@v4
24 | - name: Set up Python ${{ matrix.python-version }}
25 | uses: actions/setup-python@v5
26 | with:
27 | python-version: ${{ matrix.python-version }}
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install ${{ matrix.config.torch-version }}
32 | # python setup.py install
33 | pip install .
34 | pip install -r requirements-dev.txt
35 | - name: Lint with flake8
36 | run: |
37 | pip install flake8
38 | # stop the build if there are Python syntax errors or undefined names
39 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
40 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
41 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
42 | - name: Test with pytest
43 | run: |
44 | pytest
45 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # datasets
104 | /datasets/
105 |
106 | # vscode
107 | .vscode/
108 |
109 | # torch files
110 | *.torch
111 |
112 | # data files
113 | /pycox/datasets/data/
114 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2018, Haavard Kvamme
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | The notebooks in this directory describe how the methods can be used.
4 |
5 |
--------------------------------------------------------------------------------
/figures/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/pycox/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """Top-level package for pycox."""
4 |
5 | __author__ = """Haavard Kvamme"""
6 | __email__ = 'haavard.kvamme@gmail.com'
7 | __version__ = '0.3.0'
8 |
9 | import pycox.datasets
10 | import pycox.evaluation
11 | import pycox.preprocessing
12 | import pycox.simulations
13 | import pycox.utils
14 | import pycox.models
15 |
--------------------------------------------------------------------------------
/pycox/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from pycox.datasets import from_deepsurv
2 | from pycox.datasets import from_rdatasets
3 | from pycox.datasets import from_kkbox
4 | from pycox.datasets import from_simulations
5 |
6 |
7 | support = from_deepsurv._Support()
8 | metabric = from_deepsurv._Metabric()
9 | gbsg = from_deepsurv._Gbsg()
10 | flchain = from_rdatasets._Flchain()
11 | nwtco = from_rdatasets._Nwtco()
12 | kkbox_v1 = from_kkbox._DatasetKKBoxChurn()
13 | kkbox = from_kkbox._DatasetKKBoxAdmin()
14 | sac3 = from_simulations._SAC3()
15 | rr_nl_nhp = from_simulations._RRNLNPH()
16 | sac_admin5 = from_simulations._SACAdmin5()
17 |
--------------------------------------------------------------------------------
/pycox/datasets/_dataset_loader.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import pandas as pd
3 | import pycox
4 | import os
5 |
6 | _DATA_OVERRIDE = os.environ.get('PYCOX_DATA_DIR', None)
7 | if _DATA_OVERRIDE:
8 | _PATH_DATA = Path(_DATA_OVERRIDE)
9 | else:
10 | _PATH_ROOT = Path(pycox.__file__).parent
11 | _PATH_DATA = _PATH_ROOT / 'datasets' / 'data'
12 | _PATH_DATA.mkdir(parents=True, exist_ok=True)
13 |
14 | class _DatasetLoader:
15 | """Abstract class for loading data sets.
16 | """
17 | name = NotImplemented
18 | _checksum = None
19 |
20 | def __init__(self):
21 | self.path = _PATH_DATA / f"{self.name}.feather"
22 |
23 | def read_df(self):
24 | if not self.path.exists():
25 | print(f"Dataset '{self.name}' not locally available. Downloading...")
26 | self._download()
27 | print(f"Done")
28 | df = pd.read_feather(self.path)
29 | df = self._label_cols_at_end(df)
30 | return df
31 |
32 | def _download(self):
33 | raise NotImplementedError
34 |
35 | def delete_local_copy(self):
36 | if not self.path.exists():
37 | raise RuntimeError("File does not exists.")
38 | self.path.unlink()
39 |
40 | def _label_cols_at_end(self, df):
41 | if hasattr(self, 'col_duration') and hasattr(self, 'col_event'):
42 | col_label = [self.col_duration, self.col_event]
43 | df = df[list(df.columns.drop(col_label)) + col_label]
44 | return df
45 |
46 | def checksum(self):
47 | """Checks that the dataset is correct.
48 |
49 | Returns:
50 | bool -- If the check passed.
51 | """
52 | if self._checksum is None:
53 | raise NotImplementedError("No available comparison for this dataset.")
54 | df = self.read_df()
55 | return self._checksum_df(df)
56 |
57 | def _checksum_df(self, df):
58 | if self._checksum is None:
59 | raise NotImplementedError("No available comparison for this dataset.")
60 | import hashlib
61 | val = get_checksum(df)
62 | return val == self._checksum
63 |
64 |
65 | def get_checksum(df):
66 | import hashlib
67 | val = hashlib.sha256(df.to_csv().encode()).hexdigest()
68 | return val
69 |
--------------------------------------------------------------------------------
/pycox/datasets/from_deepsurv.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import requests
3 | import h5py
4 | import pandas as pd
5 | from pycox.datasets._dataset_loader import _DatasetLoader
6 |
7 |
8 | class _DatasetDeepSurv(_DatasetLoader):
9 | _dataset_url = "https://raw.githubusercontent.com/jaredleekatzman/DeepSurv/master/experiments/data/"
10 | _datasets = {
11 | 'support': "support/support_train_test.h5",
12 | 'metabric': "metabric/metabric_IHC4_clinical_train_test.h5",
13 | 'gbsg': "gbsg/gbsg_cancer_train_test.h5",
14 | }
15 | col_duration = 'duration'
16 | col_event = 'event'
17 | def _download(self):
18 | url = self._dataset_url + self._datasets[self.name]
19 | path = self.path.parent / f"{self.name}.h5"
20 | with requests.Session() as s:
21 | r = s.get(url)
22 | with open(path, 'wb') as f:
23 | f.write(r.content)
24 |
25 | data = defaultdict(dict)
26 | with h5py.File(path) as f:
27 | for ds in f:
28 | for array in f[ds]:
29 | data[ds][array] = f[ds][array][:]
30 |
31 | path.unlink()
32 | train = _make_df(data['train'])
33 | test = _make_df(data['test'])
34 | df = pd.concat([train, test]).reset_index(drop=True)
35 | df.to_feather(self.path)
36 |
37 |
38 | def _make_df(data):
39 | x = data['x']
40 | t = data['t']
41 | d = data['e']
42 |
43 | colnames = ['x'+str(i) for i in range(x.shape[1])]
44 | df = (pd.DataFrame(x, columns=colnames)
45 | .assign(duration=t)
46 | .assign(event=d))
47 | return df
48 |
49 |
50 | class _Support(_DatasetDeepSurv):
51 | """Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT).
52 |
53 | A study of survival for seriously ill hospitalized adults.
54 |
55 | This is the processed data set used in the DeepSurv paper (Katzman et al. 2018), and details
56 | can be found at https://doi.org/10.1186/s12874-018-0482-1
57 |
58 | See https://github.com/jaredleekatzman/DeepSurv/tree/master/experiments/data
59 | for original data.
60 |
61 | Variables:
62 | x0, ..., x13:
63 | numerical covariates.
64 | duration: (duration)
65 | the right-censored event-times.
66 | event: (event)
67 | event indicator {1: event, 0: censoring}.
68 | """
69 | name = 'support'
70 | _checksum = 'b07a9d216bf04501e832084e5b7955cb84dfef834810037c548dee82ea251f8d'
71 |
72 |
73 | class _Metabric(_DatasetDeepSurv):
74 | """The Molecular Taxonomy of Breast Cancer International Consortium (METABRIC).
75 |
76 | Gene and protein expression profiles to determine new breast cancer subgroups in
77 | order to help physicians provide better treatment recommendations.
78 |
79 | This is the processed data set used in the DeepSurv paper (Katzman et al. 2018), and details
80 | can be found at https://doi.org/10.1186/s12874-018-0482-1
81 |
82 | See https://github.com/jaredleekatzman/DeepSurv/tree/master/experiments/data
83 | for original data.
84 |
85 | Variables:
86 | x0, ..., x8:
87 | numerical covariates.
88 | duration: (duration)
89 | the right-censored event-times.
90 | event: (event)
91 | event indicator {1: event, 0: censoring}.
92 | """
93 | name = 'metabric'
94 | _checksum = '310b74b97cc37c9eddd29f253ae3c06015dc63a17a71e4a68ff339dbe265f417'
95 |
96 |
97 | class _Gbsg(_DatasetDeepSurv):
98 | """ Rotterdam & German Breast Cancer Study Group (GBSG)
99 |
100 | A combination of the Rotterdam tumor bank and the German Breast Cancer Study Group.
101 |
102 | This is the processed data set used in the DeepSurv paper (Katzman et al. 2018), and details
103 | can be found at https://doi.org/10.1186/s12874-018-0482-1
104 |
105 | See https://github.com/jaredleekatzman/DeepSurv/tree/master/experiments/data
106 | for original data.
107 |
108 | Variables:
109 | x0, ..., x6:
110 | numerical covariates.
111 | duration: (duration)
112 | the right-censored event-times.
113 | event: (event)
114 | event indicator {1: event, 0: censoring}.
115 | """
116 | name = 'gbsg'
117 | _checksum = 'de2359bee62bf36b9e3f901fea4a9fbef2d145e26e9384617d0d3f75892fe5ce'
118 |
--------------------------------------------------------------------------------
/pycox/datasets/from_rdatasets.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from pycox.datasets._dataset_loader import _DatasetLoader
3 |
4 | def download_from_rdatasets(package, name):
5 | datasets = (pd.read_csv("https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/datasets.csv")
6 | .loc[lambda x: x['Package'] == package].set_index('Item'))
7 | if not name in datasets.index:
8 | raise ValueError(f"Dataset {name} not found.")
9 | info = datasets.loc[name]
10 | url = info.CSV
11 | return pd.read_csv(url), info
12 |
13 |
14 | class _DatasetRdatasetsSurvival(_DatasetLoader):
15 | """Data sets from Rdataset survival.
16 | """
17 | def _download(self):
18 | df, info = download_from_rdatasets('survival', self.name)
19 | self.info = info
20 | df.to_feather(self.path)
21 |
22 |
23 | class _Flchain(_DatasetRdatasetsSurvival):
24 | """Assay of serum free light chain (FLCHAIN).
25 | Obtained from Rdatasets (https://github.com/vincentarelbundock/Rdatasets).
26 |
27 | A study of the relationship between serum free light chain (FLC) and mortality.
28 | The original sample contains samples on approximately 2/3 of the residents of Olmsted
29 | County aged 50 or greater.
30 |
31 | For details see http://vincentarelbundock.github.io/Rdatasets/doc/survival/flchain.html
32 |
33 | Variables:
34 | age:
35 | age in years.
36 | sex:
37 | F=female, M=male.
38 | sample.yr:
39 | the calendar year in which a blood sample was obtained.
40 | kappa:
41 | serum free light chain, kappa portion.
42 | lambda:
43 | serum free light chain, lambda portion.
44 | flc.grp:
45 | the FLC group for the subject, as used in the original analysis.
46 | creatinine:
47 | serum creatinine.
48 | mgus:
49 | 1 if the subject had been diagnosed with monoclonal gammapothy (MGUS).
50 | futime: (duration)
51 | days from enrollment until death. Note that there are 3 subjects whose sample
52 | was obtained on their death date.
53 | death: (event)
54 | 0=alive at last contact date, 1=dead.
55 | chapter:
56 | for those who died, a grouping of their primary cause of death by chapter headings
57 | of the International Code of Diseases ICD-9.
58 |
59 | """
60 | name = 'flchain'
61 | col_duration = 'futime'
62 | col_event = 'death'
63 | _checksum = 'ec12748a1aa5790457c09793387337bb03b1dc45a22a2d58a8c2b9ad1f2648dd'
64 |
65 | def read_df(self, processed=True):
66 | """Get dataset.
67 |
68 | If 'processed' is False, return the raw data set.
69 | See the code for processing.
70 |
71 | Keyword Arguments:
72 | processed {bool} -- If 'False' get raw data, else get processed (see '??flchain.read_df').
73 | (default: {True})
74 | """
75 | df = super().read_df()
76 | if processed:
77 | df = (df
78 | .drop(['chapter', 'Unnamed: 0'], axis=1)
79 | .loc[lambda x: x['creatinine'].isna() == False]
80 | .reset_index(drop=True)
81 | .assign(sex=lambda x: (x['sex'] == 'M')))
82 |
83 | categorical = ['sample.yr', 'flc.grp']
84 | for col in categorical:
85 | df[col] = df[col].astype('category')
86 | for col in df.columns.drop(categorical):
87 | df[col] = df[col].astype('float32')
88 | return df
89 |
90 |
91 | class _Nwtco(_DatasetRdatasetsSurvival):
92 | """Data from the National Wilm's Tumor Study (NWTCO)
93 | Obtained from Rdatasets (https://github.com/vincentarelbundock/Rdatasets).
94 |
95 | Measurement error example. Tumor histology predicts survival, but prediction is stronger
96 | with central lab histology than with the local institution determination.
97 |
98 | For details see http://vincentarelbundock.github.io/Rdatasets/doc/survival/nwtco.html
99 |
100 | Variables:
101 | seqno:
102 | id number
103 | instit:
104 | histology from local institution
105 | histol:
106 | histology from central lab
107 | stage:
108 | disease stage
109 | study:
110 | study
111 | rel: (event)
112 | indicator for relapse
113 | edrel: (duration)
114 | time to relapse
115 | age:
116 | age in months
117 | in.subcohort:
118 | included in the subcohort for the example in the paper
119 |
120 | References
121 | NE Breslow and N Chatterjee (1999), Design and analysis of two-phase studies with binary
122 | outcome applied to Wilms tumor prognosis. Applied Statistics 48, 457–68.
123 | """
124 | name = 'nwtco'
125 | col_duration = 'edrel'
126 | col_event = 'rel'
127 | _checksum = '5aa3de698dadb60154dd59196796e382739ff56dc6cbd39cfc2fda50d69d118e'
128 |
129 | def read_df(self, processed=True):
130 | """Get dataset.
131 |
132 | If 'processed' is False, return the raw data set.
133 | See the code for processing.
134 |
135 | Keyword Arguments:
136 | processed {bool} -- If 'False' get raw data, else get processed (see '??nwtco.read_df').
137 | (default: {True})
138 | """
139 | df = super().read_df()
140 | if processed:
141 | df = (df
142 | .assign(instit_2=df['instit'] - 1,
143 | histol_2=df['histol'] - 1,
144 | study_4=df['study'] - 3,
145 | stage=df['stage'].astype('category'))
146 | .drop(['Unnamed: 0', 'seqno', 'instit', 'histol', 'study'], axis=1))
147 | for col in df.columns.drop('stage'):
148 | df[col] = df[col].astype('float32')
149 | df = self._label_cols_at_end(df)
150 | return df
151 |
--------------------------------------------------------------------------------
/pycox/datasets/from_simulations.py:
--------------------------------------------------------------------------------
1 | """Make dataset from the simulations, so we don't have to compute over again.
2 | """
3 |
4 | import numpy as np
5 | import pandas as pd
6 | from pycox import simulations
7 | from pycox.datasets._dataset_loader import _DatasetLoader
8 |
9 | class _SimDataset(_DatasetLoader):
10 | col_duration = 'duration'
11 | col_event = 'event'
12 | cols_true = ['duration_true', 'censoring_true']
13 |
14 | def read_df(self, add_true=True):
15 | if not self.path.exists():
16 | print(f"Dataset '{self.name}' not created yet. Making dataset...")
17 | self._simulate_data()
18 | print(f"Done")
19 | df = super().read_df()
20 | if add_true is False:
21 | df = self._drop_true(df)
22 | return df
23 |
24 | def _simulate_data(self):
25 | raise NotImplementedError
26 |
27 | def _download(self):
28 | raise NotImplementedError("There is no `_download` for simulated data.")
29 |
30 | def _drop_true(self, df):
31 | return df.drop(columns=self.cols_true)
32 |
33 |
34 | class _RRNLNPH(_SimDataset):
35 | """Dataset from simulation study in "Time-to-Event Prediction with Neural
36 | Networks and Cox Regression" [1].
37 |
38 | This is a continuous-time simulation study with event times drawn from a
39 | relative risk non-linear non-proportional hazards model (RRNLNPH).
40 | The full details are given in the paper [1].
41 |
42 | The dataset is created with `pycox.simulations.SimStudyNonLinearNonPH` (see
43 | `rr_nl_nph._simulate_data`).
44 |
45 | Variables:
46 | x0, x1, x2:
47 | numerical covariates.
48 | duration: (duration)
49 | the right-censored event-times.
50 | event: (event)
51 | event indicator {1: event, 0: censoring}.
52 | duration_true:
53 | the uncensored event times.
54 | event_true:
55 | if `duration_true` is an event.
56 | censoring_true:
57 | the censoring times.
58 |
59 | To generate more data:
60 | >>> from pycox.simulations import SimStudyNonLinearNonPH
61 | >>> n = 10000
62 | >>> sim = SimStudyNonLinearNonPH()
63 | >>> data = sim.simulate(n)
64 | >>> df = sim.dict2df(data, True)
65 |
66 | References:
67 | [1] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel.
68 | Time-to-event prediction with neural networks and Cox regression.
69 | Journal of Machine Learning Research, 20(129):1–30, 2019.
70 | http://jmlr.org/papers/v20/18-424.html
71 | """
72 | name = 'rr_nl_nph'
73 | _checksum = '4952a8712403f7222d1bec58e36cdbfcd46aa31ddf87c5fb2c455565fc3f7068'
74 |
75 | def _simulate_data(self):
76 | np.random.seed(1234)
77 | sim = simulations.SimStudyNonLinearNonPH()
78 | data = sim.simulate(25000)
79 | df = sim.dict2df(data, True)
80 | df.to_feather(self.path)
81 |
82 |
83 | class _SAC3(_SimDataset):
84 | """Dataset from simulation study in "Continuous and Discrete-Time Survival Prediction
85 | with Neural Networks" [1].
86 |
87 | The dataset is created with `pycox.simulations.SimStudySACConstCensor`
88 | (see `sac3._simulate_data`).
89 |
90 | The full details are given in Appendix A.1 in [1].
91 |
92 | Variables:
93 | x0, ..., x44:
94 | numerical covariates.
95 | duration: (duration)
96 | the right-censored event-times.
97 | event: (event)
98 | event indicator {1: event, 0: censoring}.
99 | duration_true:
100 | the uncensored event times (only censored at max-time 100.)
101 | event_true:
102 | if `duration_true` is an event.
103 | censoring_true:
104 | the censoring times.
105 |
106 | To generate more data:
107 | >>> from pycox.simulations import SimStudySACCensorConst
108 | >>> n = 10000
109 | >>> sim = SimStudySACCensorConst()
110 | >>> data = sim.simulate(n)
111 | >>> df = sim.dict2df(data, True, False)
112 |
113 | References:
114 | [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction
115 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019.
116 | https://arxiv.org/pdf/1910.06724.pdf
117 | """
118 | name = 'sac3'
119 | _checksum = '2941d46baf0fbae949933565dc88663adbf1d8f5a58f989baf915d6586641fea'
120 |
121 | def _simulate_data(self):
122 | np.random.seed(1234)
123 | sim = simulations.SimStudySACCensorConst()
124 | data = sim.simulate(100000)
125 | df = sim.dict2df(data, True, False)
126 | df.to_feather(self.path)
127 |
128 |
129 | class _SACAdmin5(_SimDataset):
130 | """Dataset from simulation study in [1].
131 | The survival function is the same as in sac3, but the censoring is administrative
132 | and determined by five covariates.
133 |
134 | Variables:
135 | x0, ..., x22:
136 | numerical covariates.
137 | duration: (duration)
138 | the right-censored event-times.
139 | event: (event)
140 | event indicator {1: event, 0: censoring}.
141 | duration_true:
142 | the uncensored event times (only censored at max-time 100.)
143 | event_true:
144 | if `duration_true` is an event or right-censored at time 100.
145 | censoring_true:
146 | the censoring times.
147 |
148 | To generate more data:
149 | >>> from pycox.simulations import SimStudySACAdmin
150 | >>> n = 10000
151 | >>> sim = SimStudySACAdmin()
152 | >>> data = sim.simulate(n)
153 | >>> df = sim.dict2df(data, True, True)
154 |
155 | References:
156 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems
157 | and Solutions. arXiv preprint arXiv:1912.08581, 2019.
158 | https://arxiv.org/pdf/1912.08581.pdf
159 | """
160 | name = 'sac_admin5'
161 | _checksum = '9882bc8651315bcd80cba20b5f11040d71e4a84865898d7c2ca7b82ccba56683'
162 |
163 | def _simulate_data(self):
164 | np.random.seed(1234)
165 | sim = simulations.SimStudySACAdmin()
166 | data = sim.simulate(50000)
167 | df = sim.dict2df(data, True, True)
168 | df.to_feather(self.path)
169 |
--------------------------------------------------------------------------------
/pycox/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from pycox.evaluation.eval_surv import EvalSurv
3 | # from pycox.evaluation import binomial_log_likelihood, brier_score,\
4 | # integrated_binomial_log_likelihood, integrated_brier_score
5 | # from pycox.evaluation.concordance import concordance_td
6 | # from pycox.evaluation.km_inverce_censor_weight import binomial_log_likelihood_km, brier_score_km,\
7 | # integrated_binomial_log_likelihood_km_numpy, integrated_brier_score_km_numpy
8 |
--------------------------------------------------------------------------------
/pycox/evaluation/admin.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import numba
4 | from pycox.utils import idx_at_times
5 |
6 |
7 | def administrative_scores(func):
8 | if not func.__class__.__module__.startswith('numba'):
9 | raise ValueError("Need to provide numba compiled function")
10 | def metric(time_grid, durations, durations_c, events, surv, index_surv, reduce=True, steps_surv='post'):
11 | if not hasattr(time_grid, '__iter__'):
12 | time_grid = np.array([time_grid])
13 | assert (type(time_grid) is type(durations) is type(events) is type(surv) is
14 | type(index_surv) is type(durations_c) is np.ndarray), 'Need all input to be np.ndarrays'
15 | assert (durations[events == 0] == durations_c[events == 0]).all(), 'Censored observations need same `durations` and `durations_c`'
16 | assert (durations[events == 1] <= durations_c[events == 1]).all(), '`durations` cannot be larger than `durations_c`'
17 | idx_ts_surv = idx_at_times(index_surv, time_grid, steps_surv, assert_sorted=True)
18 | scores, norm = _admin_scores(func, time_grid, durations, durations_c, events, surv, idx_ts_surv)
19 | if reduce is True:
20 | return scores.sum(axis=1) / norm
21 | return scores, norm.reshape(-1, 1)
22 | return metric
23 |
24 | @numba.njit(parallel=True)
25 | def _admin_scores(func, time_grid, durations, durations_c, events, surv, idx_ts_surv):
26 | def _single(func, ts, durations, durations_c, events, surv, idx_ts_surv_i,
27 | scores, n_indiv):
28 | for i in range(n_indiv):
29 | tt = durations[i]
30 | tc = durations_c[i]
31 | d = events[i]
32 | s = surv[idx_ts_surv_i, i]
33 | scores[i] = func(ts, tt, tc, d, s)
34 |
35 | n_times = len(time_grid)
36 | n_indiv = len(durations)
37 | scores = np.empty((n_times, n_indiv))
38 | scores.fill(np.nan)
39 | normalizer = np.empty(n_times)
40 | normalizer.fill(np.nan)
41 | for i in numba.prange(n_times):
42 | ts = time_grid[i]
43 | idx_ts_surv_i = idx_ts_surv[i]
44 | scores_i = scores[i]
45 | normalizer[i] = (durations_c >= ts).sum()
46 | _single(func, ts, durations, durations_c, events, surv, idx_ts_surv_i, scores_i, n_indiv)
47 | return scores, normalizer
48 |
49 | @numba.njit
50 | def _brier_score(ts, tt, tc, d, s):
51 | if (tt <= ts) and (d == 1) and (tc >= ts):
52 | return np.power(s, 2)
53 | if tt >= ts:
54 | return np.power(1 - s, 2)
55 | return 0.
56 |
57 | @numba.njit
58 | def _binomial_log_likelihood(ts, tt, tc, d, s, eps=1e-7):
59 | if s < eps:
60 | s = eps
61 | elif s > (1 - eps):
62 | s = 1 - eps
63 | if (tt <= ts) and (d == 1) and (tc >= ts):
64 | return np.log(1 - s)
65 | if tt >= ts:
66 | return np.log(s)
67 | return 0.
68 |
69 | brier_score = administrative_scores(_brier_score)
70 | binomial_log_likelihood = administrative_scores(_binomial_log_likelihood)
71 |
72 |
73 | def _integrated_admin_metric(func):
74 | def metric(time_grid, durations, durations_c, events, surv, index_surv, steps_surv='post'):
75 | scores = func(time_grid, durations, durations_c, events, surv, index_surv, True, steps_surv)
76 | integral = scipy.integrate.simps(scores, time_grid)
77 | return integral / (time_grid[-1] - time_grid[0])
78 | return metric
79 |
80 | integrated_brier_score = _integrated_admin_metric(brier_score)
81 | integrated_binomial_log_likelihood = _integrated_admin_metric(binomial_log_likelihood)
82 |
83 |
--------------------------------------------------------------------------------
/pycox/evaluation/concordance.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import pandas as pd
4 | import numba
5 |
6 |
7 | @numba.jit(nopython=True)
8 | def _is_comparable(t_i, t_j, d_i, d_j):
9 | return ((t_i < t_j) & d_i) | ((t_i == t_j) & (d_i | d_j))
10 |
11 | @numba.jit(nopython=True)
12 | def _is_comparable_antolini(t_i, t_j, d_i, d_j):
13 | return ((t_i < t_j) & d_i) | ((t_i == t_j) & d_i & (d_j == 0))
14 |
15 | @numba.jit(nopython=True)
16 | def _is_concordant(s_i, s_j, t_i, t_j, d_i, d_j):
17 | conc = 0.
18 | if t_i < t_j:
19 | conc = (s_i < s_j) + (s_i == s_j) * 0.5
20 | elif t_i == t_j:
21 | if d_i & d_j:
22 | conc = 1. - (s_i != s_j) * 0.5
23 | elif d_i:
24 | conc = (s_i < s_j) + (s_i == s_j) * 0.5 # different from RSF paper.
25 | elif d_j:
26 | conc = (s_i > s_j) + (s_i == s_j) * 0.5 # different from RSF paper.
27 | return conc * _is_comparable(t_i, t_j, d_i, d_j)
28 |
29 | @numba.jit(nopython=True)
30 | def _is_concordant_antolini(s_i, s_j, t_i, t_j, d_i, d_j):
31 | return (s_i < s_j) & _is_comparable_antolini(t_i, t_j, d_i, d_j)
32 |
33 | @numba.jit(nopython=True, parallel=True)
34 | def _sum_comparable(t, d, is_comparable_func):
35 | n = t.shape[0]
36 | count = 0.
37 | for i in numba.prange(n):
38 | for j in range(n):
39 | if j != i:
40 | count += is_comparable_func(t[i], t[j], d[i], d[j])
41 | return count
42 |
43 | @numba.jit(nopython=True, parallel=True)
44 | def _sum_concordant(s, t, d):
45 | n = len(t)
46 | count = 0.
47 | for i in numba.prange(n):
48 | for j in range(n):
49 | if j != i:
50 | count += _is_concordant(s[i, i], s[i, j], t[i], t[j], d[i], d[j])
51 | return count
52 |
53 | @numba.jit(nopython=True, parallel=True)
54 | def _sum_concordant_disc(s, t, d, s_idx, is_concordant_func):
55 | n = len(t)
56 | count = 0
57 | for i in numba.prange(n):
58 | idx = s_idx[i]
59 | for j in range(n):
60 | if j != i:
61 | count += is_concordant_func(s[idx, i], s[idx, j], t[i], t[j], d[i], d[j])
62 | return count
63 |
64 | def concordance_td(durations, events, surv, surv_idx, method='adj_antolini'):
65 | """Time dependent concorance index from
66 | Antolini, L.; Boracchi, P.; and Biganzoli, E. 2005. A timedependent discrimination
67 | index for survival data. Statistics in Medicine 24:3927–3944.
68 |
69 | If 'method' is 'antolini', the concordance from Antolini et al. is computed.
70 |
71 | If 'method' is 'adj_antolini' (default) we have made a small modifications
72 | for ties in predictions and event times.
73 | We have followed step 3. in Sec 5.1. in Random Survial Forests paper, except for the last
74 | point with "T_i = T_j, but not both are deaths", as that doesn't make much sense.
75 | See '_is_concordant'.
76 |
77 | Arguments:
78 | durations {np.array[n]} -- Event times (or censoring times.)
79 | events {np.array[n]} -- Event indicators (0 is censoring).
80 | surv {np.array[n_times, n]} -- Survival function (each row is a duraratoin, and each col
81 | is an individual).
82 | surv_idx {np.array[n_test]} -- Mapping of survival_func s.t. 'surv_idx[i]' gives index in
83 | 'surv' corresponding to the event time of individual 'i'.
84 |
85 | Keyword Arguments:
86 | method {str} -- Type of c-index 'antolini' or 'adj_antolini' (default {'adj_antolini'}).
87 |
88 | Returns:
89 | float -- Time dependent concordance index.
90 | """
91 | if np.isfortran(surv):
92 | surv = np.array(surv, order='C')
93 | assert durations.shape[0] == surv.shape[1] == surv_idx.shape[0] == events.shape[0]
94 | assert type(durations) is type(events) is type(surv) is type(surv_idx) is np.ndarray
95 | if events.dtype in ('float', 'float32'):
96 | events = events.astype('int32')
97 | if method == 'adj_antolini':
98 | is_concordant = _is_concordant
99 | is_comparable = _is_comparable
100 | return (_sum_concordant_disc(surv, durations, events, surv_idx, is_concordant) /
101 | _sum_comparable(durations, events, is_comparable))
102 | elif method == 'antolini':
103 | is_concordant = _is_concordant_antolini
104 | is_comparable = _is_comparable_antolini
105 | return (_sum_concordant_disc(surv, durations, events, surv_idx, is_concordant) /
106 | _sum_comparable(durations, events, is_comparable))
107 | return ValueError(f"Need 'method' to be e.g. 'antolini', got '{method}'.")
108 |
109 |
--------------------------------------------------------------------------------
/pycox/evaluation/eval_surv.py:
--------------------------------------------------------------------------------
1 |
2 | import warnings
3 | import numpy as np
4 | import pandas as pd
5 | from pycox.evaluation.concordance import concordance_td
6 | from pycox.evaluation import ipcw, admin
7 | from pycox import utils
8 |
9 |
10 | class EvalSurv:
11 | """Class for evaluating predictions.
12 |
13 | Arguments:
14 | surv {pd.DataFrame} -- Survival predictions.
15 | durations {np.array} -- Durations of test set.
16 | events {np.array} -- Events of test set.
17 |
18 | Keyword Arguments:
19 | censor_surv {str, pd.DataFrame, EvalSurv} -- Censoring distribution.
20 | If provided data frame (survival function for censoring) or EvalSurv object,
21 | this will be used.
22 | If 'km', we will fit a Kaplan-Meier to the dataset.
23 | (default: {None})
24 | censor_durations {np.array}: -- Administrative censoring times. (default: {None})
25 | steps {str} -- For durations between values of `surv.index` choose the higher index 'pre'
26 | or lower index 'post'. For a visualization see `help(EvalSurv.steps)`. (default: {'post'})
27 | """
28 | def __init__(self, surv, durations, events, censor_surv=None, censor_durations=None, steps='post'):
29 | assert (type(durations) == type(events) == np.ndarray), 'Need `durations` and `events` to be arrays'
30 | self.surv = surv
31 | self.durations = durations
32 | self.events = events
33 | self.censor_surv = censor_surv
34 | self.censor_durations = censor_durations
35 | self.steps = steps
36 | assert pd.Series(self.index_surv).is_monotonic_increasing
37 |
38 | @property
39 | def censor_surv(self):
40 | """Estimated survival for censorings.
41 | Also an EvalSurv object.
42 | """
43 | return self._censor_surv
44 |
45 | @censor_surv.setter
46 | def censor_surv(self, censor_surv):
47 | if isinstance(censor_surv, EvalSurv):
48 | self._censor_surv = censor_surv
49 | elif type(censor_surv) is str:
50 | if censor_surv == 'km':
51 | self.add_km_censor()
52 | else:
53 | raise ValueError(f"censor_surv cannot be {censor_surv}. Use e.g. 'km'")
54 | elif censor_surv is not None:
55 | self.add_censor_est(censor_surv)
56 | else:
57 | self._censor_surv = None
58 |
59 | @property
60 | def index_surv(self):
61 | return self.surv.index.values
62 |
63 | @property
64 | def steps(self):
65 | """How to handle predictions that are between two indexes in `index_surv`.
66 |
67 | For a visualization, run the following:
68 | ev = EvalSurv(pd.DataFrame(np.linspace(1, 0, 7)), np.empty(7), np.ones(7), steps='pre')
69 | ax = ev[0].plot_surv()
70 | ev.steps = 'post'
71 | ev[0].plot_surv(ax=ax, style='--')
72 | ax.legend(['pre', 'post'])
73 | """
74 | return self._steps
75 |
76 | @steps.setter
77 | def steps(self, steps):
78 | vals = ['post', 'pre']
79 | if steps not in vals:
80 | raise ValueError(f"`steps` needs to be {vals}, got {steps}")
81 | self._steps = steps
82 |
83 | def add_censor_est(self, censor_surv, steps='post'):
84 | """Add censoring estimates so one can use inverse censoring weighting.
85 | `censor_surv` are the survival estimates trained on (durations, 1-events),
86 |
87 | Arguments:
88 | censor_surv {pd.DataFrame} -- Censor survival curves.
89 |
90 | Keyword Arguments:
91 | round {str} -- For durations between values of `surv.index` choose the higher index 'pre'
92 | or lower index 'post'. If `None` use `self.steps` (default: {None})
93 | """
94 | if not isinstance(censor_surv, EvalSurv):
95 | censor_surv = self._constructor(censor_surv, self.durations, 1-self.events, None,
96 | steps=steps)
97 | self.censor_surv = censor_surv
98 | return self
99 |
100 | def add_km_censor(self, steps='post'):
101 | """Add censoring estimates obtained by Kaplan-Meier on the test set
102 | (durations, 1-events).
103 | """
104 | km = utils.kaplan_meier(self.durations, 1-self.events)
105 | surv = pd.DataFrame(np.repeat(km.values.reshape(-1, 1), len(self.durations), axis=1),
106 | index=km.index)
107 | return self.add_censor_est(surv, steps)
108 |
109 | @property
110 | def censor_durations(self):
111 | """Administrative censoring times."""
112 | return self._censor_durations
113 |
114 | @censor_durations.setter
115 | def censor_durations(self, val):
116 | if val is not None:
117 | assert (self.durations[self.events == 0] == val[self.events == 0]).all(),\
118 | 'Censored observations need same `durations` and `censor_durations`'
119 | assert (self.durations[self.events == 1] <= val[self.events == 1]).all(),\
120 | '`durations` cannot be larger than `censor_durations`'
121 | if (self.durations == val).all():
122 | warnings.warn("`censor_durations` are equal to `durations`." +
123 | " `censor_durations` are likely wrong!")
124 | self._censor_durations = val
125 | else:
126 | self._censor_durations = val
127 |
128 | @property
129 | def _constructor(self):
130 | return EvalSurv
131 |
132 | def __getitem__(self, index):
133 | if not (hasattr(index, '__iter__') or type(index) is slice) :
134 | index = [index]
135 | surv = self.surv.iloc[:, index]
136 | durations = self.durations[index]
137 | events = self.events[index]
138 | new = self._constructor(surv, durations, events, None, steps=self.steps)
139 | if self.censor_surv is not None:
140 | new.censor_surv = self.censor_surv[index]
141 | return new
142 |
143 | def plot_surv(self, **kwargs):
144 | """Plot survival estimates.
145 | kwargs are passed to `self.surv.plot`.
146 | """
147 | if len(self.durations) > 50:
148 | raise RuntimeError("We don't allow to plot more than 50 lines. Use e.g. `ev[1:5].plot()`")
149 | if 'drawstyle' in kwargs:
150 | raise RuntimeError(f"`drawstyle` is set by `self.steps`. Remove from **kwargs")
151 | return self.surv.plot(drawstyle=f"steps-{self.steps}", **kwargs)
152 |
153 | def idx_at_times(self, times):
154 | """Get the index (iloc) of the `surv.index` closest to `times`.
155 | I.e. surv.loc[tims] (almost)= surv.iloc[idx_at_times(times)].
156 |
157 | Useful for finding predictions at given durations.
158 | """
159 | return utils.idx_at_times(self.index_surv, times, self.steps)
160 |
161 | def _duration_idx(self):
162 | return self.idx_at_times(self.durations)
163 |
164 | def surv_at_times(self, times):
165 | idx = self.idx_at_times(times)
166 | return self.surv.iloc[idx]
167 |
168 | # def prob_alive(self, time_grid):
169 | # return self.surv_at_times(time_grid).values
170 |
171 | def concordance_td(self, method='adj_antolini'):
172 | """Time dependent concorance index from
173 | Antolini, L.; Boracchi, P.; and Biganzoli, E. 2005. A time-dependent discrimination
174 | index for survival data. Statistics in Medicine 24:3927–3944.
175 |
176 | If 'method' is 'antolini', the concordance from Antolini et al. is computed.
177 |
178 | If 'method' is 'adj_antolini' (default) we have made a small modifications
179 | for ties in predictions and event times.
180 | We have followed step 3. in Sec 5.1. in Random Survival Forests paper, except for the last
181 | point with "T_i = T_j, but not both are deaths", as that doesn't make much sense.
182 | See 'metrics._is_concordant'.
183 |
184 | Keyword Arguments:
185 | method {str} -- Type of c-index 'antolini' or 'adj_antolini' (default {'adj_antolini'}).
186 |
187 | Returns:
188 | float -- Time dependent concordance index.
189 | """
190 | return concordance_td(self.durations, self.events, self.surv.values,
191 | self._duration_idx(), method)
192 |
193 | def brier_score(self, time_grid, max_weight=np.inf):
194 | """Brier score weighted by the inverse censoring distribution.
195 | See Section 3.1.2 or [1] for details of the wighting scheme.
196 |
197 | Arguments:
198 | time_grid {np.array} -- Durations where the brier score should be calculated.
199 |
200 | Keyword Arguments:
201 | max_weight {float} -- Max weight value (max number of individuals an individual
202 | can represent (default {np.inf}).
203 |
204 | References:
205 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems
206 | and Solutions. arXiv preprint arXiv:1912.08581, 2019.
207 | https://arxiv.org/pdf/1912.08581.pdf
208 | """
209 | if self.censor_surv is None:
210 | raise ValueError("""Need to add censor_surv to compute Brier score. Use 'add_censor_est'
211 | or 'add_km_censor' for Kaplan-Meier""")
212 | bs = ipcw.brier_score(time_grid, self.durations, self.events, self.surv.values,
213 | self.censor_surv.surv.values, self.index_surv,
214 | self.censor_surv.index_surv, max_weight, True, self.steps,
215 | self.censor_surv.steps)
216 | return pd.Series(bs, index=time_grid).rename('brier_score')
217 |
218 | def nbll(self, time_grid, max_weight=np.inf):
219 | """Negative binomial log-likelihood weighted by the inverse censoring distribution.
220 | See Section 3.1.2 or [1] for details of the wighting scheme.
221 |
222 | Arguments:
223 | time_grid {np.array} -- Durations where the brier score should be calculated.
224 |
225 | Keyword Arguments:
226 | max_weight {float} -- Max weight value (max number of individuals an individual
227 | can represent (default {np.inf}).
228 |
229 | References:
230 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems
231 | and Solutions. arXiv preprint arXiv:1912.08581, 2019.
232 | https://arxiv.org/pdf/1912.08581.pdf
233 | """
234 | if self.censor_surv is None:
235 | raise ValueError("""Need to add censor_surv to compute the score. Use 'add_censor_est'
236 | or 'add_km_censor' for Kaplan-Meier""")
237 | bll = ipcw.binomial_log_likelihood(time_grid, self.durations, self.events, self.surv.values,
238 | self.censor_surv.surv.values, self.index_surv,
239 | self.censor_surv.index_surv, max_weight, True, self.steps,
240 | self.censor_surv.steps)
241 | return pd.Series(-bll, index=time_grid).rename('nbll')
242 |
243 | def integrated_brier_score(self, time_grid, max_weight=np.inf):
244 | """Integrated Brier score weighted by the inverse censoring distribution.
245 | Essentially an integral over values obtained from `brier_score(time_grid, max_weight)`.
246 |
247 | Arguments:
248 | time_grid {np.array} -- Durations where the brier score should be calculated.
249 |
250 | Keyword Arguments:
251 | max_weight {float} -- Max weight value (max number of individuals an individual
252 | can represent (default {np.inf}).
253 | """
254 | if self.censor_surv is None:
255 | raise ValueError("Need to add censor_surv to compute briser score. Use 'add_censor_est'")
256 | return ipcw.integrated_brier_score(time_grid, self.durations, self.events, self.surv.values,
257 | self.censor_surv.surv.values, self.index_surv,
258 | self.censor_surv.index_surv, max_weight, self.steps,
259 | self.censor_surv.steps)
260 |
261 | def integrated_nbll(self, time_grid, max_weight=np.inf):
262 | """Integrated negative binomial log-likelihood weighted by the inverse censoring distribution.
263 | Essentially an integral over values obtained from `nbll(time_grid, max_weight)`.
264 |
265 | Arguments:
266 | time_grid {np.array} -- Durations where the brier score should be calculated.
267 |
268 | Keyword Arguments:
269 | max_weight {float} -- Max weight value (max number of individuals an individual
270 | can represent (default {np.inf}).
271 | """
272 | if self.censor_surv is None:
273 | raise ValueError("Need to add censor_surv to compute the score. Use 'add_censor_est'")
274 | ibll = ipcw.integrated_binomial_log_likelihood(time_grid, self.durations, self.events, self.surv.values,
275 | self.censor_surv.surv.values, self.index_surv,
276 | self.censor_surv.index_surv, max_weight, self.steps,
277 | self.censor_surv.steps)
278 | return -ibll
279 |
280 | def brier_score_admin(self, time_grid):
281 | """The Administrative Brier score proposed by [1].
282 | Removes individuals as they are administratively censored, event if they have experienced an
283 | event.
284 |
285 | Arguments:
286 | time_grid {np.array} -- Durations where the brier score should be calculated.
287 |
288 | References:
289 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems
290 | and Solutions. arXiv preprint arXiv:1912.08581, 2019.
291 | https://arxiv.org/pdf/1912.08581.pdf
292 | """
293 | if self.censor_durations is None:
294 | raise ValueError("Need to provide `censor_durations` (censoring durations) to use this method")
295 | bs = admin.brier_score(time_grid, self.durations, self.censor_durations, self.events,
296 | self.surv.values, self.index_surv, True, self.steps)
297 | return pd.Series(bs, index=time_grid).rename('brier_score')
298 |
299 | def integrated_brier_score_admin(self, time_grid):
300 | """The Integrated administrative Brier score proposed by [1].
301 | Removes individuals as they are administratively censored, event if they have experienced an
302 | event.
303 |
304 | Arguments:
305 | time_grid {np.array} -- Durations where the brier score should be calculated.
306 |
307 | References:
308 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems
309 | and Solutions. arXiv preprint arXiv:1912.08581, 2019.
310 | https://arxiv.org/pdf/1912.08581.pdf
311 | """
312 | if self.censor_durations is None:
313 | raise ValueError("Need to provide `censor_durations` (censoring durations) to use this method")
314 | ibs = admin.integrated_brier_score(time_grid, self.durations, self.censor_durations, self.events,
315 | self.surv.values, self.index_surv, self.steps)
316 | return ibs
317 |
318 | def nbll_admin(self, time_grid):
319 | """The negative administrative binomial log-likelihood proposed by [1].
320 | Removes individuals as they are administratively censored, event if they have experienced an
321 | event.
322 |
323 | Arguments:
324 | time_grid {np.array} -- Durations where the brier score should be calculated.
325 |
326 | References:
327 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems
328 | and Solutions. arXiv preprint arXiv:1912.08581, 2019.
329 | https://arxiv.org/pdf/1912.08581.pdf
330 | """
331 | if self.censor_durations is None:
332 | raise ValueError("Need to provide `censor_durations` (censoring durations) to use this method")
333 | bll = admin.binomial_log_likelihood(time_grid, self.durations, self.censor_durations, self.events,
334 | self.surv.values, self.index_surv, True, self.steps)
335 | return pd.Series(-bll, index=time_grid).rename('nbll')
336 |
337 | def integrated_nbll_admin(self, time_grid):
338 | """The Integrated negative administrative binomial log-likelihood score proposed by [1].
339 | Removes individuals as they are administratively censored, event if they have experienced an
340 | event.
341 |
342 | Arguments:
343 | time_grid {np.array} -- Durations where the brier score should be calculated.
344 |
345 | References:
346 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems
347 | and Solutions. arXiv preprint arXiv:1912.08581, 2019.
348 | https://arxiv.org/pdf/1912.08581.pdf
349 | """
350 | if self.censor_durations is None:
351 | raise ValueError("Need to provide `censor_durations` (censoring durations) to use this method")
352 | ibll = admin.integrated_binomial_log_likelihood(time_grid, self.durations, self.censor_durations,
353 | self.events, self.surv.values, self.index_surv,
354 | self.steps)
355 | return -ibll
356 |
--------------------------------------------------------------------------------
/pycox/evaluation/ipcw.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy
3 | import numba
4 | from pycox import utils
5 |
6 | @numba.njit(parallel=True)
7 | def _inv_cens_scores(func, time_grid, durations, events, surv, censor_surv, idx_ts_surv, idx_ts_censor,
8 | idx_tt_censor, scores, weights, n_times, n_indiv, max_weight):
9 | def _inv_cens_score_single(func, ts, durations, events, surv, censor_surv, idx_ts_surv_i,
10 | idx_ts_censor_i, idx_tt_censor, scores, weights, n_indiv, max_weight):
11 | min_g = 1./max_weight
12 | for i in range(n_indiv):
13 | tt = durations[i]
14 | d = events[i]
15 | s = surv[idx_ts_surv_i, i]
16 | g_ts = censor_surv[idx_ts_censor_i, i]
17 | g_tt = censor_surv[idx_tt_censor[i], i]
18 | g_ts = max(g_ts, min_g)
19 | g_tt = max(g_tt, min_g)
20 | score, w = func(ts, tt, s, g_ts, g_tt, d)
21 | #w = min(w, max_weight)
22 | scores[i] = score * w
23 | weights[i] = w
24 |
25 | for i in numba.prange(n_times):
26 | ts = time_grid[i]
27 | idx_ts_surv_i = idx_ts_surv[i]
28 | idx_ts_censor_i = idx_ts_censor[i]
29 | scores_i = scores[i]
30 | weights_i = weights[i]
31 | _inv_cens_score_single(func, ts, durations, events, surv, censor_surv, idx_ts_surv_i,
32 | idx_ts_censor_i, idx_tt_censor, scores_i, weights_i, n_indiv, max_weight)
33 |
34 | def _inverse_censoring_weighted_metric(func):
35 | if not func.__class__.__module__.startswith('numba'):
36 | raise ValueError("Need to provide numba compiled function")
37 | def metric(time_grid, durations, events, surv, censor_surv, index_surv, index_censor, max_weight=np.inf,
38 | reduce=True, steps_surv='post', steps_censor='post'):
39 | if not hasattr(time_grid, '__iter__'):
40 | time_grid = np.array([time_grid])
41 | assert (type(time_grid) is type(durations) is type(events) is type(surv) is type(censor_surv) is
42 | type(index_surv) is type(index_censor) is np.ndarray), 'Need all input to be np.ndarrays'
43 | n_times = len(time_grid)
44 | n_indiv = len(durations)
45 | scores = np.zeros((n_times, n_indiv))
46 | weights = np.zeros((n_times, n_indiv))
47 | idx_ts_surv = utils.idx_at_times(index_surv, time_grid, steps_surv, assert_sorted=True)
48 | idx_ts_censor = utils.idx_at_times(index_censor, time_grid, steps_censor, assert_sorted=True)
49 | idx_tt_censor = utils.idx_at_times(index_censor, durations, 'pre', assert_sorted=True)
50 | if steps_censor == 'post':
51 | idx_tt_censor = (idx_tt_censor - 1).clip(0)
52 | # This ensures that we get G(tt-)
53 | _inv_cens_scores(func, time_grid, durations, events, surv, censor_surv, idx_ts_surv, idx_ts_censor,
54 | idx_tt_censor, scores, weights, n_times, n_indiv, max_weight)
55 | if reduce is True:
56 | return np.sum(scores, axis=1) / np.sum(weights, axis=1)
57 | return scores, weights
58 | return metric
59 |
60 | @numba.njit()
61 | def _brier_score(ts, tt, s, g_ts, g_tt, d):
62 | if (tt <= ts) and d == 1:
63 | return np.power(s, 2), 1./g_tt
64 | if tt > ts:
65 | return np.power(1 - s, 2), 1./g_ts
66 | return 0., 0.
67 |
68 | @numba.njit()
69 | def _binomial_log_likelihood(ts, tt, s, g_ts, g_tt, d, eps=1e-7):
70 | s = eps if s < eps else s
71 | s = (1-eps) if s > (1 - eps) else s
72 | if (tt <= ts) and d == 1:
73 | return np.log(1 - s), 1./g_tt
74 | if tt > ts:
75 | return np.log(s), 1./g_ts
76 | return 0., 0.
77 |
78 | brier_score = _inverse_censoring_weighted_metric(_brier_score)
79 | binomial_log_likelihood = _inverse_censoring_weighted_metric(_binomial_log_likelihood)
80 |
81 | def _integrated_inverce_censoring_weighed_metric(func):
82 | def metric(time_grid, durations, events, surv, censor_surv, index_surv, index_censor,
83 | max_weight=np.inf, steps_surv='post', steps_censor='post'):
84 | scores = func(time_grid, durations, events, surv, censor_surv, index_surv, index_censor,
85 | max_weight, True, steps_surv, steps_censor)
86 | integral = scipy.integrate.simps(scores, time_grid)
87 | return integral / (time_grid[-1] - time_grid[0])
88 | return metric
89 |
90 | integrated_brier_score = _integrated_inverce_censoring_weighed_metric(brier_score)
91 | integrated_binomial_log_likelihood = _integrated_inverce_censoring_weighed_metric(binomial_log_likelihood)
92 |
--------------------------------------------------------------------------------
/pycox/evaluation/metrics.py:
--------------------------------------------------------------------------------
1 | '''
2 | Some relevant metrics
3 | '''
4 | import numpy as np
5 | import pandas as pd
6 |
7 | def partial_log_likelihood_ph(log_partial_hazards, durations, events, mean=True):
8 | """Partial log-likelihood for PH models.
9 |
10 | Arguments:
11 | log_partial_hazards {np.array} -- Log partial hazards (e.g. x^T beta).
12 | durations {np.array} -- Durations.
13 | events {np.array} -- Events.
14 |
15 | Keyword Arguments:
16 | mean {bool} -- Return the mean. (default: {True})
17 |
18 | Returns:
19 | pd.Series or float -- partial log-likelihood or mean.
20 | """
21 |
22 | df = pd.DataFrame(dict(duration=durations, event=events, lph=log_partial_hazards))
23 | pll = (df
24 | .sort_values('duration', ascending=False)
25 | .assign(cum_ph=(lambda x: x['lph']
26 | .pipe(np.exp)
27 | .cumsum()
28 | .groupby(x['duration'])
29 | .transform('max')))
30 | .loc[lambda x: x['event'] == 1]
31 | .assign(pll=lambda x: x['lph'] - np.log(x['cum_ph']))
32 | ['pll'])
33 | if mean:
34 | return pll.mean()
35 | return pll
36 |
--------------------------------------------------------------------------------
/pycox/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from pycox.models import base, loss, utils, pmf, data
3 | from pycox.models.cox import CoxPH
4 | from pycox.models.cox_cc import CoxCC
5 | from pycox.models.cox_time import CoxTime
6 | from pycox.models.deephit import DeepHitSingle, DeepHit
7 | from pycox.models.pmf import PMF
8 | from pycox.models.logistic_hazard import LogisticHazard
9 | from pycox.models.pc_hazard import PCHazard
10 | from pycox.models.mtlr import MTLR
11 | from pycox.models.bce_surv import BCESurv
12 |
--------------------------------------------------------------------------------
/pycox/models/base.py:
--------------------------------------------------------------------------------
1 | import torchtuples as tt
2 | import warnings
3 |
4 |
5 | class SurvBase(tt.Model):
6 | """Base class for survival models.
7 | Essentially same as torchtuples.Model,
8 | """
9 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True,
10 | to_cpu=False, num_workers=0):
11 | """Predict the survival function for `input`.
12 | See `prediction_surv_df` to return a DataFrame instead.
13 |
14 | Arguments:
15 | input {dataloader, tuple, np.ndarray, or torch.tensor} -- Input to net.
16 |
17 | Keyword Arguments:
18 | batch_size {int} -- Batch size (default: {8224})
19 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
20 | (default: {None})
21 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
22 | to_cpu {bool} -- For larger data sets we need to move the results to cpu
23 | (default: {False})
24 | num_workers {int} -- Number of workers in created dataloader (default: {0})
25 |
26 | Returns:
27 | [TupleTree, np.ndarray or tensor] -- Predictions
28 | """
29 | raise NotImplementedError
30 |
31 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0):
32 | """Predict the survival function for `input` and return as a pandas DataFrame.
33 | See `predict_surv` to return tensor or np.array instead.
34 |
35 | Arguments:
36 | input {dataloader, tuple, np.ndarray, or torch.tensor} -- Input to net.
37 |
38 | Keyword Arguments:
39 | batch_size {int} -- Batch size (default: {8224})
40 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
41 | num_workers {int} -- Number of workers in created dataloader (default: {0})
42 |
43 | Returns:
44 | pd.DataFrame -- Predictions
45 | """
46 | raise NotImplementedError
47 |
48 | def predict_hazard(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
49 | num_workers=0):
50 | """Predict the hazard function for `input`.
51 |
52 | Arguments:
53 | input {dataloader, tuple, np.ndarray, or torch.tensor} -- Input to net.
54 |
55 | Keyword Arguments:
56 | batch_size {int} -- Batch size (default: {8224})
57 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
58 | (default: {None})
59 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
60 | grads {bool} -- If gradients should be computed (default: {False})
61 | to_cpu {bool} -- For larger data sets we need to move the results to cpu
62 | (default: {False})
63 | num_workers {int} -- Number of workers in created dataloader (default: {0})
64 |
65 | Returns:
66 | [np.ndarray or tensor] -- Predicted hazards
67 | """
68 | raise NotImplementedError
69 |
70 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
71 | num_workers=0):
72 | """Predict the probability mass function (PMF) for `input`.
73 |
74 | Arguments:
75 | input {dataloader, tuple, np.ndarray, or torch.tensor} -- Input to net.
76 |
77 | Keyword Arguments:
78 | batch_size {int} -- Batch size (default: {8224})
79 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
80 | (default: {None})
81 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
82 | grads {bool} -- If gradients should be computed (default: {False})
83 | to_cpu {bool} -- For larger data sets we need to move the results to cpu
84 | (default: {False})
85 | num_workers {int} -- Number of workers in created dataloader (default: {0})
86 |
87 | Returns:
88 | [np.ndarray or tensor] -- Predictions
89 | """
90 | raise NotImplementedError
91 |
92 |
93 | class _SurvModelBase(tt.Model):
94 | """Base class for survival models.
95 | Essentially same as torchtuples.Model,
96 | """
97 | def __init__(self, net, loss=None, optimizer=None, device=None):
98 | warnings.warn('Will be removed shortly', DeprecationWarning)
99 | super().__init__(net, loss, optimizer, device)
100 |
101 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True,
102 | to_cpu=False, num_workers=0):
103 | """Predict the survival function for `input`.
104 | See `prediction_surv_df` to return a DataFrame instead.
105 |
106 | Arguments:
107 | input {tuple, np.ndarray, or torch.tensor} -- Input to net.
108 |
109 | Keyword Arguments:
110 | batch_size {int} -- Batch size (default: {8224})
111 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
112 | (default: {None})
113 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
114 | to_cpu {bool} -- For larger data sets we need to move the results to cpu
115 | (default: {False})
116 | num_workers {int} -- Number of workers in created dataloader (default: {0})
117 |
118 | Returns:
119 | [TupleTree, np.ndarray or tensor] -- Predictions
120 | """
121 | raise NotImplementedError
122 |
123 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0):
124 | """Predict the survival function for `input` and return as a pandas DataFrame.
125 | See `predict_surv` to return tensor or np.array instead.
126 |
127 | Arguments:
128 | input {tuple, np.ndarray, or torch.tensor} -- Input to net.
129 |
130 | Keyword Arguments:
131 | batch_size {int} -- Batch size (default: {8224})
132 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
133 | num_workers {int} -- Number of workers in created dataloader (default: {0})
134 |
135 | Returns:
136 | pd.DataFrame -- Predictions
137 | """
138 | raise NotImplementedError
139 |
--------------------------------------------------------------------------------
/pycox/models/bce_surv.py:
--------------------------------------------------------------------------------
1 | """Estimate survival curve with binomial log-likelihood.
2 |
3 | This method is not smart to use!!!!!!!
4 | """
5 | import pandas as pd
6 | import torch
7 | from pycox import models
8 | from pycox.preprocessing import label_transforms
9 |
10 |
11 | class BCESurv(models.base.SurvBase):
12 | """
13 | The BCESurv method is a discrete-time survival model that parametrize the survival function directly
14 | and disregards individuals as they are censored. Each output node represents a binary classifier at
15 | the corresponding time, where all censored individual are removed.
16 | See [1] for details.
17 |
18 | Arguments:
19 | net {torch.nn.Module} -- A torch module.
20 |
21 | Keyword Arguments:
22 | optimizer {Optimizer} -- A torch optimizer or similar. Preferably use torchtuples.optim instead of
23 | torch.optim, as this allows for reinitialization, etc. If 'None' set to torchtuples.optim.AdamW.
24 | (default: {None})
25 | device {str, int, torch.device} -- Device to compute on. (default: {None})
26 | Preferably pass a torch.device object.
27 | If 'None': use default gpu if available, else use cpu.
28 | If 'int': used that gpu: torch.device('cuda:').
29 | If 'string': string is passed to torch.device('string').
30 | duration_index {list, np.array} -- Array of durations that defines the discrete times.
31 | This is used to set the index of the DataFrame in `predict_surv_df`.
32 | loss {func} -- An alternative loss function (default: {None})
33 |
34 | References:
35 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems
36 | and Solutions. arXiv preprint arXiv:1912.08581, 2019.
37 | https://arxiv.org/pdf/1912.08581.pdf
38 | """
39 | label_transform = label_transforms.LabTransDiscreteTime
40 |
41 | def __init__(self, net, optimizer=None, device=None, duration_index=None, loss=None):
42 | self.duration_index = duration_index
43 | if loss is None:
44 | loss = models.loss.BCESurvLoss()
45 | super().__init__(net, loss, optimizer, device)
46 |
47 | @property
48 | def duration_index(self):
49 | """
50 | Array of durations that defines the discrete times. This is used to set the index
51 | of the DataFrame in `predict_surv_df`.
52 |
53 | Returns:
54 | np.array -- Duration index.
55 | """
56 | return self._duration_index
57 |
58 | @duration_index.setter
59 | def duration_index(self, val):
60 | self._duration_index = val
61 |
62 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0, is_dataloader=None):
63 | surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers, is_dataloader)
64 | return pd.DataFrame(surv.transpose(), self.duration_index)
65 |
66 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
67 | num_workers=0, is_dataloader=None):
68 | return self.predict(input, batch_size, numpy, eval_, False, to_cpu, num_workers,
69 | is_dataloader, torch.sigmoid)
70 |
71 | def interpolate(self, sub=10, scheme='const_pdf', duration_index=None):
72 | """Use interpolation for predictions.
73 | There is only one scheme:
74 | `const_pdf` and `lin_surv` which assumes pice-wise constant PMF in each interval (linear survival).
75 |
76 | Keyword Arguments:
77 | sub {int} -- Number of "sub" units in interpolation grid. If `sub` is 10 we have a grid with
78 | 10 times the number of grid points than the original `duration_index` (default: {10}).
79 | scheme {str} -- Type of interpolation {'const_pdf'}.
80 | See `InterpolateDiscrete` (default: {'const_pdf'})
81 | duration_index {np.array} -- Cuts used for discretization. Does not affect interpolation,
82 | only for setting index in `predict_surv_df` (default: {None})
83 |
84 | Returns:
85 | [InterpolateLogisticHazard] -- Object for prediction with interpolation.
86 | """
87 | if duration_index is None:
88 | duration_index = self.duration_index
89 | return models.interpolation.InterpolateDiscrete(self, scheme, duration_index, sub)
90 |
--------------------------------------------------------------------------------
/pycox/models/cox.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | import torchtuples as tt
7 | from pycox import models
8 |
9 | def search_sorted_idx(array, values):
10 | '''For sorted array, get index of values.
11 | If value not in array, give left index of value.
12 | '''
13 | n = len(array)
14 | idx = np.searchsorted(array, values)
15 | idx[idx == n] = n-1 # We can't have indexes higher than the length-1
16 | not_exact = values != array[idx]
17 | idx -= not_exact
18 | if any(idx < 0):
19 | warnings.warn('Given value smaller than first value')
20 | idx[idx < 0] = 0
21 | return idx
22 |
23 |
24 | class _CoxBase(models.base.SurvBase):
25 | duration_col = 'duration'
26 | event_col = 'event'
27 |
28 | def fit(self, input, target, batch_size=256, epochs=1, callbacks=None, verbose=True,
29 | num_workers=0, shuffle=True, metrics=None, val_data=None, val_batch_size=8224,
30 | **kwargs):
31 | """Fit model with inputs and targets. Where 'input' is the covariates, and
32 | 'target' is a tuple with (durations, events).
33 |
34 | Arguments:
35 | input {np.array, tensor or tuple} -- Input x passed to net.
36 | target {np.array, tensor or tuple} -- Target [durations, events].
37 |
38 | Keyword Arguments:
39 | batch_size {int} -- Elements in each batch (default: {256})
40 | epochs {int} -- Number of epochs (default: {1})
41 | callbacks {list} -- list of callbacks (default: {None})
42 | verbose {bool} -- Print progress (default: {True})
43 | num_workers {int} -- Number of workers used in the dataloader (default: {0})
44 | shuffle {bool} -- If we should shuffle the order of the dataset (default: {True})
45 | **kwargs are passed to 'make_dataloader' method.
46 |
47 | Returns:
48 | TrainingLogger -- Training log
49 | """
50 | self.training_data = tt.tuplefy(input, target)
51 | return super().fit(input, target, batch_size, epochs, callbacks, verbose,
52 | num_workers, shuffle, metrics, val_data, val_batch_size,
53 | **kwargs)
54 |
55 | def _compute_baseline_hazards(self, input, df, max_duration, batch_size, eval_=True, num_workers=0):
56 | raise NotImplementedError
57 |
58 | def target_to_df(self, target):
59 | durations, events = tt.tuplefy(target).to_numpy()
60 | df = pd.DataFrame({self.duration_col: durations, self.event_col: events})
61 | return df
62 |
63 | def compute_baseline_hazards(self, input=None, target=None, max_duration=None, sample=None, batch_size=8224,
64 | set_hazards=True, eval_=True, num_workers=0):
65 | """Computes the Breslow estimates form the data defined by `input` and `target`
66 | (if `None` use training data).
67 |
68 | Typically call
69 | model.compute_baseline_hazards() after fitting.
70 |
71 | Keyword Arguments:
72 | input -- Input data (train input) (default: {None})
73 | target -- Target data (train target) (default: {None})
74 | max_duration {float} -- Don't compute estimates for duration higher (default: {None})
75 | sample {float or int} -- Compute estimates of subsample of data (default: {None})
76 | batch_size {int} -- Batch size (default: {8224})
77 | set_hazards {bool} -- Set hazards in model object, or just return hazards. (default: {True})
78 |
79 | Returns:
80 | pd.Series -- Pandas series with baseline hazards. Index is duration_col.
81 | """
82 | if (input is None) and (target is None):
83 | if not hasattr(self, 'training_data'):
84 | raise ValueError("Need to give a 'input' and 'target' to this function.")
85 | input, target = self.training_data
86 | df = self.target_to_df(target)#.sort_values(self.duration_col)
87 | if sample is not None:
88 | if sample >= 1:
89 | df = df.sample(n=sample)
90 | else:
91 | df = df.sample(frac=sample)
92 | input = tt.tuplefy(input).to_numpy().iloc[df.index.values]
93 | base_haz = self._compute_baseline_hazards(input, df, max_duration, batch_size,
94 | eval_=eval_, num_workers=num_workers)
95 | if set_hazards:
96 | self.compute_baseline_cumulative_hazards(set_hazards=True, baseline_hazards_=base_haz)
97 | return base_haz
98 |
99 | def compute_baseline_cumulative_hazards(self, input=None, target=None, max_duration=None, sample=None,
100 | batch_size=8224, set_hazards=True, baseline_hazards_=None,
101 | eval_=True, num_workers=0):
102 | """See `compute_baseline_hazards. This is the cumulative version."""
103 | if ((input is not None) or (target is not None)) and (baseline_hazards_ is not None):
104 | raise ValueError("'input', 'target' and 'baseline_hazards_' can not both be different from 'None'.")
105 | if baseline_hazards_ is None:
106 | baseline_hazards_ = self.compute_baseline_hazards(input, target, max_duration, sample, batch_size,
107 | set_hazards=False, eval_=eval_, num_workers=num_workers)
108 | assert baseline_hazards_.index.is_monotonic_increasing,\
109 | 'Need index of baseline_hazards_ to be monotonic increasing, as it represents time.'
110 | bch = (baseline_hazards_
111 | .cumsum()
112 | .rename('baseline_cumulative_hazards'))
113 | if set_hazards:
114 | self.baseline_hazards_ = baseline_hazards_
115 | self.baseline_cumulative_hazards_ = bch
116 | return bch
117 |
118 | def predict_cumulative_hazards(self, input, max_duration=None, batch_size=8224, verbose=False,
119 | baseline_hazards_=None, eval_=True, num_workers=0):
120 | """See `predict_survival_function`."""
121 | if type(input) is pd.DataFrame:
122 | input = self.df_to_input(input)
123 | if baseline_hazards_ is None:
124 | if not hasattr(self, 'baseline_hazards_'):
125 | raise ValueError('Need to compute baseline_hazards_. E.g run `model.compute_baseline_hazards()`')
126 | baseline_hazards_ = self.baseline_hazards_
127 | assert baseline_hazards_.index.is_monotonic_increasing,\
128 | 'Need index of baseline_hazards_ to be monotonic increasing, as it represents time.'
129 | return self._predict_cumulative_hazards(input, max_duration, batch_size, verbose, baseline_hazards_,
130 | eval_, num_workers=num_workers)
131 |
132 | def _predict_cumulative_hazards(self, input, max_duration, batch_size, verbose, baseline_hazards_,
133 | eval_=True, num_workers=0):
134 | raise NotImplementedError
135 |
136 | def predict_surv_df(self, input, max_duration=None, batch_size=8224, verbose=False, baseline_hazards_=None,
137 | eval_=True, num_workers=0):
138 | """Predict survival function for `input`. S(x, t) = exp(-H(x, t))
139 | Require computed baseline hazards.
140 |
141 | Arguments:
142 | input {np.array, tensor or tuple} -- Input x passed to net.
143 |
144 | Keyword Arguments:
145 | max_duration {float} -- Don't compute estimates for duration higher (default: {None})
146 | batch_size {int} -- Batch size (default: {8224})
147 | baseline_hazards_ {pd.Series} -- Baseline hazards. If `None` used `model.baseline_hazards_` (default: {None})
148 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
149 | num_workers {int} -- Number of workers in created dataloader (default: {0})
150 |
151 | Returns:
152 | pd.DataFrame -- Survival estimates. One columns for each individual.
153 | """
154 | return np.exp(-self.predict_cumulative_hazards(input, max_duration, batch_size, verbose, baseline_hazards_,
155 | eval_, num_workers))
156 |
157 | def predict_surv(self, input, max_duration=None, batch_size=8224, numpy=None, verbose=False,
158 | baseline_hazards_=None, eval_=True, num_workers=0):
159 | """Predict survival function for `input`. S(x, t) = exp(-H(x, t))
160 | Require compueted baseline hazards.
161 |
162 | Arguments:
163 | input {np.array, tensor or tuple} -- Input x passed to net.
164 |
165 | Keyword Arguments:
166 | max_duration {float} -- Don't compute estimates for duration higher (default: {None})
167 | batch_size {int} -- Batch size (default: {8224})
168 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
169 | (default: {None})
170 | baseline_hazards_ {pd.Series} -- Baseline hazards. If `None` used `model.baseline_hazards_` (default: {None})
171 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
172 | num_workers {int} -- Number of workers in created dataloader (default: {0})
173 |
174 | Returns:
175 | pd.DataFrame -- Survival estimates. One columns for each individual.
176 | """
177 | surv = self.predict_surv_df(input, max_duration, batch_size, verbose, baseline_hazards_,
178 | eval_, num_workers)
179 | surv = torch.from_numpy(surv.values.transpose())
180 | return tt.utils.array_or_tensor(surv, numpy, input)
181 |
182 | def save_net(self, path, **kwargs):
183 | """Save self.net and baseline hazards to file.
184 |
185 | Arguments:
186 | path {str} -- Path to file.
187 | **kwargs are passed to torch.save
188 |
189 | Returns:
190 | None
191 | """
192 | path, extension = os.path.splitext(path)
193 | if extension == "":
194 | extension = '.pt'
195 | super().save_net(path+extension, **kwargs)
196 | if hasattr(self, 'baseline_hazards_'):
197 | self.baseline_hazards_.to_pickle(path+'_blh.pickle')
198 |
199 | def load_net(self, path, **kwargs):
200 | """Load net and hazards from file.
201 |
202 | Arguments:
203 | path {str} -- Path to file.
204 | **kwargs are passed to torch.load
205 |
206 | Returns:
207 | None
208 | """
209 | path, extension = os.path.splitext(path)
210 | if extension == "":
211 | extension = '.pt'
212 | super().load_net(path+extension, **kwargs)
213 | blh_path = path+'_blh.pickle'
214 | if os.path.isfile(blh_path):
215 | self.baseline_hazards_ = pd.read_pickle(blh_path)
216 | self.baseline_cumulative_hazards_ = self.baseline_hazards_.cumsum()
217 |
218 | def df_to_input(self, df):
219 | input = df[self.input_cols].values
220 | return input
221 |
222 |
223 | class _CoxPHBase(_CoxBase):
224 | def _compute_baseline_hazards(self, input, df_target, max_duration, batch_size, eval_=True, num_workers=0):
225 | if max_duration is None:
226 | max_duration = np.inf
227 |
228 | # Here we are computing when expg when there are no events.
229 | # Could be made faster, by only computing when there are events.
230 | return (df_target
231 | .assign(expg=np.exp(self.predict(input, batch_size, True, eval_, num_workers=num_workers)))
232 | .groupby(self.duration_col)
233 | .agg({'expg': 'sum', self.event_col: 'sum'})
234 | .sort_index(ascending=False)
235 | .assign(expg=lambda x: x['expg'].cumsum())
236 | .pipe(lambda x: x[self.event_col]/x['expg'])
237 | .fillna(0.)
238 | .iloc[::-1]
239 | .loc[lambda x: x.index <= max_duration]
240 | .rename('baseline_hazards'))
241 |
242 | def _predict_cumulative_hazards(self, input, max_duration, batch_size, verbose, baseline_hazards_,
243 | eval_=True, num_workers=0):
244 | max_duration = np.inf if max_duration is None else max_duration
245 | if baseline_hazards_ is self.baseline_hazards_:
246 | bch = self.baseline_cumulative_hazards_
247 | else:
248 | bch = self.compute_baseline_cumulative_hazards(set_hazards=False,
249 | baseline_hazards_=baseline_hazards_)
250 | bch = bch.loc[lambda x: x.index <= max_duration]
251 | expg = np.exp(self.predict(input, batch_size, True, eval_, num_workers=num_workers)).reshape(1, -1)
252 | return pd.DataFrame(bch.values.reshape(-1, 1).dot(expg),
253 | index=bch.index)
254 |
255 | def partial_log_likelihood(self, input, target, g_preds=None, batch_size=8224, eps=1e-7, eval_=True,
256 | num_workers=0):
257 | '''Calculate the partial log-likelihood for the events in datafram df.
258 | This likelihood does not sample the controls.
259 | Note that censored data (non events) does not have a partial log-likelihood.
260 |
261 | Arguments:
262 | input {tuple, np.ndarray, or torch.tensor} -- Input to net.
263 | target {tuple, np.ndarray, or torch.tensor} -- Target labels.
264 |
265 | Keyword Arguments:
266 | g_preds {np.array} -- Predictions from `model.predict` (default: {None})
267 | batch_size {int} -- Batch size (default: {8224})
268 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
269 | num_workers {int} -- Number of workers in created dataloader (default: {0})
270 |
271 | Returns:
272 | Partial log-likelihood.
273 | '''
274 | df = self.target_to_df(target)
275 | if g_preds is None:
276 | g_preds = self.predict(input, batch_size, True, eval_, num_workers=num_workers)
277 | return (df
278 | .assign(_g_preds=g_preds)
279 | .sort_values(self.duration_col, ascending=False)
280 | .assign(_cum_exp_g=(lambda x: x['_g_preds']
281 | .pipe(np.exp)
282 | .cumsum()
283 | .groupby(x[self.duration_col])
284 | .transform('max')))
285 | .loc[lambda x: x[self.event_col] == 1]
286 | .assign(pll=lambda x: x['_g_preds'] - np.log(x['_cum_exp_g'] + eps))
287 | ['pll'])
288 |
289 |
290 | class CoxPH(_CoxPHBase):
291 | """Cox proportional hazards model parameterized with a neural net.
292 | This is essentially the DeepSurv method [1].
293 |
294 | The loss function is not quite the partial log-likelihood, but close.
295 | The difference is that for tied events, we use a random order instead of
296 | including all individuals that had an event at that point in time.
297 |
298 | Arguments:
299 | net {torch.nn.Module} -- A pytorch net.
300 |
301 | Keyword Arguments:
302 | optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None})
303 | device {str, int, torch.device} -- Device to compute on. (default: {None})
304 | Preferably pass a torch.device object.
305 | If 'None': use default gpu if available, else use cpu.
306 | If 'int': used that gpu: torch.device('cuda:').
307 | If 'string': string is passed to torch.device('string').
308 |
309 | [1] Jared L. Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger.
310 | Deepsurv: personalized treatment recommender system using a Cox proportional hazards deep neural network.
311 | BMC Medical Research Methodology, 18(1), 2018.
312 | https://bmcmedresmethodol.biomedcentral.com/articles/10.1186/s12874-018-0482-1
313 | """
314 | def __init__(self, net, optimizer=None, device=None, loss=None):
315 | if loss is None:
316 | loss = models.loss.CoxPHLoss()
317 | super().__init__(net, loss, optimizer, device)
318 |
319 |
320 | class CoxPHSorted(_CoxPHBase):
321 | """Cox proportional hazards model parameterized with a neural net.
322 | This is essentially the DeepSurv method [1].
323 |
324 | The loss function is not quite the partial log-likelihood, but close.
325 | The difference is that for tied events, we use a random order instead of
326 | including all individuals that had an event at that point in time.
327 |
328 | Arguments:
329 | net {torch.nn.Module} -- A pytorch net.
330 |
331 | Keyword Arguments:
332 | optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None})
333 | device {str, int, torch.device} -- Device to compute on. (default: {None})
334 | Preferably pass a torch.device object.
335 | If 'None': use default gpu if available, else use cpu.
336 | If 'int': used that gpu: torch.device('cuda:').
337 | If 'string': string is passed to torch.device('string').
338 |
339 | [1] Jared L. Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger.
340 | Deepsurv: personalized treatment recommender system using a Cox proportional hazards deep neural network.
341 | BMC Medical Research Methodology, 18(1), 2018.
342 | https://bmcmedresmethodol.biomedcentral.com/articles/10.1186/s12874-018-0482-1
343 | """
344 | def __init__(self, net, optimizer=None, device=None, loss=None):
345 | warnings.warn('Use `CoxPH` instead. This will be removed', DeprecationWarning)
346 | if loss is None:
347 | loss = models.loss.CoxPHLossSorted()
348 | super().__init__(net, loss, optimizer, device)
349 |
350 | @staticmethod
351 | def make_dataloader(data, batch_size, shuffle, num_workers=0):
352 | dataloader = tt.make_dataloader(data, batch_size, shuffle, num_workers,
353 | make_dataset=models.data.DurationSortedDataset)
354 | return dataloader
355 |
356 | def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
357 | dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers)
358 | return dataloader
359 |
--------------------------------------------------------------------------------
/pycox/models/cox_cc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torchtuples as tt
3 | from pycox import models
4 |
5 |
6 | class _CoxCCBase(models.cox._CoxBase):
7 | make_dataset = NotImplementedError
8 |
9 | def __init__(self, net, optimizer=None, device=None, shrink=0., loss=None):
10 | if loss is None:
11 | loss = models.loss.CoxCCLoss(shrink)
12 | super().__init__(net, loss, optimizer, device)
13 |
14 | def fit(self, input, target, batch_size=256, epochs=1, callbacks=None, verbose=True,
15 | num_workers=0, shuffle=True, metrics=None, val_data=None, val_batch_size=8224,
16 | n_control=1, shrink=None, **kwargs):
17 | """Fit model with inputs and targets. Where 'input' is the covariates, and
18 | 'target' is a tuple with (durations, events).
19 |
20 | Arguments:
21 | input {np.array, tensor or tuple} -- Input x passed to net.
22 | target {np.array, tensor or tuple} -- Target [durations, events].
23 |
24 | Keyword Arguments:
25 | batch_size {int} -- Elements in each batch (default: {256})
26 | epochs {int} -- Number of epochs (default: {1})
27 | callbacks {list} -- list of callbacks (default: {None})
28 | verbose {bool} -- Print progress (default: {True})
29 | num_workers {int} -- Number of workers used in the dataloader (default: {0})
30 | shuffle {bool} -- If we should shuffle the order of the dataset (default: {True})
31 | n_control {int} -- Number of control samples.
32 | **kwargs are passed to 'make_dataloader' method.
33 |
34 | Returns:
35 | TrainingLogger -- Training log
36 | """
37 | input, target = self._sorted_input_target(input, target)
38 | if shrink is not None:
39 | self.loss.shrink = shrink
40 | return super().fit(input, target, batch_size, epochs, callbacks, verbose,
41 | num_workers, shuffle, metrics, val_data, val_batch_size,
42 | n_control=n_control, **kwargs)
43 |
44 | def compute_metrics(self, input, metrics):
45 | if (self.loss is None) and (self.loss in metrics.values()):
46 | raise RuntimeError(f"Need to specify a loss (self.loss). It's currently None")
47 | input = self._to_device(input)
48 | batch_size = input.lens().flatten().get_if_all_equal()
49 | if batch_size is None:
50 | raise RuntimeError("All elements in input does not have the same length.")
51 | case, control = input # both are TupleTree
52 | input_all = tt.TupleTree((case,) + control).cat()
53 | g_all = self.net(*input_all)
54 | g_all = tt.tuplefy(g_all).split(batch_size).flatten()
55 | g_case = g_all[0]
56 | g_control = g_all[1:]
57 | res = {name: metric(g_case, g_control) for name, metric in metrics.items()}
58 | return res
59 |
60 | def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
61 | """Dataloader for prediction. The input is either the regular input, or a tuple
62 | with input and label.
63 |
64 | Arguments:
65 | input {np.array, tensor, tuple} -- Input to net, or tuple with input and labels.
66 | batch_size {int} -- Batch size.
67 |
68 | Keyword Arguments:
69 | shuffle {bool} -- If we should shuffle in the dataloader. (default: {False})
70 | num_workers {int} -- Number of worker in dataloader. (default: {0})
71 |
72 | Returns:
73 | dataloader -- A dataloader.
74 | """
75 | dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers)
76 | return dataloader
77 |
78 | def make_dataloader(self, data, batch_size, shuffle=True, num_workers=0, n_control=1):
79 | """Dataloader for training. Data is on the form (input, target), where
80 | target is (durations, events).
81 |
82 | Arguments:
83 | data {tuple} -- Tuple containing (input, (durations, events)).
84 | batch_size {int} -- Batch size.
85 |
86 | Keyword Arguments:
87 | shuffle {bool} -- If shuffle in dataloader (default: {True})
88 | num_workers {int} -- Number of workers in dataloader. (default: {0})
89 | n_control {int} -- Number of control samples in dataloader (default: {1})
90 |
91 | Returns:
92 | dataloader -- Dataloader for training.
93 | """
94 | input, target = self._sorted_input_target(*data)
95 | durations, events = target
96 | dataset = self.make_dataset(input, durations, events, n_control)
97 | dataloader = tt.data.DataLoaderBatch(dataset, batch_size=batch_size,
98 | shuffle=shuffle, num_workers=num_workers)
99 | return dataloader
100 |
101 | @staticmethod
102 | def _sorted_input_target(input, target):
103 | input, target = tt.tuplefy(input, target).to_numpy()
104 | durations, _ = target
105 | idx_sort = np.argsort(durations)
106 | if (idx_sort == np.arange(0, len(idx_sort))).all():
107 | return input, target
108 | input = tt.tuplefy(input).iloc[idx_sort]
109 | target = tt.tuplefy(target).iloc[idx_sort]
110 | return input, target
111 |
112 |
113 | class CoxCC(_CoxCCBase, models.cox._CoxPHBase):
114 | """Cox proportional hazards model parameterized with a neural net and
115 | trained with case-control sampling [1].
116 | This is similar to DeepSurv, but use an approximation of the loss function.
117 |
118 | Arguments:
119 | net {torch.nn.Module} -- A PyTorch net.
120 |
121 | Keyword Arguments:
122 | optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None})
123 | device {str, int, torch.device} -- Device to compute on. (default: {None})
124 | Preferably pass a torch.device object.
125 | If 'None': use default gpu if available, else use cpu.
126 | If 'int': used that gpu: torch.device('cuda:').
127 | If 'string': string is passed to torch.device('string').
128 |
129 | References:
130 | [1] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel.
131 | Time-to-event prediction with neural networks and Cox regression.
132 | Journal of Machine Learning Research, 20(129):1–30, 2019.
133 | http://jmlr.org/papers/v20/18-424.html
134 | """
135 | make_dataset = models.data.CoxCCDataset
136 |
--------------------------------------------------------------------------------
/pycox/models/cox_time.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import torch
4 | from torch import nn
5 | import torchtuples as tt
6 |
7 | from pycox import models
8 | from pycox.preprocessing.label_transforms import LabTransCoxTime
9 |
10 | class CoxTime(models.cox_cc._CoxCCBase):
11 | """The Cox-Time model from [1]. A relative risk model without proportional hazards, trained
12 | with case-control sampling.
13 |
14 | Arguments:
15 | net {torch.nn.Module} -- A PyTorch net.
16 |
17 | Keyword Arguments:
18 | optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None})
19 | device {str, int, torch.device} -- Device to compute on. (default: {None})
20 | Preferably pass a torch.device object.
21 | If 'None': use default gpu if available, else use cpu.
22 | If 'int': used that gpu: torch.device('cuda:').
23 | If 'string': string is passed to torch.device('string').
24 | shrink {float} -- Shrinkage that encourage the net got give g_case and g_control
25 | closer to zero (a regularizer in a sense). (default: {0.})
26 | labtrans {pycox.preprocessing.label_tranforms.LabTransCoxTime} -- A object for transforming
27 | durations. Useful for prediction as we can obtain durations on the original scale.
28 | (default: {None})
29 |
30 | References:
31 | [1] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel.
32 | Time-to-event prediction with neural networks and Cox regression.
33 | Journal of Machine Learning Research, 20(129):1–30, 2019.
34 | http://jmlr.org/papers/v20/18-424.html
35 | """
36 | make_dataset = models.data.CoxTimeDataset
37 | label_transform = LabTransCoxTime
38 |
39 | def __init__(self, net, optimizer=None, device=None, shrink=0., labtrans=None, loss=None):
40 | self.labtrans = labtrans
41 | super().__init__(net, optimizer, device, shrink, loss)
42 |
43 | def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
44 | input, durations = input
45 | input = tt.tuplefy(input)
46 | durations = tt.tuplefy(durations)
47 | new_input = input + durations
48 | dataloader = super().make_dataloader_predict(new_input, batch_size, shuffle, num_workers)
49 | return dataloader
50 |
51 | def predict_surv_df(self, input, max_duration=None, batch_size=8224, verbose=False, baseline_hazards_=None,
52 | eval_=True, num_workers=0):
53 | surv = super().predict_surv_df(input, max_duration, batch_size, verbose, baseline_hazards_,
54 | eval_, num_workers)
55 | if self.labtrans is not None:
56 | surv.index = self.labtrans.map_scaled_to_orig(surv.index)
57 | return surv
58 |
59 | def compute_baseline_hazards(self, input=None, target=None, max_duration=None, sample=None, batch_size=8224,
60 | set_hazards=True, eval_=True, num_workers=0):
61 | if (input is None) and (target is None):
62 | if not hasattr(self, 'training_data'):
63 | raise ValueError('Need to fit, or supply a input and target to this function.')
64 | input, target = self.training_data
65 | df = self.target_to_df(target)
66 | if sample is not None:
67 | if sample >= 1:
68 | df = df.sample(n=sample)
69 | else:
70 | df = df.sample(frac=sample)
71 | df = df.sort_values(self.duration_col)
72 | input = tt.tuplefy(input).to_numpy().iloc[df.index.values]
73 | base_haz = self._compute_baseline_hazards(input, df, max_duration, batch_size, eval_, num_workers)
74 | if set_hazards:
75 | self.compute_baseline_cumulative_hazards(set_hazards=True, baseline_hazards_=base_haz)
76 | return base_haz
77 |
78 | def _compute_baseline_hazards(self, input, df_train_target, max_duration, batch_size, eval_=True,
79 | num_workers=0):
80 | if max_duration is None:
81 | max_duration = np.inf
82 | def compute_expg_at_risk(ix, t):
83 | sub = input.iloc[ix:]
84 | n = sub.lens().flatten().get_if_all_equal()
85 | t = np.repeat(t, n).reshape(-1, 1).astype('float32')
86 | return np.exp(self.predict((sub, t), batch_size, True, eval_, num_workers=num_workers)).flatten().sum()
87 |
88 | if not df_train_target[self.duration_col].is_monotonic_increasing:
89 | raise RuntimeError(f"Need 'df_train_target' to be sorted by {self.duration_col}")
90 | input = tt.tuplefy(input)
91 | df = df_train_target.reset_index(drop=True)
92 | times = (df
93 | .loc[lambda x: x[self.event_col] != 0]
94 | [self.duration_col]
95 | .loc[lambda x: x <= max_duration]
96 | .drop_duplicates(keep='first'))
97 | at_risk_sum = (pd.Series([compute_expg_at_risk(ix, t) for ix, t in times.items()],
98 | index=times.values)
99 | .rename('at_risk_sum'))
100 | events = (df
101 | .groupby(self.duration_col)
102 | [[self.event_col]]
103 | .agg('sum')
104 | .loc[lambda x: x.index <= max_duration])
105 | base_haz = (events
106 | .join(at_risk_sum, how='left', sort=True)
107 | .pipe(lambda x: x[self.event_col] / x['at_risk_sum'])
108 | .fillna(0.)
109 | .rename('baseline_hazards'))
110 | return base_haz
111 |
112 | def _predict_cumulative_hazards(self, input, max_duration, batch_size, verbose, baseline_hazards_,
113 | eval_=True, num_workers=0):
114 | def expg_at_time(t):
115 | t = np.repeat(t, n_cols).reshape(-1, 1).astype('float32')
116 | if tt.tuplefy(input).type() is torch.Tensor:
117 | t = torch.from_numpy(t)
118 | return np.exp(self.predict((input, t), batch_size, True, eval_, num_workers=num_workers)).flatten()
119 |
120 | if tt.utils.is_dl(input):
121 | raise NotImplementedError(f"Prediction with a dataloader as input is not supported ")
122 | input = tt.tuplefy(input)
123 | max_duration = np.inf if max_duration is None else max_duration
124 | baseline_hazards_ = baseline_hazards_.loc[lambda x: x.index <= max_duration]
125 | n_rows, n_cols = baseline_hazards_.shape[0], input.lens().flatten().get_if_all_equal()
126 | hazards = np.empty((n_rows, n_cols))
127 | for idx, t in enumerate(baseline_hazards_.index):
128 | if verbose:
129 | print(idx, 'of', len(baseline_hazards_))
130 | hazards[idx, :] = expg_at_time(t)
131 | hazards[baseline_hazards_.values == 0] = 0. # in case hazards are inf here
132 | hazards *= baseline_hazards_.values.reshape(-1, 1)
133 | return pd.DataFrame(hazards, index=baseline_hazards_.index).cumsum()
134 |
135 | def partial_log_likelihood(self, input, target, batch_size=8224, eval_=True, num_workers=0):
136 | def expg_sum(t, i):
137 | sub = input_sorted.iloc[i:]
138 | n = sub.lens().flatten().get_if_all_equal()
139 | t = np.repeat(t, n).reshape(-1, 1).astype('float32')
140 | return np.exp(self.predict((sub, t), batch_size, True, eval_, num_workers=num_workers)).flatten().sum()
141 |
142 | durations, events = target
143 | df = pd.DataFrame({self.duration_col: durations, self.event_col: events})
144 | df = df.sort_values(self.duration_col)
145 | input = tt.tuplefy(input)
146 | input_sorted = input.iloc[df.index.values]
147 |
148 | times = (df
149 | .assign(_idx=np.arange(len(df)))
150 | .loc[lambda x: x[self.event_col] == True]
151 | .drop_duplicates(self.duration_col, keep='first')
152 | .assign(_expg_sum=lambda x: [expg_sum(t, i) for t, i in zip(x[self.duration_col], x['_idx'])])
153 | .drop([self.event_col, '_idx'], axis=1))
154 |
155 | idx_name_old = df.index.name
156 | idx_name = '__' + idx_name_old if idx_name_old else '__index'
157 | df.index.name = idx_name
158 |
159 | pll = df.loc[lambda x: x[self.event_col] == True]
160 | input_event = input.iloc[pll.index.values]
161 | durations_event = pll[self.duration_col].values.reshape(-1, 1)
162 | g_preds = self.predict((input_event, durations_event), batch_size, True, eval_, num_workers=num_workers).flatten()
163 | pll = (pll
164 | .assign(_g_preds=g_preds)
165 | .reset_index()
166 | .merge(times, on=self.duration_col)
167 | .set_index(idx_name)
168 | .assign(pll=lambda x: x['_g_preds'] - np.log(x['_expg_sum']))
169 | ['pll'])
170 |
171 | pll.index.name = idx_name_old
172 | return pll
173 |
174 |
175 | class MLPVanillaCoxTime(nn.Module):
176 | """A version of torchtuples.practical.MLPVanilla that works for CoxTime.
177 | The difference is that it takes `time` as an additional input and removes the output bias and
178 | output activation.
179 | """
180 | def __init__(self, in_features, num_nodes, batch_norm=True, dropout=None, activation=nn.ReLU,
181 | w_init_=lambda w: nn.init.kaiming_normal_(w, nonlinearity='relu')):
182 | super().__init__()
183 | in_features += 1
184 | out_features = 1
185 | output_activation = None
186 | output_bias=False
187 | self.net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout,
188 | activation, output_activation, output_bias, w_init_)
189 |
190 | def forward(self, input, time):
191 | input = torch.cat([input, time], dim=1)
192 | return self.net(input)
193 |
194 |
195 | class MixedInputMLPCoxTime(nn.Module):
196 | """A version of torchtuples.practical.MixedInputMLP that works for CoxTime.
197 | The difference is that it takes `time` as an additional input and removes the output bias and
198 | output activation.
199 | """
200 | def __init__(self, in_features, num_embeddings, embedding_dims, num_nodes, batch_norm=True,
201 | dropout=None, activation=nn.ReLU, dropout_embedding=0.,
202 | w_init_=lambda w: nn.init.kaiming_normal_(w, nonlinearity='relu')):
203 | super().__init__()
204 | in_features += 1
205 | out_features = 1
206 | output_activation = None
207 | output_bias=False
208 | self.net = tt.practical.MixedInputMLP(in_features, num_embeddings, embedding_dims, num_nodes,
209 | out_features, batch_norm, dropout, activation,
210 | dropout_embedding, output_activation, output_bias, w_init_)
211 |
212 | def forward(self, input_numeric, input_categoric, time):
213 | input_numeric = torch.cat([input_numeric, time], dim=1)
214 | return self.net(input_numeric, input_categoric)
215 |
--------------------------------------------------------------------------------
/pycox/models/data.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import pandas as pd
4 | import numba
5 | import torch
6 | import torchtuples as tt
7 |
8 |
9 | def sample_alive_from_dates(dates, at_risk_dict, n_control=1):
10 | '''Sample index from living at time given in dates.
11 | dates: np.array of times (or pd.Series).
12 | at_risk_dict: dict with at_risk_dict[time] = .
13 | n_control: number of samples.
14 | '''
15 | lengths = np.array([at_risk_dict[x].shape[0] for x in dates]) # Can be moved outside
16 | idx = (np.random.uniform(size=(n_control, dates.size)) * lengths).astype('int')
17 | samp = np.empty((dates.size, n_control), dtype=int)
18 |
19 | for it, time in enumerate(dates):
20 | samp[it, :] = at_risk_dict[time][idx[:, it]]
21 | return samp
22 |
23 | def make_at_risk_dict(durations):
24 | """Create dict(duration: indices) from sorted df.
25 | A dict mapping durations to indices.
26 | For each time => index of all individual alive.
27 |
28 | Arguments:
29 | durations {np.arrary} -- durations.
30 | """
31 | assert type(durations) is np.ndarray, 'Need durations to be a numpy array'
32 | durations = pd.Series(durations)
33 | assert durations.is_monotonic_increasing, 'Requires durations to be monotonic'
34 | allidx = durations.index.values
35 | keys = durations.drop_duplicates(keep='first')
36 | at_risk_dict = dict()
37 | for ix, t in keys.items():
38 | at_risk_dict[t] = allidx[ix:]
39 | return at_risk_dict
40 |
41 |
42 | class DurationSortedDataset(tt.data.DatasetTuple):
43 | """We assume the dataset contrain `(input, durations, events)`, and
44 | sort the batch based on descending `durations`.
45 |
46 | See `torchtuples.data.DatasetTuple`.
47 | """
48 | def __getitem__(self, index):
49 | batch = super().__getitem__(index)
50 | input, (duration, event) = batch
51 | idx_sort = duration.sort(descending=True)[1]
52 | event = event.float()
53 | batch = tt.tuplefy(input, event).iloc[idx_sort]
54 | return batch
55 |
56 |
57 | class CoxCCDataset(torch.utils.data.Dataset):
58 | def __init__(self, input, durations, events, n_control=1):
59 | df_train_target = pd.DataFrame(dict(duration=durations, event=events))
60 | self.durations = df_train_target.loc[lambda x: x['event'] == 1]['duration']
61 | self.at_risk_dict = make_at_risk_dict(durations)
62 |
63 | self.input = tt.tuplefy(input)
64 | assert type(self.durations) is pd.Series
65 | self.n_control = n_control
66 |
67 | def __getitem__(self, index):
68 | if (not hasattr(index, '__iter__')) and (type(index) is not slice):
69 | index = [index]
70 | fails = self.durations.iloc[index]
71 | x_case = self.input.iloc[fails.index]
72 | control_idx = sample_alive_from_dates(fails.values, self.at_risk_dict, self.n_control)
73 | x_control = tt.TupleTree(self.input.iloc[idx] for idx in control_idx.transpose())
74 | return tt.tuplefy(x_case, x_control).to_tensor()
75 |
76 | def __len__(self):
77 | return len(self.durations)
78 |
79 |
80 | class CoxTimeDataset(CoxCCDataset):
81 | def __init__(self, input, durations, events, n_control=1):
82 | super().__init__(input, durations, events, n_control)
83 | self.durations_tensor = tt.tuplefy(self.durations.values.reshape(-1, 1)).to_tensor()
84 |
85 | def __getitem__(self, index):
86 | if not hasattr(index, '__iter__'):
87 | index = [index]
88 | durations = self.durations_tensor.iloc[index]
89 | case, control = super().__getitem__(index)
90 | case = case + durations
91 | control = control.apply_nrec(lambda x: x + durations)
92 | return tt.tuplefy(case, control)
93 |
94 | @numba.njit
95 | def _pair_rank_mat(mat, idx_durations, events, dtype='float32'):
96 | n = len(idx_durations)
97 | for i in range(n):
98 | dur_i = idx_durations[i]
99 | ev_i = events[i]
100 | if ev_i == 0:
101 | continue
102 | for j in range(n):
103 | dur_j = idx_durations[j]
104 | ev_j = events[j]
105 | if (dur_i < dur_j) or ((dur_i == dur_j) and (ev_j == 0)):
106 | mat[i, j] = 1
107 | return mat
108 |
109 | def pair_rank_mat(idx_durations, events, dtype='float32'):
110 | """Indicator matrix R with R_ij = 1{T_i < T_j and D_i = 1}.
111 | So it takes value 1 if we observe that i has an event before j and zero otherwise.
112 |
113 | Arguments:
114 | idx_durations {np.array} -- Array with durations.
115 | events {np.array} -- Array with event indicators.
116 |
117 | Keyword Arguments:
118 | dtype {str} -- dtype of array (default: {'float32'})
119 |
120 | Returns:
121 | np.array -- n x n matrix indicating if i has an observerd event before j.
122 | """
123 | idx_durations = idx_durations.reshape(-1)
124 | events = events.reshape(-1)
125 | n = len(idx_durations)
126 | mat = np.zeros((n, n), dtype=dtype)
127 | mat = _pair_rank_mat(mat, idx_durations, events, dtype)
128 | return mat
129 |
130 |
131 | class DeepHitDataset(tt.data.DatasetTuple):
132 | def __getitem__(self, index):
133 | input, target = super().__getitem__(index)
134 | target = target.to_numpy()
135 | rank_mat = pair_rank_mat(*target)
136 | target = tt.tuplefy(*target, rank_mat).to_tensor()
137 | return tt.tuplefy(input, target)
138 |
--------------------------------------------------------------------------------
/pycox/models/deephit.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import torch
4 | import torchtuples as tt
5 |
6 | from pycox import models
7 | from pycox.models.utils import pad_col
8 |
9 | class DeepHitSingle(models.pmf.PMFBase):
10 | """The DeepHit methods by [1] but only for single event (not competing risks).
11 |
12 | Note that `alpha` is here defined differently than in [1], as `alpha` is weighting between
13 | the likelihood and rank loss (see Appendix D in [2])
14 | loss = alpha * nll + (1 - alpha) rank_loss(sigma).
15 |
16 | Also, unlike [1], this implementation allows for survival past the max durations, i.e., it
17 | does not assume all events happen within the defined duration grid. See [3] for details.
18 |
19 | Keyword Arguments:
20 | alpha {float} -- Weighting (0, 1) likelihood and rank loss (L2 in paper).
21 | 1 gives only likelihood, and 0 gives only rank loss. (default: {0.2})
22 | sigma {float} -- from eta in rank loss (L2 in paper) (default: {0.1})
23 |
24 | References:
25 | [1] Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar. Deephit: A deep learning
26 | approach to survival analysis with competing risks. In Thirty-Second AAAI Conference on Artificial
27 | Intelligence, 2018.
28 | http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit
29 |
30 | [2] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel.
31 | Time-to-event prediction with neural networks and Cox regression.
32 | Journal of Machine Learning Research, 20(129):1–30, 2019.
33 | http://jmlr.org/papers/v20/18-424.html
34 |
35 | [3] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction
36 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019.
37 | https://arxiv.org/pdf/1910.06724.pdf
38 | """
39 | def __init__(self, net, optimizer=None, device=None, duration_index=None, alpha=0.2, sigma=0.1, loss=None):
40 | if loss is None:
41 | loss = models.loss.DeepHitSingleLoss(alpha, sigma)
42 | super().__init__(net, loss, optimizer, device, duration_index)
43 |
44 | def make_dataloader(self, data, batch_size, shuffle, num_workers=0):
45 | dataloader = super().make_dataloader(data, batch_size, shuffle, num_workers,
46 | make_dataset=models.data.DeepHitDataset)
47 | return dataloader
48 |
49 | def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
50 | dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers)
51 | return dataloader
52 |
53 |
54 | class DeepHit(tt.Model):
55 | """DeepHit for competing risks [1].
56 | For single risk (only one event type) use `DeepHitSingle` instead!
57 |
58 | Note that `alpha` is here defined differently than in [1], as `alpha` is weighting between
59 | the likelihood and rank loss (see Appendix D in [2])
60 | loss = alpha * nll + (1 - alpha) rank_loss(sigma).
61 |
62 | Also, unlike [1], this implementation allows for survival past the max durations, i.e., it
63 | does not assume all events happen within the defined duration grid. See [3] for details.
64 |
65 | Keyword Arguments:
66 | alpha {float} -- Weighting (0, 1) likelihood and rank loss (L2 in paper).
67 | 1 gives only likelihood, and 0 gives only rank loss. (default: {0.2})
68 | sigma {float} -- from eta in rank loss (L2 in paper) (default: {0.1})
69 |
70 | References:
71 | [1] Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar. Deephit: A deep learning
72 | approach to survival analysis with competing risks. In Thirty-Second AAAI Conference on Artificial
73 | Intelligence, 2018.
74 | http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit
75 |
76 | [2] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel.
77 | Time-to-event prediction with neural networks and Cox regression.
78 | Journal of Machine Learning Research, 20(129):1–30, 2019.
79 | http://jmlr.org/papers/v20/18-424.html
80 |
81 | [3] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction
82 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019.
83 | https://arxiv.org/pdf/1910.06724.pdf
84 | """
85 | def __init__(self, net, optimizer=None, device=None, alpha=0.2, sigma=0.1, duration_index=None, loss=None):
86 | self.duration_index = duration_index
87 | if loss is None:
88 | loss = models.loss.DeepHitLoss(alpha, sigma)
89 | super().__init__(net, loss, optimizer, device)
90 |
91 | @property
92 | def duration_index(self):
93 | """
94 | Array of durations that defines the discrete times. This is used to set the index
95 | of the DataFrame in `predict_surv_df`.
96 |
97 | Returns:
98 | np.array -- Duration index.
99 | """
100 | return self._duration_index
101 |
102 | @duration_index.setter
103 | def duration_index(self, val):
104 | self._duration_index = val
105 |
106 | def make_dataloader(self, data, batch_size, shuffle, num_workers=0):
107 | dataloader = super().make_dataloader(data, batch_size, shuffle, num_workers,
108 | make_dataset=models.data.DeepHitDataset)
109 | return dataloader
110 |
111 | def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0):
112 | dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers)
113 | return dataloader
114 |
115 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0):
116 | """Predict the survival function for `input`, i.e., survive all of the event types,
117 | and return as a pandas DataFrame.
118 | See `prediction_surv_df` to return a DataFrame instead.
119 |
120 | Arguments:
121 | input {tuple, np.ndarra, or torch.tensor} -- Input to net.
122 |
123 | Keyword Arguments:
124 | batch_size {int} -- Batch size (default: {8224})
125 | eval_ {bool} -- If 'True', use 'eval' modede on net. (default: {True})
126 | num_workers {int} -- Number of workes in created dataloader (default: {0})
127 |
128 | Returns:
129 | pd.DataFrame -- Predictions
130 | """
131 | surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers)
132 | return pd.DataFrame(surv, self.duration_index)
133 |
134 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True,
135 | to_cpu=False, num_workers=0):
136 | """Predict the survival function for `input`, i.e., survive all of the event types.
137 | See `prediction_surv_df` to return a DataFrame instead.
138 |
139 | Arguments:
140 | input {tuple, np.ndarra, or torch.tensor} -- Input to net.
141 |
142 | Keyword Arguments:
143 | batch_size {int} -- Batch size (default: {8224})
144 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
145 | (default: {None})
146 | eval_ {bool} -- If 'True', use 'eval' modede on net. (default: {True})
147 | to_cpu {bool} -- For larger data sets we need to move the results to cpu
148 | (default: {False})
149 | num_workers {int} -- Number of workes in created dataloader (default: {0})
150 |
151 | Returns:
152 | [TupleTree, np.ndarray or tensor] -- Predictions
153 | """
154 | cif = self.predict_cif(input, batch_size, False, eval_, to_cpu, num_workers)
155 | surv = 1. - cif.sum(0)
156 | return tt.utils.array_or_tensor(surv, numpy, input)
157 |
158 | def predict_cif(self, input, batch_size=8224, numpy=None, eval_=True,
159 | to_cpu=False, num_workers=0):
160 | """Predict the cumulative incidence function (cif) for `input`.
161 |
162 | Arguments:
163 | input {tuple, np.ndarray, or torch.tensor} -- Input to net.
164 |
165 | Keyword Arguments:
166 | batch_size {int} -- Batch size (default: {8224})
167 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
168 | (default: {None})
169 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
170 | to_cpu {bool} -- For larger data sets we need to move the results to cpu
171 | (default: {False})
172 | num_workers {int} -- Number of workers in created dataloader (default: {0})
173 |
174 | Returns:
175 | [np.ndarray or tensor] -- Predictions
176 | """
177 | pmf = self.predict_pmf(input, batch_size, False, eval_, to_cpu, num_workers)
178 | cif = pmf.cumsum(1)
179 | return tt.utils.array_or_tensor(cif, numpy, input)
180 |
181 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True,
182 | to_cpu=False, num_workers=0):
183 | """Predict the probability mass fuction (PMF) for `input`.
184 |
185 | Arguments:
186 | input {tuple, np.ndarray, or torch.tensor} -- Input to net.
187 |
188 | Keyword Arguments:
189 | batch_size {int} -- Batch size (default: {8224})
190 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
191 | (default: {None})
192 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
193 | grads {bool} -- If gradients should be computed (default: {False})
194 | to_cpu {bool} -- For larger data sets we need to move the results to cpu
195 | (default: {False})
196 | num_workers {int} -- Number of workers in created dataloader (default: {0})
197 |
198 | Returns:
199 | [np.ndarray or tensor] -- Predictions
200 | """
201 | preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers)
202 | pmf = pad_col(preds.view(preds.size(0), -1)).softmax(1)[:, :-1]
203 | pmf = pmf.view(preds.shape).transpose(0, 1).transpose(1, 2)
204 | return tt.utils.array_or_tensor(pmf, numpy, input)
205 |
--------------------------------------------------------------------------------
/pycox/models/interpolation.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torch
3 | import torchtuples as tt
4 | from pycox.models import utils
5 |
6 |
7 | class InterpolateDiscrete:
8 | """Interpolation of discrete models, for continuous predictions.
9 | There are two schemes:
10 | `const_hazard` and `exp_surv` which assumes pice-wise constant hazard in each interval (exponential survival).
11 | `const_pdf` and `lin_surv` which assumes pice-wise constant pmf in each interval (linear survival).
12 |
13 | Arguments:
14 | model {[type]} -- [description]
15 |
16 | Keyword Arguments:
17 | duration_index {np.array} -- Cuts used for discretization. Does not affect interpolation,
18 | only for setting index in `predict_surv_df` (default: {None})
19 | scheme {str} -- Type of interpolation {'const_hazard', 'const_pdf'} (default: {'const_pdf'})
20 | sub {int} -- Number of "sub" units in interpolation grid. If `sub` is 10 we have a grid with
21 | 10 times the number of grid points than the original `duration_index` (default: {10}).
22 |
23 | Keyword Arguments:
24 | """
25 | def __init__(self, model, scheme='const_pdf', duration_index=None, sub=10, epsilon=1e-7):
26 | self.model = model
27 | self.scheme = scheme
28 | self.duration_index = duration_index
29 | self.sub = sub
30 |
31 | @property
32 | def sub(self):
33 | return self._sub
34 |
35 | @sub.setter
36 | def sub(self, sub):
37 | if type(sub) is not int:
38 | raise ValueError(f"Need `sub` to have type `int`, got {type(sub)}")
39 | self._sub = sub
40 |
41 | def predict_hazard(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
42 | raise NotImplementedError
43 |
44 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
45 | raise NotImplementedError
46 |
47 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
48 | """Predict the survival function for `input`.
49 | See `prediction_surv_df` to return a DataFrame instead.
50 |
51 | Arguments:
52 | input {tuple, np.ndarray, or torch.tensor} -- Input to net.
53 |
54 | Keyword Arguments:
55 | batch_size {int} -- Batch size (default: {8224})
56 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
57 | (default: {None})
58 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
59 | to_cpu {bool} -- For larger data sets we need to move the results to cpu
60 | (default: {False})
61 | num_workers {int} -- Number of workers in created dataloader (default: {0})
62 |
63 | Returns:
64 | [np.ndarray or tensor] -- Predictions
65 | """
66 | return self._surv_const_pdf(input, batch_size, numpy, eval_, to_cpu, num_workers)
67 |
68 | def _surv_const_pdf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
69 | num_workers=0):
70 | """Basic method for constant PDF interpolation that use `self.model.predict_surv`.
71 |
72 | Arguments:
73 | input {tuple, np.ndarray, or torch.tensor} -- Input to net.
74 |
75 | Keyword Arguments:
76 | batch_size {int} -- Batch size (default: {8224})
77 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
78 | (default: {None})
79 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
80 | to_cpu {bool} -- For larger data sets we need to move the results to cpu
81 | (default: {False})
82 | num_workers {int} -- Number of workers in created dataloader (default: {0})
83 |
84 | Returns:
85 | [np.ndarray or tensor] -- Predictions
86 | """
87 | s = self.model.predict_surv(input, batch_size, False, eval_, to_cpu, num_workers)
88 | n, m = s.shape
89 | device = s.device
90 | diff = (s[:, 1:] - s[:, :-1]).contiguous().view(-1, 1).repeat(1, self.sub).view(n, -1)
91 | rho = torch.linspace(0, 1, self.sub+1, device=device)[:-1].contiguous().repeat(n, m-1)
92 | s_prev = s[:, :-1].contiguous().view(-1, 1).repeat(1, self.sub).view(n, -1)
93 | surv = torch.zeros(n, int((m-1)*self.sub + 1))
94 | surv[:, :-1] = diff * rho + s_prev
95 | surv[:, -1] = s[:, -1]
96 | return tt.utils.array_or_tensor(surv, numpy, input)
97 |
98 | def predict_surv_df(self, input, batch_size=8224, eval_=True, to_cpu=False, num_workers=0):
99 | """Predict the survival function for `input` and return as a pandas DataFrame.
100 | See `predict_surv` to return tensor or np.array instead.
101 |
102 | Arguments:
103 | input {tuple, np.ndarray, or torch.tensor} -- Input to net.
104 |
105 | Keyword Arguments:
106 | batch_size {int} -- Batch size (default: {8224})
107 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
108 | num_workers {int} -- Number of workers in created dataloader (default: {0})
109 |
110 | Returns:
111 | pd.DataFrame -- Predictions
112 | """
113 | surv = self.predict_surv(input, batch_size, True, eval_, to_cpu, num_workers)
114 | index = None
115 | if self.duration_index is not None:
116 | index = utils.make_subgrid(self.duration_index, self.sub)
117 | return pd.DataFrame(surv.transpose(), index)
118 |
119 |
120 | class InterpolatePMF(InterpolateDiscrete):
121 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
122 | if not self.scheme in ['const_pdf', 'lin_surv']:
123 | raise NotImplementedError
124 | pmf = self.model.predict_pmf(input, batch_size, False, eval_, to_cpu, num_workers)
125 | n, m = pmf.shape
126 | pmf_cdi = pmf[:, 1:].contiguous().view(-1, 1).repeat(1, self.sub).div(self.sub).view(n, -1)
127 | pmf_cdi = utils.pad_col(pmf_cdi, where='start')
128 | pmf_cdi[:, 0] = pmf[:, 0]
129 | return tt.utils.array_or_tensor(pmf_cdi, numpy, input)
130 |
131 | def _surv_const_pdf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
132 | pmf = self.predict_pmf(input, batch_size, False, eval_, to_cpu, num_workers)
133 | surv = 1 - pmf.cumsum(1)
134 | return tt.utils.array_or_tensor(surv, numpy, input)
135 |
136 |
137 | class InterpolateLogisticHazard(InterpolateDiscrete):
138 | epsilon = 1e-7
139 | def predict_hazard(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
140 | if self.scheme in ['const_hazard', 'exp_surv']:
141 | haz = self._hazard_const_haz(input, batch_size, numpy, eval_, to_cpu, num_workers)
142 | else:
143 | raise NotImplementedError
144 | return haz
145 |
146 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
147 | if self.scheme in ['const_hazard', 'exp_surv']:
148 | surv = self._surv_const_haz(input, batch_size, numpy, eval_, to_cpu, num_workers)
149 | elif self.scheme in ['const_pdf', 'lin_surv']:
150 | surv = self._surv_const_pdf(input, batch_size, numpy, eval_, to_cpu, num_workers)
151 | else:
152 | raise NotImplementedError
153 | return surv
154 |
155 | def _hazard_const_haz(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
156 | num_workers=0):
157 | """Computes the continuous-time constant hazard interpolation.
158 | Essentially we what the discrete survival estimates to match the continuous time at the knots.
159 | So essentially we want
160 | $$S(tau_j) = prod_{k=1}^j [1 - h_k] = prod_{k=1}{j} exp[-eta_k].$$
161 | where $h_k$ is the discrete hazard estimates and $eta_k$ continuous time hazards multiplied
162 | with the length of the duration interval as they are defined for the PC-Hazard method.
163 | Thus we get
164 | $$eta_k = - log[1 - h_k]$$
165 | which can be divided by the length of the time interval to get the continuous time hazards.
166 | """
167 | haz_orig = self.model.predict_hazard(input, batch_size, False, eval_, to_cpu, num_workers)
168 | haz = (1 - haz_orig).add(self.epsilon).log().mul(-1).relu()[:, 1:].contiguous()
169 | n = haz.shape[0]
170 | haz = haz.view(-1, 1).repeat(1, self.sub).view(n, -1).div(self.sub)
171 | haz = utils.pad_col(haz, where='start')
172 | haz[:, 0] = haz_orig[:, 0]
173 | return tt.utils.array_or_tensor(haz, numpy, input)
174 |
175 | def _surv_const_haz(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
176 | haz = self._hazard_const_haz(input, batch_size, False, eval_, to_cpu, num_workers)
177 | surv_0 = 1 - haz[:, :1]
178 | surv = utils.pad_col(haz[:, 1:], where='start').cumsum(1).mul(-1).exp().mul(surv_0)
179 | return tt.utils.array_or_tensor(surv, numpy, input)
180 |
--------------------------------------------------------------------------------
/pycox/models/logistic_hazard.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import pandas as pd
4 | import torch
5 | import torchtuples as tt
6 | from pycox import models
7 | from pycox.models.utils import pad_col, make_subgrid
8 | from pycox.preprocessing import label_transforms
9 | from pycox.models.interpolation import InterpolateLogisticHazard
10 |
11 | class LogisticHazard(models.base.SurvBase):
12 | """
13 | A discrete-time survival model that minimize the likelihood for right-censored data by
14 | parameterizing the hazard function. Also known as "Nnet-survival" [3].
15 |
16 | The Logistic-Hazard was first proposed by [2], but this implementation follows [1].
17 |
18 | Arguments:
19 | net {torch.nn.Module} -- A torch module.
20 |
21 | Keyword Arguments:
22 | optimizer {Optimizer} -- A torch optimizer or similar. Preferably use torchtuples.optim instead of
23 | torch.optim, as this allows for reinitialization, etc. If 'None' set to torchtuples.optim.AdamW.
24 | (default: {None})
25 | device {str, int, torch.device} -- Device to compute on. (default: {None})
26 | Preferably pass a torch.device object.
27 | If 'None': use default gpu if available, else use cpu.
28 | If 'int': used that gpu: torch.device('cuda:').
29 | If 'string': string is passed to torch.device('string').
30 | duration_index {list, np.array} -- Array of durations that defines the discrete times.
31 | This is used to set the index of the DataFrame in `predict_surv_df`.
32 |
33 | References:
34 | [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction
35 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019.
36 | https://arxiv.org/pdf/1910.06724.pdf
37 |
38 | [2] Charles C. Brown. On the use of indicator variables for studying the time-dependence of parameters
39 | in a response-time model. Biometrics, 31(4):863–872, 1975.
40 | https://www.jstor.org/stable/2529811?seq=1#metadata_info_tab_contents
41 |
42 | [3] Michael F. Gensheimer and Balasubramanian Narasimhan. A scalable discrete-time survival model for
43 | neural networks. PeerJ, 7:e6257, 2019.
44 | https://peerj.com/articles/6257/
45 | """
46 | label_transform = label_transforms.LabTransDiscreteTime
47 |
48 | def __init__(self, net, optimizer=None, device=None, duration_index=None, loss=None):
49 | self.duration_index = duration_index
50 | if loss is None:
51 | loss = models.loss.NLLLogistiHazardLoss()
52 | super().__init__(net, loss, optimizer, device)
53 |
54 | @property
55 | def duration_index(self):
56 | """
57 | Array of durations that defines the discrete times. This is used to set the index
58 | of the DataFrame in `predict_surv_df`.
59 |
60 | Returns:
61 | np.array -- Duration index.
62 | """
63 | return self._duration_index
64 |
65 | @duration_index.setter
66 | def duration_index(self, val):
67 | self._duration_index = val
68 |
69 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0):
70 | surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers)
71 | return pd.DataFrame(surv.transpose(), self.duration_index)
72 |
73 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
74 | num_workers=0, epsilon=1e-7):
75 | hazard = self.predict_hazard(input, batch_size, False, eval_, to_cpu, num_workers)
76 | surv = (1 - hazard).add(epsilon).log().cumsum(1).exp()
77 | return tt.utils.array_or_tensor(surv, numpy, input)
78 |
79 |
80 | def predict_hazard(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
81 | num_workers=0):
82 | hazard = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers).sigmoid()
83 | return tt.utils.array_or_tensor(hazard, numpy, input)
84 |
85 | def interpolate(self, sub=10, scheme='const_pdf', duration_index=None):
86 | """Use interpolation for predictions.
87 | There are two schemes:
88 | `const_hazard` and `exp_surv` which assumes pice-wise constant hazard in each interval (exponential survival).
89 | `const_pdf` and `lin_surv` which assumes pice-wise constant PMF in each interval (linear survival).
90 |
91 | Keyword Arguments:
92 | sub {int} -- Number of "sub" units in interpolation grid. If `sub` is 10 we have a grid with
93 | 10 times the number of grid points than the original `duration_index` (default: {10}).
94 | scheme {str} -- Type of interpolation {'const_hazard', 'const_pdf'}.
95 | See `InterpolateDiscrete` (default: {'const_pdf'})
96 | duration_index {np.array} -- Cuts used for discretization. Does not affect interpolation,
97 | only for setting index in `predict_surv_df` (default: {None})
98 |
99 | Returns:
100 | [InterpolateLogisticHazard] -- Object for prediction with interpolation.
101 | """
102 | if duration_index is None:
103 | duration_index = self.duration_index
104 | return InterpolateLogisticHazard(self, scheme, duration_index, sub)
105 |
--------------------------------------------------------------------------------
/pycox/models/mtlr.py:
--------------------------------------------------------------------------------
1 |
2 | from pycox import models
3 | import torchtuples as tt
4 | from pycox.models import utils
5 |
6 | class MTLR(models.pmf.PMFBase):
7 | """
8 | The (Neural) Multi-Task Logistic Regression, MTLR [1] and N-MTLR [2].
9 | A discrete-time survival model that minimize the likelihood for right-censored data.
10 |
11 | This is essentially a PMF parametrization with an extra cumulative sum, as explained in [3].
12 |
13 | Arguments:
14 | net {torch.nn.Module} -- A torch module.
15 |
16 | Keyword Arguments:
17 | optimizer {Optimizer} -- A torch optimizer or similar. Preferably use torchtuples.optim instead of
18 | torch.optim, as this allows for reinitialization, etc. If 'None' set to torchtuples.optim.AdamW.
19 | (default: {None})
20 | device {str, int, torch.device} -- Device to compute on. (default: {None})
21 | Preferably pass a torch.device object.
22 | If 'None': use default gpu if available, else use cpu.
23 | If 'int': used that gpu: torch.device('cuda:').
24 | If 'string': string is passed to torch.device('string').
25 | duration_index {list, np.array} -- Array of durations that defines the discrete times.
26 | This is used to set the index of the DataFrame in `predict_surv_df`.
27 |
28 | References:
29 | [1] Chun-Nam Yu, Russell Greiner, Hsiu-Chin Lin, and Vickie Baracos.
30 | Learning patient- specific cancer survival distributions as a sequence of dependent regressors.
31 | In Advances in Neural Information Processing Systems 24, pages 1845–1853.
32 | Curran Associates, Inc., 2011.
33 | https://papers.nips.cc/paper/4210-learning-patient-specific-cancer-survival-distributions-as-a-sequence-of-dependent-regressors.pdf
34 |
35 | [2] Stephane Fotso. Deep neural networks for survival analysis based on a multi-task framework.
36 | arXiv preprint arXiv:1801.05512, 2018.
37 | https://arxiv.org/pdf/1801.05512.pdf
38 |
39 | [3] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction
40 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019.
41 | https://arxiv.org/pdf/1910.06724.pdf
42 | """
43 | def __init__(self, net, optimizer=None, device=None, duration_index=None, loss=None):
44 | if loss is None:
45 | loss = models.loss.NLLMTLRLoss()
46 | super().__init__(net, loss, optimizer, device, duration_index)
47 |
48 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
49 | preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers)
50 | preds = utils.cumsum_reverse(preds, dim=1)
51 | pmf = utils.pad_col(preds).softmax(1)[:, :-1]
52 | return tt.utils.array_or_tensor(pmf, numpy, input)
53 |
--------------------------------------------------------------------------------
/pycox/models/pc_hazard.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import pandas as pd
3 | import torch
4 | import torch.nn.functional as F
5 | import torchtuples as tt
6 | from pycox import models
7 | from pycox.models.utils import pad_col, make_subgrid
8 | from pycox.preprocessing import label_transforms
9 |
10 | class PCHazard(models.base.SurvBase):
11 | """The PC-Hazard (piecewise constant hazard) method from [1].
12 | The Piecewise Constant Hazard (PC-Hazard) model from [1] which assumes that the continuous-time
13 | hazard function is constant in a set of predefined intervals. It is similar to the Piecewise
14 | Exponential Models [2] but with a softplus activation instead of the exponential function.
15 |
16 | Note that the label_transform is slightly different than that of the LogistcHazard and PMF methods.
17 | This typically results in one less output node.
18 |
19 | References:
20 | [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction
21 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019.
22 | https://arxiv.org/pdf/1910.06724.pdf
23 |
24 | [2] Michael Friedman. Piecewise exponential models for survival data with covariates.
25 | The Annals of Statistics, 10(1):101–113, 1982.
26 | https://projecteuclid.org/euclid.aos/1176345693
27 | """
28 | label_transform = label_transforms.LabTransPCHazard
29 |
30 | def __init__(self, net, optimizer=None, device=None, duration_index=None, sub=1, loss=None):
31 | self.duration_index = duration_index
32 | self.sub = sub
33 | if loss is None:
34 | loss = models.loss.NLLPCHazardLoss()
35 | super().__init__(net, loss, optimizer, device)
36 | if self.duration_index is not None:
37 | self._check_out_features()
38 |
39 | @property
40 | def sub(self):
41 | return self._sub
42 |
43 | @sub.setter
44 | def sub(self, sub):
45 | if type(sub) is not int:
46 | raise ValueError(f"Need `sub` to have type `int`, got {type(sub)}")
47 | self._sub = sub
48 |
49 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
50 | hazard = self.predict_hazard(input, batch_size, False, eval_, to_cpu, num_workers)
51 | surv = hazard.cumsum(1).mul(-1).exp()
52 | return tt.utils.array_or_tensor(surv, numpy, input)
53 |
54 | def predict_hazard(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0):
55 | """Predict the hazard function for `input`.
56 |
57 | Arguments:
58 | input {tuple, np.ndarra, or torch.tensor} -- Input to net.
59 |
60 | Keyword Arguments:
61 | batch_size {int} -- Batch size (default: {8224})
62 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input
63 | (default: {None})
64 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True})
65 | to_cpu {bool} -- For larger data sets we need to move the results to cpu
66 | (default: {False})
67 | num_workers {int} -- Number of workers in created dataloader (default: {0})
68 |
69 | Returns:
70 | [np.ndarray or tensor] -- Predicted hazards
71 | """
72 | preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers)
73 | n = preds.shape[0]
74 | hazard = F.softplus(preds).view(-1, 1).repeat(1, self.sub).view(n, -1).div(self.sub)
75 | hazard = pad_col(hazard, where='start')
76 | return tt.utils.array_or_tensor(hazard, numpy, input)
77 |
78 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0):
79 | self._check_out_features()
80 | surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers)
81 | index = None
82 | if self.duration_index is not None:
83 | index = make_subgrid(self.duration_index, self.sub)
84 | return pd.DataFrame(surv.transpose(), index)
85 |
86 | def fit(self, input, target, batch_size=256, epochs=1, callbacks=None, verbose=True,
87 | num_workers=0, shuffle=True, metrics=None, val_data=None, val_batch_size=8224,
88 | check_out_features=True, **kwargs):
89 | if check_out_features:
90 | self._check_out_features(target)
91 | return super().fit(input, target, batch_size, epochs, callbacks, verbose, num_workers,
92 | shuffle, metrics, val_data, val_batch_size, **kwargs)
93 |
94 | def fit_dataloader(self, dataloader, epochs=1, callbacks=None, verbose=True, metrics=None,
95 | val_dataloader=None, check_out_features=True):
96 | if check_out_features:
97 | self._check_out_features()
98 | return super().fit_dataloader(dataloader, epochs, callbacks, verbose, metrics, val_dataloader)
99 |
100 | def _check_out_features(self, target=None):
101 | last = list(self.net.modules())[-1]
102 | if hasattr(last, 'out_features'):
103 | m_output = last.out_features
104 | if self.duration_index is not None:
105 | n_grid = len(self.duration_index)
106 | if n_grid == m_output:
107 | raise ValueError("Output of `net` is one too large. Should have length "+
108 | f"{len(self.duration_index)-1}")
109 | if n_grid != (m_output + 1):
110 | raise ValueError(f"Output of `net` does not correspond with `duration_index`")
111 | if target is not None:
112 | max_idx = tt.tuplefy(target).to_numpy()[0].max()
113 | if m_output != (max_idx + 1):
114 | raise ValueError(f"Output of `net` is {m_output}, but data only trains {max_idx + 1} indices. "+
115 | f"Output of `net` should be {max_idx + 1}."+
116 | "Set `check_out_feature=False` to suppress this Error.")
117 |
--------------------------------------------------------------------------------
/pycox/models/pmf.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import torchtuples as tt
3 | from pycox import models
4 | from pycox.models.utils import pad_col
5 | from pycox.preprocessing import label_transforms
6 | from pycox.models.interpolation import InterpolatePMF
7 |
8 |
9 | class PMFBase(models.base.SurvBase):
10 | """Base class for PMF methods.
11 | """
12 | label_transform = label_transforms.LabTransDiscreteTime
13 |
14 | def __init__(self, net, loss=None, optimizer=None, device=None, duration_index=None):
15 | self.duration_index = duration_index
16 | super().__init__(net, loss, optimizer, device)
17 |
18 | @property
19 | def duration_index(self):
20 | """
21 | Array of durations that defines the discrete times. This is used to set the index
22 | of the DataFrame in `predict_surv_df`.
23 |
24 | Returns:
25 | np.array -- Duration index.
26 | """
27 | return self._duration_index
28 |
29 | @duration_index.setter
30 | def duration_index(self, val):
31 | self._duration_index = val
32 |
33 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
34 | num_workers=0):
35 | pmf = self.predict_pmf(input, batch_size, False, eval_, to_cpu, num_workers)
36 | surv = 1 - pmf.cumsum(1)
37 | return tt.utils.array_or_tensor(surv, numpy, input)
38 |
39 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False,
40 | num_workers=0):
41 | preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers)
42 | pmf = pad_col(preds).softmax(1)[:, :-1]
43 | return tt.utils.array_or_tensor(pmf, numpy, input)
44 |
45 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0):
46 | surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers)
47 | return pd.DataFrame(surv.transpose(), self.duration_index)
48 |
49 | def interpolate(self, sub=10, scheme='const_pdf', duration_index=None):
50 | """Use interpolation for predictions.
51 | There are only one scheme:
52 | `const_pdf` and `lin_surv` which assumes pice-wise constant pmf in each interval (linear survival).
53 |
54 | Keyword Arguments:
55 | sub {int} -- Number of "sub" units in interpolation grid. If `sub` is 10 we have a grid with
56 | 10 times the number of grid points than the original `duration_index` (default: {10}).
57 | scheme {str} -- Type of interpolation {'const_hazard', 'const_pdf'}.
58 | See `InterpolateDiscrete` (default: {'const_pdf'})
59 | duration_index {np.array} -- Cuts used for discretization. Does not affect interpolation,
60 | only for setting index in `predict_surv_df` (default: {None})
61 |
62 | Returns:
63 | [InterpolationPMF] -- Object for prediction with interpolation.
64 | """
65 | if duration_index is None:
66 | duration_index = self.duration_index
67 | return InterpolatePMF(self, scheme, duration_index, sub)
68 |
69 |
70 | class PMF(PMFBase):
71 | """
72 | The PMF is a discrete-time survival model that parametrize the probability mass function (PMF)
73 | and optimizer the survival likelihood. It is the foundation of methods such as DeepHit and MTLR.
74 | See [1] for details.
75 |
76 | Arguments:
77 | net {torch.nn.Module} -- A torch module.
78 |
79 | Keyword Arguments:
80 | optimizer {Optimizer} -- A torch optimizer or similar. Preferably use torchtuples.optim instead of
81 | torch.optim, as this allows for reinitialization, etc. If 'None' set to torchtuples.optim.AdamW.
82 | (default: {None})
83 | device {str, int, torch.device} -- Device to compute on. (default: {None})
84 | Preferably pass a torch.device object.
85 | If 'None': use default gpu if available, else use cpu.
86 | If 'int': used that gpu: torch.device('cuda:').
87 | If 'string': string is passed to torch.device('string').
88 | duration_index {list, np.array} -- Array of durations that defines the discrete times.
89 | This is used to set the index of the DataFrame in `predict_surv_df`.
90 |
91 | References:
92 | [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction
93 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019.
94 | https://arxiv.org/pdf/1910.06724.pdf
95 | """
96 | def __init__(self, net, optimizer=None, device=None, duration_index=None, loss=None):
97 | if loss is None:
98 | loss = models.loss.NLLPMFLoss()
99 | super().__init__(net, loss, optimizer, device, duration_index)
100 |
101 |
--------------------------------------------------------------------------------
/pycox/models/utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 | import torchtuples as tt
6 |
7 | def pad_col(input, val=0, where='end'):
8 | """Addes a column of `val` at the start of end of `input`."""
9 | if len(input.shape) != 2:
10 | raise ValueError(f"Only works for `phi` tensor that is 2-D.")
11 | pad = torch.zeros_like(input[:, :1])
12 | if val != 0:
13 | pad = pad + val
14 | if where == 'end':
15 | return torch.cat([input, pad], dim=1)
16 | elif where == 'start':
17 | return torch.cat([pad, input], dim=1)
18 | raise ValueError(f"Need `where` to be 'start' or 'end', got {where}")
19 |
20 | def array_or_tensor(tensor, numpy, input):
21 | warnings.warn('Use `torchtuples.utils.array_or_tensor` instead', DeprecationWarning)
22 | return tt.utils.array_or_tensor(tensor, numpy, input)
23 |
24 | def make_subgrid(grid, sub=1):
25 | """When calling `predict_surv` with sub != 1 this can help with
26 | creating the duration index of the survival estimates.
27 |
28 | E.g.
29 | sub = 5
30 | surv = model.predict_surv(test_input, sub=sub)
31 | grid = model.make_subgrid(cuts, sub)
32 | surv = pd.DataFrame(surv, index=grid)
33 | """
34 | subgrid = tt.TupleTree(np.linspace(start, end, num=sub+1)[:-1]
35 | for start, end in zip(grid[:-1], grid[1:]))
36 | subgrid = subgrid.apply(lambda x: tt.TupleTree(x)).flatten() + (grid[-1],)
37 | return subgrid
38 |
39 | def log_softplus(input, threshold=-15.):
40 | """Equivalent to 'F.softplus(input).log()', but for 'input < threshold',
41 | we return 'input', as this is approximately the same.
42 |
43 | Arguments:
44 | input {torch.tensor} -- Input tensor
45 |
46 | Keyword Arguments:
47 | threshold {float} -- Treshold for when to just return input (default: {-15.})
48 |
49 | Returns:
50 | torch.tensor -- return log(softplus(input)).
51 | """
52 | output = input.clone()
53 | above = input >= threshold
54 | output[above] = F.softplus(input[above]).log()
55 | return output
56 |
57 | def cumsum_reverse(input: torch.Tensor, dim: int = 1) -> torch.Tensor:
58 | if dim != 1:
59 | raise NotImplementedError
60 | input = input.sum(1, keepdim=True) - pad_col(input, where='start').cumsum(1)
61 | return input[:, :-1]
62 |
--------------------------------------------------------------------------------
/pycox/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from pycox.preprocessing import feature_transforms, label_transforms, discretization
2 |
--------------------------------------------------------------------------------
/pycox/preprocessing/discretization.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import numpy as np
3 | import pandas as pd
4 | from pycox import utils
5 |
6 |
7 | def make_cuts(n_cuts, scheme, durations, events, min_=0., dtype='float64'):
8 | if scheme == 'equidistant':
9 | cuts = cuts_equidistant(durations.max(), n_cuts, min_, dtype)
10 | elif scheme == 'quantiles':
11 | cuts = cuts_quantiles(durations, events, n_cuts, min_, dtype)
12 | else:
13 | raise ValueError(f"Got invalid `scheme` {scheme}.")
14 | if (np.diff(cuts) == 0).any():
15 | raise ValueError("cuts are not unique.")
16 | return cuts
17 |
18 | def _values_if_series(x):
19 | if type(x) is pd.Series:
20 | return x.values
21 | return x
22 |
23 | def cuts_equidistant(max_, num, min_=0., dtype='float64'):
24 | return np.linspace(min_, max_, num, dtype=dtype)
25 |
26 | def cuts_quantiles(durations, events, num, min_=0., dtype='float64'):
27 | """
28 | If min_ = None, we will use durations.min() for the first cut.
29 | """
30 | km = utils.kaplan_meier(durations, events)
31 | surv_est, surv_durations = km.values, km.index.values
32 | s_cuts = np.linspace(km.values.min(), km.values.max(), num)
33 | cuts_idx = np.searchsorted(surv_est[::-1], s_cuts)[::-1]
34 | cuts = surv_durations[::-1][cuts_idx]
35 | cuts = np.unique(cuts)
36 | if len(cuts) != num:
37 | warnings.warn(f"cuts are not unique, continue with {len(cuts)} cuts instead of {num}")
38 | cuts[0] = durations.min() if min_ is None else min_
39 | assert cuts[-1] == durations.max(), 'something wrong...'
40 | return cuts.astype(dtype)
41 |
42 | def _is_monotonic_increasing(x):
43 | assert len(x.shape) == 1, 'Only works for 1d'
44 | return (x[1:] >= x[:-1]).all()
45 |
46 | def bin_numerical(x, right_cuts, error_on_larger=False):
47 | """
48 | Discretize x into bins defined by right_cuts (needs to be sorted).
49 | If right_cuts = [1, 2], we have bins (-inf, 1], (1, 2], (2, inf).
50 | error_on_larger results in a ValueError if x contains larger
51 | values than right_cuts.
52 |
53 | Returns index of bins.
54 | To optaine values do righ_cuts[bin_numerica(x, right_cuts)].
55 | """
56 | assert _is_monotonic_increasing(right_cuts), 'Need `right_cuts` to be sorted.'
57 | bins = np.searchsorted(right_cuts, x, side='left')
58 | if bins.max() == right_cuts.size:
59 | if error_on_larger:
60 | raise ValueError('x contrains larger values than right_cuts.')
61 | return bins
62 |
63 | def discretize(x, cuts, side='right', error_on_larger=False):
64 | """Discretize x to cuts.
65 |
66 | Arguments:
67 | x {np.array} -- Array of times.
68 | cuts {np.array} -- Sortet array of discrete times.
69 |
70 | Keyword Arguments:
71 | side {str} -- If we shold round down or up (left, right) (default: {'right'})
72 | error_on_larger {bool} -- If we shold return an error if we pass higher values
73 | than cuts (default: {False})
74 |
75 | Returns:
76 | np.array -- Discretized values.
77 | """
78 | if side not in ['right', 'left']:
79 | raise ValueError('side argument needs to be right or left.')
80 | bins = bin_numerical(x, cuts, error_on_larger)
81 | if side == 'right':
82 | cuts = np.concatenate((cuts, np.array([np.inf])))
83 | return cuts[bins]
84 | bins_cut = bins.copy()
85 | bins_cut[bins_cut == cuts.size] = -1
86 | exact = cuts[bins_cut] == x
87 | left_bins = bins - 1 + exact
88 | vals = cuts[left_bins]
89 | vals[left_bins == -1] = - np.inf
90 | return vals
91 |
92 |
93 | class _OnlyTransform:
94 | """Abstract class for sklearn preprocessing methods.
95 | Only implements fit and fit_transform.
96 | """
97 | def fit(self, *args):
98 | return self
99 |
100 | def transform(self, *args):
101 | raise NotImplementedError
102 |
103 | def fit_transform(self, *args):
104 | return self.fit(*args).transform(*args)
105 |
106 |
107 | class DiscretizeUnknownC(_OnlyTransform):
108 | """Implementation of scheme 2.
109 |
110 | cuts should be [t0, t1, ..., t_m], where t_m is right sensored value.
111 | """
112 | def __init__(self, cuts, right_censor=False, censor_side='left'):
113 | self.cuts = cuts
114 | self.right_censor = right_censor
115 | self.censor_side = censor_side
116 |
117 | def transform(self, duration, event):
118 | dtype_event = event.dtype
119 | event = event.astype('bool')
120 | if self.right_censor:
121 | duration = duration.copy()
122 | censor = duration > self.cuts.max()
123 | duration[censor] = self.cuts.max()
124 | event[censor] = False
125 | if duration.max() > self.cuts.max():
126 | raise ValueError("`duration` contains larger values than cuts. Set `right_censor`=True to censor these")
127 | td = np.zeros_like(duration)
128 | c = event == False
129 | td[event] = discretize(duration[event], self.cuts, side='right', error_on_larger=True)
130 | if c.any():
131 | td[c] = discretize(duration[c], self.cuts, side=self.censor_side, error_on_larger=True)
132 | return td, event.astype(dtype_event)
133 |
134 |
135 | def duration_idx_map(duration):
136 | duration = np.unique(duration)
137 | duration = np.sort(duration)
138 | idx = np.arange(duration.shape[0])
139 | return {d: i for i, d in zip(idx, duration)}
140 |
141 |
142 | class Duration2Idx(_OnlyTransform):
143 | def __init__(self, durations=None):
144 | self.durations = durations
145 | if durations is None:
146 | raise NotImplementedError()
147 | if self.durations is not None:
148 | self.duration_to_idx = self._make_map(self.durations)
149 |
150 | @staticmethod
151 | def _make_map(durations):
152 | return np.vectorize(duration_idx_map(durations).get)
153 |
154 | def transform(self, duration, y=None):
155 | if duration.dtype is not self.durations.dtype:
156 | raise ValueError('Need `time` to have same type as `self.durations`.')
157 | idx = self.duration_to_idx(duration)
158 | if np.isnan(idx).any():
159 | raise ValueError('Encountered `nans` in transformed indexes.')
160 | return idx
161 |
162 |
163 | class IdxDiscUnknownC:
164 | """Get indexed for discrete data using cuts.
165 |
166 | Arguments:
167 | cuts {np.array} -- Array or right cuts.
168 |
169 | Keyword Arguments:
170 | label_cols {tuple} -- Name of label columns in dataframe (default: {None}).
171 | """
172 | def __init__(self, cuts, label_cols=None, censor_side='left'):
173 | self.cuts = cuts
174 | self.duc = DiscretizeUnknownC(cuts, right_censor=True, censor_side=censor_side)
175 | self.di = Duration2Idx(cuts)
176 | self.label_cols= label_cols
177 |
178 | def transform(self, time, d):
179 | time, d = self.duc.transform(time, d)
180 | idx = self.di.transform(time)
181 | return idx, d
182 |
183 | def transform_df(self, df):
184 | if self.label_cols is None:
185 | raise RuntimeError("Need to set 'label_cols' to use this. Use 'transform instead'")
186 | col_duration, col_event = self.label_cols
187 | time = df[col_duration].values
188 | d = df[col_event].values
189 | return self.transform(time, d)
190 |
--------------------------------------------------------------------------------
/pycox/preprocessing/feature_transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | class OrderedCategoricalLong:
5 | """Transform pandas series or numpy array to categorical, and get (long) values,
6 | i.e. index of category. Useful for entity embeddings.
7 | Zero is reserved for unknown categories or nans.
8 |
9 | Keyword Arguments:
10 | min_per_category {int} -- Number of instances required to not be set to nan (default: {20})
11 | return_series {bool} -- If return a array or pd.Series (default: {False})
12 |
13 | Returns:
14 | [pd.Series] -- Series with long values reffering to categories.
15 | """
16 | def __init__(self, min_per_category=20, return_series=False):
17 |
18 | self.min_per_category = min_per_category
19 | self.return_series = return_series
20 |
21 | def fit(self, series, y=None):
22 | series = pd.Series(series).copy()
23 | smaller = series.value_counts() < self.min_per_category
24 | values = smaller[smaller].index.values
25 | for v in values:
26 | series[series == v] = np.nan
27 | self.categories = series.astype('category').cat.categories
28 | return self
29 |
30 | def transform(self, series, y=None):
31 | series = pd.Series(series).copy()
32 | transformed = pd.Categorical(series, categories=self.categories, ordered=True)
33 | transformed = pd.Series(transformed, index=series.index)
34 | transformed = transformed.cat.codes.astype('int64') + 1
35 | return transformed if self.return_series else transformed.values
36 |
37 | def fit_transform(self, series, y=None):
38 | return self.fit(series, y).transform(series, y)
39 |
--------------------------------------------------------------------------------
/pycox/preprocessing/label_transforms.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import numpy as np
3 | from sklearn.preprocessing import StandardScaler
4 | from pycox.preprocessing.discretization import (make_cuts, IdxDiscUnknownC, _values_if_series,
5 | DiscretizeUnknownC, Duration2Idx)
6 |
7 |
8 | class LabTransCoxTime:
9 | """
10 | Label transforms useful for CoxTime models. It can log-transform and standardize the durations.
11 |
12 | It also creates `map_scaled_to_orig` which is the inverse transform of the durations data,
13 | enabling us to set the correct time scale for predictions.
14 | This can be done by passing the object to the CoxTime init:
15 | model = CoxTime(net, labrans=labtrans)
16 | which gives the correct time scale of survival predictions
17 | surv = model.predict_surv_df(x)
18 |
19 | Keyword Arguments:
20 | log_duration {bool} -- Log-transform durations, i.e. 'log(1+x)'. (default: {False})
21 | with_mean {bool} -- Center the duration before scaling.
22 | Passed to `sklearn.preprocessing.StandardScaler` (default: {True})
23 | with_std {bool} -- Scale duration to unit variance.
24 | Passed to `sklearn.preprocessing.StandardScaler` (default: {True})
25 | """
26 | def __init__(self, log_duration=False, with_mean=True, with_std=True):
27 | self.log_duration = log_duration
28 | self.duration_scaler = StandardScaler(copy=True, with_mean=with_mean, with_std=with_std)
29 |
30 | @property
31 | def map_scaled_to_orig(self):
32 | """Map from transformed durations back to the original durations, i.e. inverse transform.
33 |
34 | Use it to e.g. set index of survival predictions:
35 | surv = model.predict_surv_df(x_test)
36 | surv.index = labtrans.map_scaled_to_orig(surv.index)
37 | """
38 | if not hasattr(self, '_inverse_duration_map'):
39 | raise ValueError('Need to fit the models before you can call this method')
40 | return self._inverse_duration_map
41 |
42 | def fit(self, durations, events):
43 | self.fit_transform(durations, events)
44 | return self
45 |
46 | def fit_transform(self, durations, events):
47 | train_durations = durations
48 | durations = durations.astype('float32')
49 | events = events.astype('float32')
50 | if self.log_duration:
51 | durations = np.log1p(durations)
52 | durations = self.duration_scaler.fit_transform(durations.reshape(-1, 1)).flatten()
53 | self._inverse_duration_map = {scaled: orig for orig, scaled in zip(train_durations, durations)}
54 | self._inverse_duration_map = np.vectorize(self._inverse_duration_map.get)
55 | return durations, events
56 |
57 | def transform(self, durations, events):
58 | durations = durations.astype('float32')
59 | events = events.astype('float32')
60 | if self.log_duration:
61 | durations = np.log1p(durations)
62 | durations = self.duration_scaler.transform(durations.reshape(-1, 1)).flatten()
63 | return durations, events
64 |
65 | @property
66 | def out_features(self):
67 | """Returns the number of output features that should be used in the torch model.
68 | This always returns 1, and is just included for api design purposes.
69 |
70 | Returns:
71 | [int] -- Number of output features.
72 | """
73 | return 1
74 |
75 |
76 | class LabTransDiscreteTime:
77 | """
78 | Discretize continuous (duration, event) pairs based on a set of cut points.
79 | One can either determine the cut points in form of passing an array to this class,
80 | or one can obtain cut points based on the training data.
81 |
82 | The discretization learned from fitting to data will move censorings to the left cut point,
83 | and events to right cut point.
84 |
85 | Arguments:
86 | cuts {int, array} -- Defining cut points, either the number of cuts, or the actual cut points.
87 |
88 | Keyword Arguments:
89 | scheme {str} -- Scheme used for discretization. Either 'equidistant' or 'quantiles'
90 | (default: {'equidistant})
91 | min_ {float} -- Starting duration (default: {0.})
92 | dtype {str, dtype} -- dtype of discretization.
93 | """
94 | def __init__(self, cuts, scheme='equidistant', min_=0., dtype=None):
95 | self._cuts = cuts
96 | self._scheme = scheme
97 | self._min = min_
98 | self._dtype_init = dtype
99 | self._predefined_cuts = False
100 | self.cuts = None
101 | if hasattr(cuts, '__iter__'):
102 | if type(cuts) is list:
103 | cuts = np.array(cuts)
104 | self.cuts = cuts
105 | self.idu = IdxDiscUnknownC(self.cuts)
106 | assert dtype is None, "Need `dtype` to be `None` for specified cuts"
107 | self._dtype = type(self.cuts[0])
108 | self._dtype_init = self._dtype
109 | self._predefined_cuts = True
110 |
111 | def fit(self, durations, events):
112 | if self._predefined_cuts:
113 | warnings.warn("Calling fit method, when 'cuts' are already defined. Leaving cuts unchanged.")
114 | return self
115 | self._dtype = self._dtype_init
116 | if self._dtype is None:
117 | if isinstance(durations[0], np.floating):
118 | self._dtype = durations.dtype
119 | else:
120 | self._dtype = np.dtype('float64')
121 | durations = durations.astype(self._dtype)
122 | self.cuts = make_cuts(self._cuts, self._scheme, durations, events, self._min, self._dtype)
123 | self.idu = IdxDiscUnknownC(self.cuts)
124 | return self
125 |
126 | def fit_transform(self, durations, events):
127 | self.fit(durations, events)
128 | idx_durations, events = self.transform(durations, events)
129 | return idx_durations, events
130 |
131 | def transform(self, durations, events):
132 | durations = _values_if_series(durations)
133 | durations = durations.astype(self._dtype)
134 | events = _values_if_series(events)
135 | idx_durations, events = self.idu.transform(durations, events)
136 | return idx_durations.astype('int64'), events.astype('float32')
137 |
138 | @property
139 | def out_features(self):
140 | """Returns the number of output features that should be used in the torch model.
141 |
142 | Returns:
143 | [int] -- Number of output features.
144 | """
145 | if self.cuts is None:
146 | raise ValueError("Need to call `fit` before this is accessible.")
147 | return len(self.cuts)
148 |
149 |
150 | class LabTransPCHazard:
151 | """
152 | Defining time intervals (`cuts`) needed for the `PCHazard` method [1].
153 | One can either determine the cut points in form of passing an array to this class,
154 | or one can obtain cut points based on the training data.
155 |
156 | Arguments:
157 | cuts {int, array} -- Defining cut points, either the number of cuts, or the actual cut points.
158 |
159 | Keyword Arguments:
160 | scheme {str} -- Scheme used for discretization. Either 'equidistant' or 'quantiles'
161 | (default: {'equidistant})
162 | min_ {float} -- Starting duration (default: {0.})
163 | dtype {str, dtype} -- dtype of discretization.
164 |
165 | References:
166 | [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction
167 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019.
168 | https://arxiv.org/pdf/1910.06724.pdf
169 | """
170 | def __init__(self, cuts, scheme='equidistant', min_=0., dtype=None):
171 | self._cuts = cuts
172 | self._scheme = scheme
173 | self._min = min_
174 | self._dtype_init = dtype
175 | self._predefined_cuts = False
176 | self.cuts = None
177 | if hasattr(cuts, '__iter__'):
178 | if type(cuts) is list:
179 | cuts = np.array(cuts)
180 | self.cuts = cuts
181 | self.idu = IdxDiscUnknownC(self.cuts)
182 | assert dtype is None, "Need `dtype` to be `None` for specified cuts"
183 | self._dtype = type(self.cuts[0])
184 | self._dtype_init = self._dtype
185 | self._predefined_cuts = True
186 | else:
187 | self._cuts += 1
188 |
189 | def fit(self, durations, events):
190 | if self._predefined_cuts:
191 | warnings.warn("Calling fit method, when 'cuts' are already defined. Leaving cuts unchanged.")
192 | return self
193 | self._dtype = self._dtype_init
194 | if self._dtype is None:
195 | if isinstance(durations[0], np.floating):
196 | self._dtype = durations.dtype
197 | else:
198 | self._dtype = np.dtype('float64')
199 | durations = durations.astype(self._dtype)
200 | self.cuts = make_cuts(self._cuts, self._scheme, durations, events, self._min, self._dtype)
201 | self.duc = DiscretizeUnknownC(self.cuts, right_censor=True, censor_side='right')
202 | self.di = Duration2Idx(self.cuts)
203 | return self
204 |
205 | def fit_transform(self, durations, events):
206 | self.fit(durations, events)
207 | return self.transform(durations, events)
208 |
209 | def transform(self, durations, events):
210 | durations = _values_if_series(durations)
211 | durations = durations.astype(self._dtype)
212 | events = _values_if_series(events)
213 | dur_disc, events = self.duc.transform(durations, events)
214 | idx_durations = self.di.transform(dur_disc)
215 | cut_diff = np.diff(self.cuts)
216 | assert (cut_diff > 0).all(), 'Cuts are not unique.'
217 | t_frac = 1. - (dur_disc - durations) / cut_diff[idx_durations-1]
218 | if idx_durations.min() == 0:
219 | warnings.warn("""Got event/censoring at start time. Should be removed! It is set s.t. it has no contribution to loss.""")
220 | t_frac[idx_durations == 0] = 0
221 | events[idx_durations == 0] = 0
222 | idx_durations = idx_durations - 1
223 | return idx_durations.astype('int64'), events.astype('float32'), t_frac.astype('float32')
224 |
225 | @property
226 | def out_features(self):
227 | """Returns the number of output features that should be used in the torch model.
228 |
229 | Returns:
230 | [int] -- Number of output features.
231 | """
232 | if self.cuts is None:
233 | raise ValueError("Need to call `fit` before this is accessible.")
234 | return len(self.cuts) - 1
235 |
--------------------------------------------------------------------------------
/pycox/simulations/__init__.py:
--------------------------------------------------------------------------------
1 | from pycox.simulations import relative_risk, discrete_logit_hazard
2 | from pycox.simulations.relative_risk import (SimStudyLinearPH, SimStudyNonLinearPH,
3 | SimStudyNonLinearNonPH)
4 | from pycox.simulations.discrete_logit_hazard import (SimStudySACCensorConst, SimStudySACAdmin,
5 | SimStudySingleSurvUniformAdmin)
6 |
--------------------------------------------------------------------------------
/pycox/simulations/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | def dict2df(data, add_true=True, add_censor_covs=False):
5 | """Make a pd.DataFrame from the dict obtained when simulating.
6 |
7 | Arguments:
8 | data {dict} -- Dict from simulation.
9 |
10 | Keyword Arguments:
11 | add_true {bool} -- If we should include the true duration and censoring times
12 | (default: {True})
13 | add_censor_covs {bool} -- If we should include the censor covariates as covariates.
14 | (default: {False})
15 |
16 | Returns:
17 | pd.DataFrame -- A DataFrame
18 | """
19 | covs = data['covs']
20 | if add_censor_covs:
21 | covs = np.concatenate([covs, data['censor_covs']], axis=1)
22 | df = (pd.DataFrame(covs, columns=[f"x{i}" for i in range(covs.shape[1])])
23 | .assign(duration=data['durations'].astype('float32'),
24 | event=data['events'].astype('float32')))
25 | if add_true:
26 | df = df.assign(duration_true=data['durations_true'].astype('float32'),
27 | event_true=data['events_true'].astype('float32'),
28 | censoring_true=data['censor_durations'].astype('float32'))
29 | return df
30 |
31 |
32 | class _SimBase:
33 | def simulate(self, n, surv_df=False):
34 | """Simulate dataset of size `n`.
35 |
36 | Arguments:
37 | n {int} -- Number of simulations
38 |
39 | Keyword Arguments:
40 | surv_df {bool} -- If a dataframe containing the survival function should be returned.
41 | (default: {False})
42 |
43 | Returns:
44 | [dict] -- A dictionary with the results.
45 | """
46 | raise NotImplementedError
47 |
48 | def surv_df(self, *args):
49 | """Returns a data frame containing the survival function.
50 | """
51 | raise NotImplementedError
52 |
--------------------------------------------------------------------------------
/pycox/simulations/relative_risk.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 | from pycox.simulations import base
5 |
6 |
7 | class _SimStudyRelativeRisk(base._SimBase):
8 | '''Abstract class for simulation relative risk survival data,
9 | with constant baseline, and constant censoring distribution
10 |
11 | Parameters:
12 | h0: Is baseline constant.
13 | right_c: Time for right censoring.
14 | c0: Constant censoring distribution
15 | '''
16 | def __init__(self, h0, right_c=30., c0=30., surv_grid=None):
17 | self.h0 = h0
18 | self.right_c = right_c
19 | self.c0 = c0
20 | self.surv_grid = surv_grid
21 |
22 | def simulate(self, n, surv_df=False):
23 | covs = self.sample_covs(n).astype('float32')
24 | v = np.random.exponential(size=n)
25 | t = self.inv_cum_hazard(v, covs)
26 | c = self.c0 * np.random.exponential(size=n)
27 | tt = t.copy()
28 | tt[c < t] = c[c < t]
29 | tt[tt > self.right_c] = self.right_c
30 | d = tt == t
31 | surv_df = self.surv_df(covs, self.surv_grid) if surv_df else None
32 | # censor_surv_df = NotImplemented if censor_df else None
33 | return dict(covs=covs, durations=tt, events=d, surv_df=surv_df, durations_true=t,
34 | events_true=np.ones_like(t), censor_durations=c,
35 | censor_events=np.ones_like(c))
36 |
37 | @staticmethod
38 | def sample_covs(n):
39 | raise NotImplementedError
40 |
41 | def inv_cum_hazard(self, v, covs):
42 | '''The inverse of the cumulative hazard.'''
43 | raise NotImplementedError
44 |
45 | def cum_hazard(self, t, covs):
46 | '''The the cumulative hazard function.'''
47 | raise NotImplementedError
48 |
49 | def survival_func(self, t, covs):
50 | '''Returns the survival function.'''
51 | return np.exp(-self.cum_hazard(t, covs))
52 |
53 | def survival_grid_single(self, covs, t=None):
54 | covs = covs.reshape(1, -1)
55 | if t is None:
56 | t = np.arange(0, 31, 0.5)
57 | return pd.Series(self.survival_func(t, covs), index=t)
58 |
59 | def surv_df(self, covs, t=None):
60 | if t is None:
61 | t = np.linspace(0, 30, 100)
62 | s = [self.survival_grid_single(xx, t) for xx in covs]
63 | return pd.concat(s, axis=1)
64 |
65 | @staticmethod
66 | def dict2df(data, add_true=True):
67 | """Make a pd.DataFrame from the dict obtained when simulating.
68 |
69 | Arguments:
70 | data {dict} -- Dict from simulation.
71 |
72 | Keyword Arguments:
73 | add_true {bool} -- If we should include the true duration and censoring times
74 | (default: {True})
75 |
76 | Returns:
77 | pd.DataFrame -- A DataFrame
78 | """
79 | return base.dict2df(data, add_true)
80 |
81 |
82 | class SimStudyLinearPH(_SimStudyRelativeRisk):
83 | '''Survival simulations study for linear prop. hazard model
84 | h(t | x) = h0 exp[g(x)], where g(x) is linear.
85 |
86 | Parameters:
87 | h0: Is baseline constant.
88 | right_c: Time for right censoring.
89 | '''
90 | def __init__(self, h0=0.1, right_c=30., c0=30., surv_grid=None):
91 | super().__init__(h0, right_c, c0, surv_grid)
92 |
93 | @staticmethod
94 | def sample_covs(n):
95 | return np.random.uniform(-1, 1, size=(n, 3))
96 |
97 | @staticmethod
98 | def g(covs):
99 | x = covs
100 | x0, x1, x2 = x[:, 0], x[:, 1], x[:, 2]
101 | return 0.44 * x0 + 0.66 * x1 + 0.88 * x2
102 |
103 | def inv_cum_hazard(self, v, covs):
104 | '''The inverse of the cumulative hazard.'''
105 | return v / (self.h0 * np.exp(self.g(covs)))
106 |
107 | def cum_hazard(self, t, covs):
108 | '''The the cumulative hazard function.'''
109 | return self.h0 * t * np.exp(self.g(covs))
110 |
111 |
112 | class SimStudyNonLinearPH(SimStudyLinearPH):
113 | '''Survival simulations study for non-linear prop. hazard model
114 | h(t | x) = h0 exp[g(x)], where g(x) is non-linear.
115 |
116 | Parameters:
117 | h0: Is baseline constant.
118 | right_c: Time for right censoring.
119 | '''
120 | @staticmethod
121 | def g(covs):
122 | x = covs
123 | x0, x1, x2 = x[:, 0], x[:, 1], x[:, 2]
124 | beta = 2/3
125 | linear = SimStudyLinearPH.g(x)
126 | nonlinear = beta * (x0**2 + x2**2 + x0*x1 + x1*x2 + x1*x2)
127 | return linear + nonlinear
128 |
129 |
130 | class SimStudyNonLinearNonPH(SimStudyNonLinearPH):
131 | '''Survival simulations study for non-linear non-prop. hazard model.
132 | h(t | x) = h0 * exp[g(t, x)],
133 | with constant h_0, and g(t, x) = a(x) + b(x)*t.
134 |
135 | Cumulative hazard:
136 | H(t | x) = h0 / b(x) * exp[a(x)] * (exp[b(x) * t] - 1)
137 | Inverse:
138 | H^{-1}(v, x) = 1/b(x) log{1 + v * b(x) / h0 exp[-a(x)]}
139 |
140 | Parameters:
141 | h0: Is baseline constant.
142 | right_c: Time for right censoring.
143 | '''
144 | def __init__(self, h0=0.02, right_c=30., c0=30., surv_grid=None):
145 | super().__init__(h0, right_c, c0, surv_grid)
146 |
147 | @staticmethod
148 | def a(x):
149 | _, _, x2 = x[:, 0], x[:, 1], x[:, 2]
150 | return np.sign(x2) + SimStudyNonLinearPH.g(x)
151 |
152 | @staticmethod
153 | def b(x):
154 | x0, x1, _ = x[:, 0], x[:, 1], x[:, 2]
155 | return np.abs(0.2 * (x0 + x1) + 0.5 * x0 * x1)
156 |
157 | @staticmethod
158 | def g(t, covs):
159 | x = covs
160 | return SimStudyNonLinearNonPH.a(x) + SimStudyNonLinearNonPH.b(x) * t
161 |
162 | def inv_cum_hazard(self, v, covs):
163 | x = covs
164 | return 1 / self.b(x) * np.log(1 + v * self.b(x) / self.h0 * np.exp(-self.a(x)))
165 |
166 | def cum_hazard(self, t, covs):
167 | x = covs
168 | return self.h0 / self.b(x) * np.exp(self.a(x)) * (np.exp(self.b(x)*t) - 1)
169 |
--------------------------------------------------------------------------------
/pycox/utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import pandas as pd
3 | import numpy as np
4 | import numba
5 |
6 | def idx_at_times(index_surv, times, steps='pre', assert_sorted=True):
7 | """Gives index of `index_surv` corresponding to `time`, i.e.
8 | `index_surv[idx_at_times(index_surv, times)]` give the values of `index_surv`
9 | closet to `times`.
10 |
11 | Arguments:
12 | index_surv {np.array} -- Durations of survival estimates
13 | times {np.array} -- Values one want to match to `index_surv`
14 |
15 | Keyword Arguments:
16 | steps {str} -- Round 'pre' (closest value higher) or 'post'
17 | (closest value lower) (default: {'pre'})
18 | assert_sorted {bool} -- Assert that index_surv is monotone (default: {True})
19 |
20 | Returns:
21 | np.array -- Index of `index_surv` that is closest to `times`
22 | """
23 | if assert_sorted:
24 | assert pd.Series(index_surv).is_monotonic_increasing, "Need 'index_surv' to be monotonic increasing"
25 | if steps == 'pre':
26 | idx = np.searchsorted(index_surv, times)
27 | elif steps == 'post':
28 | idx = np.searchsorted(index_surv, times, side='right') - 1
29 | return idx.clip(0, len(index_surv)-1)
30 |
31 | @numba.njit
32 | def _group_loop(n, surv_idx, durations, events, di, ni):
33 | idx = 0
34 | for i in range(n):
35 | idx += durations[i] != surv_idx[idx]
36 | di[idx] += events[i]
37 | ni[idx] += 1
38 | return di, ni
39 |
40 | def kaplan_meier(durations, events, start_duration=0):
41 | """A very simple Kaplan-Meier fitter. For a more complete implementation
42 | see `lifelines`.
43 |
44 | Arguments:
45 | durations {np.array} -- durations array
46 | events {np.arrray} -- events array 0/1
47 |
48 | Keyword Arguments:
49 | start_duration {int} -- Time start as `start_duration`. (default: {0})
50 |
51 | Returns:
52 | pd.Series -- Kaplan-Meier estimates.
53 | """
54 | n = len(durations)
55 | assert n == len(events)
56 | if start_duration > durations.min():
57 | warnings.warn(f"start_duration {start_duration} is larger than minimum duration {durations.min()}. "
58 | "If intentional, consider changing start_duration when calling kaplan_meier.")
59 | order = np.argsort(durations)
60 | durations = durations[order]
61 | events = events[order]
62 | surv_idx = np.unique(durations)
63 | ni = np.zeros(len(surv_idx), dtype='int')
64 | di = np.zeros_like(ni)
65 | di, ni = _group_loop(n, surv_idx, durations, events, di, ni)
66 | ni = n - ni.cumsum()
67 | ni[1:] = ni[:-1]
68 | ni[0] = n
69 | survive = 1 - di / ni
70 | zero_survive = survive == 0
71 | if zero_survive.any():
72 | i = np.argmax(zero_survive)
73 | surv = np.zeros_like(survive)
74 | surv[:i] = np.exp(np.log(survive[:i]).cumsum())
75 | # surv[i:] = surv[i-1]
76 | surv[i:] = 0.
77 | else:
78 | surv = np.exp(np.log(1 - di / ni).cumsum())
79 | if start_duration < surv_idx.min():
80 | tmp = np.ones(len(surv)+ 1, dtype=surv.dtype)
81 | tmp[1:] = surv
82 | surv = tmp
83 | tmp = np.zeros(len(surv_idx)+ 1, dtype=surv_idx.dtype)
84 | tmp[1:] = surv_idx
85 | surv_idx = tmp
86 | surv = pd.Series(surv, surv_idx)
87 | return surv
88 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | pytest>=4.0.2
2 | lifelines>=0.22.8
3 | sklearn-pandas>=1.8.0
4 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [bumpversion]
2 | current_version = 0.3.0
3 | commit = True
4 | tag = False
5 |
6 | [metadata]
7 | license_files = LICENSE
8 |
9 | [bumpversion:file:setup.py]
10 | search = version='{current_version}'
11 | replace = version='{new_version}'
12 |
13 | [bumpversion:file:pycox/__init__.py]
14 | search = __version__ = '{current_version}'
15 | replace = __version__ = '{new_version}'
16 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | """The setup script."""
5 |
6 | from setuptools import setup, find_packages
7 |
8 |
9 | long_description = """
10 | **pycox** is a python package for survival analysis and time-to-event prediction with [PyTorch](https://pytorch.org/).
11 | It is built on the [torchtuples](https://github.com/havakv/torchtuples) package for training [PyTorch](https://pytorch.org/) models.
12 |
13 | Read the documentation at: https://github.com/havakv/pycox
14 |
15 | The package contains
16 |
17 | - survival models: (Logistic-Hazard, DeepHit, DeepSurv, Cox-Time, MTLR, etc.)
18 | - evaluation criteria (concordance, Brier score, Binomial log-likelihood, etc.)
19 | - event-time datasets (SUPPORT, METABRIC, KKBox, etc)
20 | - simulation studies
21 | - illustrative examples
22 | """
23 |
24 | requirements = [
25 | 'torchtuples>=0.2.0',
26 | 'feather-format>=0.4.0',
27 | 'h5py>=2.9.0',
28 | 'numba>=0.44',
29 | 'scikit-learn>=0.21.2',
30 | 'requests>=2.22.0',
31 | 'py7zr>=0.11.3',
32 | ]
33 |
34 | setup(
35 | name='pycox',
36 | version='0.3.0',
37 | description="Survival analysis with PyTorch",
38 | long_description=long_description,
39 | long_description_content_type='text/markdown',
40 | author="Haavard Kvamme",
41 | author_email='haavard.kvamme@gmail.com',
42 | url='https://github.com/havakv/pycox',
43 | packages=find_packages(),
44 | include_package_data=True,
45 | install_requires=requirements,
46 | license="BSD license",
47 | zip_safe=False,
48 | keywords='pycox',
49 | classifiers=[
50 | 'Development Status :: 2 - Pre-Alpha',
51 | 'Intended Audience :: Developers',
52 | 'License :: OSI Approved :: BSD License',
53 | 'Natural Language :: English',
54 | ],
55 | python_requires='>=3.8'
56 | )
57 |
--------------------------------------------------------------------------------
/tests/evaluation/test_admin.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from pycox.evaluation import admin
4 | from pycox.evaluation import EvalSurv
5 |
6 |
7 | def test_brier_score_no_censor():
8 | n = 4
9 | durations = np.ones(n) * 50
10 | durations_c = np.ones_like(durations) * 100
11 | events = durations <= durations_c
12 | m = 5
13 | index_surv = np.array([0, 25., 50., 75., 100.])
14 |
15 | surv_ones = np.ones((m, n))
16 | time_grid = np.array([5., 40., 60., 100.])
17 | bs = admin.brier_score(time_grid, durations, durations_c, events, surv_ones, index_surv)
18 | assert (bs == np.array([0., 0., 1., 1.])).all()
19 | surv_zeros = surv_ones * 0
20 | bs = admin.brier_score(time_grid, durations, durations_c, events, surv_zeros, index_surv)
21 | assert (bs == np.array([1., 1., 0., 0.])).all()
22 | surv_05 = surv_ones * 0.5
23 | bs = admin.brier_score(time_grid, durations, durations_c, events, surv_05, index_surv)
24 | assert (bs == np.array([0.25, 0.25, 0.25, 0.25])).all()
25 | time_grid = np.array([110.])
26 | bs = admin.brier_score(time_grid, durations, durations_c, events, surv_05, index_surv)
27 | assert np.isnan(bs).all()
28 |
29 | def test_brier_score_censor():
30 | n = 4
31 | durations = np.ones(n) * 50
32 | durations_c = np.array([25, 50, 60, 100])
33 | events = durations <= durations_c
34 | durations[~events] = durations_c[~events]
35 | m = 5
36 | index_surv = np.array([0, 25., 50., 75., 100.])
37 |
38 | surv = np.ones((m, n))
39 | surv[:, 0] = 0
40 | time_grid = np.array([5., 25., 40., 60., 100.])
41 | bs = admin.brier_score(time_grid, durations, durations_c, events, surv, index_surv)
42 | assert (bs == np.array([0.25, 0.25, 0., 1., 1.])).all()
43 |
44 | def test_brier_score_evalsurv():
45 | n = 4
46 | durations = np.ones(n) * 50
47 | durations_c = np.array([25, 50, 60, 100])
48 | events = durations <= durations_c
49 | durations[~events] = durations_c[~events]
50 | m = 5
51 | index_surv = np.array([0, 25., 50., 75., 100.])
52 |
53 | surv = np.ones((m, n))
54 | surv[:, 0] = 0
55 | surv = pd.DataFrame(surv, index_surv)
56 | time_grid = np.array([5., 25., 40., 60., 100.])
57 | ev = EvalSurv(surv, durations, events, censor_durations=durations_c)
58 | bs = ev.brier_score_admin(time_grid)
59 | assert (bs.values == np.array([0.25, 0.25, 0., 1., 1.])).all()
60 |
61 |
62 | def test_binoial_log_likelihood_no_censor():
63 | n = 4
64 | durations = np.ones(n) * 50
65 | durations_c = np.ones_like(durations) * 100
66 | events = durations <= durations_c
67 | m = 5
68 | index_surv = np.array([0, 25., 50., 75., 100.])
69 |
70 | surv_ones = np.ones((m, n))
71 | time_grid = np.array([5., 40., 60., 100.])
72 | bll = admin.binomial_log_likelihood(time_grid, durations, durations_c, events, surv_ones, index_surv)
73 | eps = 1e-7
74 | assert abs(bll - np.log([1-eps, 1-eps, eps, eps])).max() < 1e-7
75 | surv_zeros = surv_ones * 0
76 | bll = admin.binomial_log_likelihood(time_grid, durations, durations_c, events, surv_zeros, index_surv)
77 | assert abs(bll - np.log([eps, eps, 1-eps, 1-eps])).max() < 1e-7
78 | surv_05 = surv_ones * 0.5
79 | bll = admin.binomial_log_likelihood(time_grid, durations, durations_c, events, surv_05, index_surv)
80 | assert abs(bll - np.log([0.5, 0.5, 1-0.5, 1-0.5])).max() < 1e-7
81 | time_grid = np.array([110.])
82 | bll = admin.binomial_log_likelihood(time_grid, durations, durations_c, events, surv_05, index_surv)
83 | assert np.isnan(bll).all()
84 |
--------------------------------------------------------------------------------
/tests/models/test_bce_surv.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pycox.models import BCESurv
3 | import torchtuples as tt
4 |
5 | from utils_model_testing import make_dataset, fit_model, assert_survs
6 |
7 |
8 | @pytest.mark.parametrize('numpy', [True, False])
9 | @pytest.mark.parametrize('num_durations', [2, 5])
10 | def test_pmf_runs(numpy, num_durations):
11 | data = make_dataset(True)
12 | input, target = data
13 | labtrans = BCESurv.label_transform(num_durations)
14 | target = labtrans.fit_transform(*target)
15 | data = tt.tuplefy(input, target)
16 | if not numpy:
17 | data = data.to_tensor()
18 | net = tt.practical.MLPVanilla(input.shape[1], [4], labtrans.out_features)
19 | model = BCESurv(net)
20 | fit_model(data, model)
21 | assert_survs(input, model)
22 | model.duration_index = labtrans.cuts
23 | assert_survs(input, model)
24 | cdi = model.interpolate(3, 'const_pdf')
25 | assert_survs(input, cdi)
26 |
--------------------------------------------------------------------------------
/tests/models/test_cox.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torchtuples as tt
3 | from pycox.models import CoxPH
4 | from pycox.models.cox_time import MLPVanillaCoxTime
5 |
6 | from utils_model_testing import make_dataset, fit_model, assert_survs
7 |
8 |
9 | @pytest.mark.parametrize('numpy', [True, False])
10 | def test_cox_cc_runs(numpy):
11 | data = make_dataset(False).apply(lambda x: x.float()).to_numpy()
12 | if not numpy:
13 | data = data.to_tensor()
14 | net = tt.practical.MLPVanilla(data[0].shape[1], [4], 1, False, output_bias=False)
15 | model = CoxPH(net)
16 | fit_model(data, model)
17 | model.compute_baseline_hazards()
18 | assert_survs(data[0], model)
19 |
--------------------------------------------------------------------------------
/tests/models/test_cox_cc.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torchtuples as tt
3 | from pycox.models import CoxCC
4 | from pycox.models.cox_time import MLPVanillaCoxTime
5 |
6 | from utils_model_testing import make_dataset, fit_model, assert_survs
7 |
8 |
9 | @pytest.mark.parametrize('numpy', [True, False])
10 | def test_cox_cc_runs(numpy):
11 | data = make_dataset(False).apply(lambda x: x.float()).to_numpy()
12 | if not numpy:
13 | data = data.to_tensor()
14 | net = tt.practical.MLPVanilla(data[0].shape[1], [4], 1, False, output_bias=False)
15 | model = CoxCC(net)
16 | fit_model(data, model)
17 | model.compute_baseline_hazards()
18 | assert_survs(data[0], model)
19 |
--------------------------------------------------------------------------------
/tests/models/test_cox_time.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torchtuples as tt
3 | from pycox.models import CoxTime
4 | from pycox.models.cox_time import MLPVanillaCoxTime
5 |
6 | from utils_model_testing import make_dataset, fit_model, assert_survs
7 |
8 |
9 | @pytest.mark.parametrize('numpy', [True, False])
10 | def test_cox_time_runs(numpy):
11 | input, target = make_dataset(False).apply(lambda x: x.float()).to_numpy()
12 | labtrans = CoxTime.label_transform()
13 | target = labtrans.fit_transform(*target)
14 | data = tt.tuplefy(input, target)
15 | if not numpy:
16 | data = data.to_tensor()
17 | net = MLPVanillaCoxTime(data[0].shape[1], [4], False)
18 | model = CoxTime(net)
19 | fit_model(data, model)
20 | model.compute_baseline_hazards()
21 | assert_survs(data[0], model, with_dl=False)
22 |
--------------------------------------------------------------------------------
/tests/models/test_deephit.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pycox.models import DeepHitSingle
3 | import torchtuples as tt
4 |
5 | from utils_model_testing import make_dataset, fit_model, assert_survs
6 |
7 |
8 | @pytest.mark.parametrize('numpy', [True, False])
9 | @pytest.mark.parametrize('num_durations', [2, 5])
10 | def test_deep_hit_single_runs(numpy, num_durations):
11 | data = make_dataset(True)
12 | input, target = data
13 | labtrans = DeepHitSingle.label_transform(num_durations)
14 | target = labtrans.fit_transform(*target)
15 | data = tt.tuplefy(input, target)
16 | if not numpy:
17 | data = data.to_tensor()
18 | net = tt.practical.MLPVanilla(input.shape[1], [4], labtrans.out_features)
19 | model = DeepHitSingle(net)
20 | fit_model(data, model)
21 | assert_survs(input, model)
22 | model.duration_index = labtrans.cuts
23 | assert_survs(input, model)
24 | cdi = model.interpolate(3, 'const_pdf')
25 | assert_survs(input, cdi)
26 |
--------------------------------------------------------------------------------
/tests/models/test_interpolation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from pycox import models
4 | from pycox.models.interpolation import InterpolateDiscrete
5 |
6 |
7 | class MockPMF(models.PMF):
8 | def __init__(self, duration_index=None):
9 | self.duration_index = duration_index
10 |
11 | def predict(self, input, *args, **kwargs):
12 | return input
13 |
14 | class MockLogisticHazard(models.LogisticHazard):
15 | def __init__(self, duration_index=None):
16 | self.duration_index = duration_index
17 |
18 | def predict(self, input, *args, **kwargs):
19 | return input
20 |
21 | class MockMTLR(models.MTLR):
22 | def __init__(self, duration_index=None):
23 | self.duration_index = duration_index
24 |
25 | def predict(self, input, *args, **kwargs):
26 | return input
27 |
28 |
29 | @pytest.mark.parametrize('m', [2, 5, 10])
30 | @pytest.mark.parametrize('sub', [2, 5])
31 | def test_pmf_cdi_equals_base(m, sub):
32 | torch.manual_seed(12345)
33 | n = 20
34 | idx = torch.randn(m).abs().sort()[0].numpy()
35 | input = torch.randn(n, m)
36 | model = MockPMF(idx)
37 | surv_pmf = model.interpolate(sub).predict_surv_df(input)
38 | surv_base = InterpolateDiscrete(model, duration_index=idx, sub=sub).predict_surv_df(input)
39 | assert (surv_pmf.index == surv_base.index).all()
40 | assert (surv_pmf - surv_base).abs().max().max() < 1e-7
41 |
42 |
43 | @pytest.mark.parametrize('m', [2, 5, 10])
44 | @pytest.mark.parametrize('sub', [2, 5])
45 | def test_base_values_at_knots(m, sub):
46 | torch.manual_seed(12345)
47 | n = 20
48 | idx = torch.randn(m).abs().sort()[0].numpy()
49 | input = torch.randn(n, m)
50 | model = MockPMF(idx)
51 | surv_cdi = InterpolateDiscrete(model, duration_index=idx, sub=sub).predict_surv_df(input)
52 | surv = model.predict_surv_df(input)
53 | diff = (surv - surv_cdi).dropna()
54 | assert diff.shape == surv.shape
55 | assert (diff == 0).all().all()
56 |
57 | @pytest.mark.parametrize('m', [2, 5, 10])
58 | @pytest.mark.parametrize('sub', [2, 5])
59 | def test_pmf_values_at_knots(m, sub):
60 | torch.manual_seed(12345)
61 | n = 20
62 | idx = torch.randn(m).abs().sort()[0].numpy()
63 | input = torch.randn(n, m)
64 | model = MockPMF(idx)
65 | surv = model.predict_surv_df(input)
66 | surv_cdi = model.interpolate(sub, 'const_pdf').predict_surv_df(input)
67 | diff = (surv - surv_cdi).dropna()
68 | assert diff.shape == surv.shape
69 | assert diff.max().max() < 1e-7
70 |
71 | @pytest.mark.parametrize('m', [2, 5, 10])
72 | @pytest.mark.parametrize('sub', [2, 5])
73 | def test_logistic_hazard_values_at_knots(m, sub):
74 | torch.manual_seed(12345)
75 | n = 20
76 | idx = torch.randn(m).abs().sort()[0].numpy()
77 | input = torch.randn(n, m)
78 | model = MockLogisticHazard(idx)
79 | surv = model.predict_surv_df(input)
80 | surv_cdi = model.interpolate(sub, 'const_pdf').predict_surv_df(input)
81 | diff = (surv - surv_cdi).dropna()
82 | assert diff.shape == surv.shape
83 | assert (diff == 0).all().all()
84 | surv_chi = model.interpolate(sub, 'const_hazard').predict_surv_df(input)
85 | diff = (surv - surv_chi).dropna()
86 | assert diff.shape == surv.shape
87 | assert (diff.index == surv.index).all()
88 | assert diff.max().max() < 1e-6
89 |
90 | @pytest.mark.parametrize('m', [2, 5, 10])
91 | @pytest.mark.parametrize('sub', [2, 5])
92 | def test_mtlr_values_at_knots(m, sub):
93 | torch.manual_seed(12345)
94 | n = 20
95 | idx = torch.randn(m).abs().sort()[0].numpy()
96 | input = torch.randn(n, m)
97 | model = MockMTLR(idx)
98 | surv = model.predict_surv_df(input)
99 | surv_cdi = model.interpolate(sub, 'const_pdf').predict_surv_df(input)
100 | diff = (surv - surv_cdi).dropna()
101 | assert diff.shape == surv.shape
102 | assert diff.max().max() < 1e-7
103 |
--------------------------------------------------------------------------------
/tests/models/test_logistic_hazard.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pycox.models import LogisticHazard
3 | import torchtuples as tt
4 |
5 | from utils_model_testing import make_dataset, fit_model, assert_survs
6 |
7 |
8 | @pytest.mark.parametrize('numpy', [True, False])
9 | @pytest.mark.parametrize('num_durations', [2, 5])
10 | def test_logistic_hazard_runs(numpy, num_durations):
11 | data = make_dataset(True)
12 | input, target = data
13 | labtrans = LogisticHazard.label_transform(num_durations)
14 | target = labtrans.fit_transform(*target)
15 | data = tt.tuplefy(input, target)
16 | if not numpy:
17 | data = data.to_tensor()
18 | net = tt.practical.MLPVanilla(input.shape[1], [4], labtrans.out_features)
19 | model = LogisticHazard(net)
20 | fit_model(data, model)
21 | assert_survs(input, model)
22 | model.duration_index = labtrans.cuts
23 | assert_survs(input, model)
24 | cdi = model.interpolate(3, 'const_pdf')
25 | assert_survs(input, cdi)
26 |
--------------------------------------------------------------------------------
/tests/models/test_loss.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | import torch
4 | import torchtuples as tt
5 |
6 | from pycox.models.data import pair_rank_mat
7 | from pycox.models import loss
8 |
9 | @pytest.mark.parametrize('seed', [0, 1, 2])
10 | @pytest.mark.parametrize('m', [1, 5, 8])
11 | def test_nll_pmf_cr_equals_nll_pmf(seed, m):
12 | torch.manual_seed(seed)
13 | # m = 5
14 | n_risk = 1
15 | rep = 7
16 | batch = m * (n_risk + 1) * rep
17 | phi = torch.randn(batch, n_risk, m)
18 | idx_duration = torch.arange(m).repeat(rep * (n_risk + 1))
19 | events = torch.arange(n_risk + 1).repeat(m * rep)
20 | r1 = loss.nll_pmf_cr(phi, idx_duration, events)
21 | r2 = loss.nll_pmf(phi.view(batch * n_risk, -1), idx_duration, events.float())
22 | assert (r1 - r2).abs() < 1e-5
23 |
24 | @pytest.mark.parametrize('seed', [0, 1, 2])
25 | @pytest.mark.parametrize('m', [1, 5, 8])
26 | @pytest.mark.parametrize('sigma', [0.1, 0.2, 1.])
27 | def test_rank_loss_deephit_cr_equals_single(seed, m, sigma):
28 | torch.manual_seed(seed)
29 | n_risk = 1
30 | rep = 7
31 | batch = m * (n_risk + 1) * rep
32 | phi = torch.randn(batch, n_risk, m)
33 | idx_duration = torch.arange(m).repeat(rep * (n_risk + 1))
34 | events = torch.arange(n_risk + 1).repeat(m * rep)
35 | rank_mat = pair_rank_mat(idx_duration.numpy(), events.numpy())
36 | rank_mat = torch.tensor(rank_mat)
37 | r1 = loss.rank_loss_deephit_cr(phi, idx_duration, events, rank_mat, sigma)
38 | r2 = loss.rank_loss_deephit_single(phi.view(batch, -1), idx_duration, events.float(),
39 | rank_mat, sigma)
40 | assert (r1 - r2).abs() < 1e-6
41 |
42 |
43 | @pytest.mark.parametrize('seed', [0, 1])
44 | @pytest.mark.parametrize('m', [1, 8])
45 | @pytest.mark.parametrize('sigma', [0.1, 1.])
46 | @pytest.mark.parametrize('alpha', [1, 0.5, 0.])
47 | def test_loss_deephit_cr_equals_single(seed, m, sigma, alpha):
48 | torch.manual_seed(seed)
49 | n_risk = 1
50 | rep = 7
51 | batch = m * (n_risk + 1) * rep
52 | phi = torch.randn(batch, n_risk, m)
53 | idx_duration = torch.arange(m).repeat(rep * (n_risk + 1))
54 | events = torch.arange(n_risk + 1).repeat(m * rep)
55 | rank_mat = pair_rank_mat(idx_duration.numpy(), events.numpy())
56 | rank_mat = torch.tensor(rank_mat)
57 | loss_cr = loss.DeepHitLoss(alpha, sigma)
58 | loss_single = loss.DeepHitSingleLoss(alpha, sigma)
59 | r1 = loss_cr(phi, idx_duration, events, rank_mat)
60 | r2 = loss_single(phi.view(batch, -1), idx_duration, events.float(), rank_mat)
61 | assert (r1 - r2).abs() < 1e-5
62 |
63 | @pytest.mark.parametrize('seed', [0, 1])
64 | @pytest.mark.parametrize('shrink', [0, 0.01, 1.])
65 | def test_cox_cc_loss_single_ctrl(seed, shrink):
66 | np.random.seed(seed)
67 | n = 100
68 | case = np.random.uniform(-1, 1, n)
69 | ctrl = np.random.uniform(-1, 1, n)
70 | case, ctrl = tt.tuplefy(case, ctrl).to_tensor()
71 | loss_1 = loss.cox_cc_loss(case, (ctrl,), shrink)
72 | loss_2 = loss.cox_cc_loss_single_ctrl(case, ctrl, shrink)
73 | assert (loss_1 - loss_2).abs() < 1e-6
74 |
75 | @pytest.mark.parametrize('shrink', [0, 0.01, 1.])
76 | def test_cox_cc_loss_single_ctrl_zero(shrink):
77 | n = 10
78 | case = ctrl = torch.zeros(n)
79 | loss_1 = loss.cox_cc_loss_single_ctrl(case, ctrl, shrink)
80 | val = torch.tensor(2.).log()
81 | assert (loss_1 - val).abs() == 0
82 |
83 | @pytest.mark.parametrize('shrink', [0, 0.01, 1.])
84 | def test_cox_cc_loss_zero(shrink):
85 | n = 10
86 | case = ctrl = torch.zeros(n)
87 | loss_1 = loss.cox_cc_loss(case, (ctrl,), shrink)
88 | val = torch.tensor(2.).log()
89 | assert (loss_1 - val).abs() == 0
90 |
91 | def test_nll_mtlr_zero():
92 | n_frac = 4
93 | m = 5
94 | n = m * n_frac
95 | phi = torch.zeros(n, m)
96 | idx_durations = torch.arange(m).repeat(n_frac)
97 | events = torch.ones_like(idx_durations).float()
98 | loss_pmf = loss.nll_pmf(phi, idx_durations, events)
99 | loss_mtlr = loss.nll_mtlr(phi, idx_durations, events)
100 | assert (loss_pmf - loss_mtlr).abs() == 0
101 | events = torch.zeros_like(events)
102 | loss_pmf = loss.nll_pmf(phi, idx_durations, events)
103 | loss_mtlr = loss.nll_mtlr(phi, idx_durations, events)
104 | assert (loss_pmf - loss_mtlr).abs() == 0
105 |
--------------------------------------------------------------------------------
/tests/models/test_models_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | import torch
4 | from pycox.models.utils import pad_col, make_subgrid, cumsum_reverse
5 |
6 | @pytest.mark.parametrize('val', [0, 1, 5])
7 | def test_pad_col_start(val):
8 | x = torch.ones((2, 3))
9 | x_pad = pad_col(x, val, where='start')
10 | pad = torch.ones(2, 1) * val
11 | assert (x_pad == torch.cat([pad, x], dim=1)).all()
12 |
13 | @pytest.mark.parametrize('val', [0, 1, 5])
14 | def test_pad_col_end(val):
15 | x = torch.ones((2, 3))
16 | x_pad = pad_col(x, val)
17 | pad = torch.ones(2, 1) * val
18 | assert (x_pad == torch.cat([x, pad], dim=1)).all()
19 |
20 | @pytest.mark.parametrize('n', [2, 13, 40])
21 | def test_make_subgrid_1(n):
22 | grid = np.random.uniform(0, 100, n)
23 | grid = np.sort(grid)
24 | new_grid = make_subgrid(grid, 1)
25 | assert len(new_grid) == len(grid)
26 | assert (new_grid == grid).all()
27 |
28 | @pytest.mark.parametrize('sub', [2, 10, 20])
29 | @pytest.mark.parametrize('start', [0, 2])
30 | @pytest.mark.parametrize('stop', [4, 100])
31 | @pytest.mark.parametrize('n', [5, 10])
32 | def test_make_subgrid(sub, start, stop, n):
33 | grid = np.linspace(start, stop, n)
34 | new_grid = make_subgrid(grid, sub)
35 | true_new = np.linspace(start, stop, n*sub - (sub-1))
36 | assert len(new_grid) == len(true_new)
37 | assert np.abs(true_new - new_grid).max() < 1e-13
38 |
39 | def test_cumsum_reverse_error_dim():
40 | x = torch.randn((5, 3))
41 | with pytest.raises(NotImplementedError):
42 | cumsum_reverse(x, dim=0)
43 | with pytest.raises(NotImplementedError):
44 | cumsum_reverse(x, dim=2)
45 |
46 | def test_cumsum_reverse_dim_1():
47 | torch.manual_seed(1234)
48 | x = torch.randn(5, 16)
49 | res_np = x.numpy()[:, ::-1].cumsum(1)[:, ::-1]
50 | res = cumsum_reverse(x, dim=1)
51 | assert np.abs(res.numpy() - res_np).max() < 1e-6
52 |
--------------------------------------------------------------------------------
/tests/models/test_mtlr.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pycox.models import MTLR
3 | import torchtuples as tt
4 |
5 | from utils_model_testing import make_dataset, fit_model, assert_survs
6 |
7 |
8 | @pytest.mark.parametrize('numpy', [True, False])
9 | @pytest.mark.parametrize('num_durations', [2, 5])
10 | def test_mtlr_runs(numpy, num_durations):
11 | data = make_dataset(True)
12 | input, target = data
13 | labtrans = MTLR.label_transform(num_durations)
14 | target = labtrans.fit_transform(*target)
15 | data = tt.tuplefy(input, target)
16 | if not numpy:
17 | data = data.to_tensor()
18 | net = tt.practical.MLPVanilla(input.shape[1], [4], labtrans.out_features)
19 | model = MTLR(net)
20 | fit_model(data, model)
21 | assert_survs(input, model)
22 | model.duration_index = labtrans.cuts
23 | assert_survs(input, model)
24 | cdi = model.interpolate(3, 'const_pdf')
25 | assert_survs(input, cdi)
26 |
--------------------------------------------------------------------------------
/tests/models/test_pc_hazard.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import numpy as np
4 | import torchtuples as tt
5 | from pycox.models import PCHazard
6 |
7 | from utils_model_testing import make_dataset, fit_model, assert_survs
8 |
9 |
10 | def _make_dataset(n, m):
11 | np.random.seed(0)
12 | x = np.random.normal(0, 1, (n, 4)).astype('float32')
13 | duration_index = np.arange(m+1).astype('int64')
14 | durations = np.repeat(duration_index, np.ceil(n / m))[:n]
15 | events = np.random.uniform(0, 1, n).round().astype('float32')
16 | fracs = np.random.uniform(0, 1, n).astype('float32')
17 | return x, (durations, events, fracs), duration_index
18 |
19 | @pytest.mark.parametrize('m', [5, 10])
20 | @pytest.mark.parametrize('n_mul', [2, 3])
21 | @pytest.mark.parametrize('mp', [1, 2, -1])
22 | def test_wrong_net_output(m, n_mul, mp):
23 | n = m * n_mul
24 | inp, tar, dur_index = _make_dataset(n, m)
25 | net = torch.nn.Linear(inp.shape[1], m+1)
26 | with pytest.raises(ValueError):
27 | model = PCHazard(net, duration_index=dur_index)
28 |
29 | model = PCHazard(net)
30 | with pytest.raises(ValueError):
31 | model.fit(inp, tar)
32 |
33 | model.duration_index = dur_index
34 | with pytest.raises(ValueError):
35 | model.predict_surv_df(inp)
36 |
37 | model.duration_index = dur_index
38 | dl = model.make_dataloader((inp, tar), 5, True)
39 | with pytest.raises(ValueError):
40 | model.fit_dataloader(dl)
41 |
42 | @pytest.mark.parametrize('m', [5, 10])
43 | @pytest.mark.parametrize('n_mul', [2, 3])
44 | def test_right_net_output(m, n_mul):
45 | n = m * n_mul
46 | inp, tar, dur_index = _make_dataset(n, m)
47 | net = torch.nn.Linear(inp.shape[1], m)
48 | model = PCHazard(net)
49 | model = PCHazard(net, duration_index=dur_index)
50 | model.fit(inp, tar, verbose=False)
51 | model.predict_surv_df(inp)
52 | dl = model.make_dataloader((inp, tar), 5, True)
53 | model.fit_dataloader(dl)
54 | assert True
55 |
56 | @pytest.mark.parametrize('numpy', [True, False])
57 | @pytest.mark.parametrize('num_durations', [3, 8])
58 | def test_pc_hazard_runs(numpy, num_durations):
59 | data = make_dataset(True)
60 | input, (durations, events) = data
61 | durations += 1
62 | target = (durations, events)
63 | labtrans = PCHazard.label_transform(num_durations)
64 | target = labtrans.fit_transform(*target)
65 | data = tt.tuplefy(input, target)
66 | if not numpy:
67 | data = data.to_tensor()
68 | net = tt.practical.MLPVanilla(input.shape[1], [4], num_durations)
69 | model = PCHazard(net)
70 | fit_model(data, model)
71 | assert_survs(input, model)
72 | model.duration_index = labtrans.cuts
73 | assert_survs(input, model)
74 |
--------------------------------------------------------------------------------
/tests/models/test_pmf.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pycox.models import PMF
3 | import torchtuples as tt
4 |
5 | from utils_model_testing import make_dataset, fit_model, assert_survs
6 |
7 |
8 | @pytest.mark.parametrize('numpy', [True, False])
9 | @pytest.mark.parametrize('num_durations', [2, 5])
10 | def test_pmf_runs(numpy, num_durations):
11 | data = make_dataset(True)
12 | input, target = data
13 | labtrans = PMF.label_transform(num_durations)
14 | target = labtrans.fit_transform(*target)
15 | data = tt.tuplefy(input, target)
16 | if not numpy:
17 | data = data.to_tensor()
18 | net = tt.practical.MLPVanilla(input.shape[1], [4], labtrans.out_features)
19 | model = PMF(net)
20 | fit_model(data, model)
21 | assert_survs(input, model)
22 | model.duration_index = labtrans.cuts
23 | assert_survs(input, model)
24 | cdi = model.interpolate(3, 'const_pdf')
25 | assert_survs(input, cdi)
26 |
--------------------------------------------------------------------------------
/tests/models/utils_model_testing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import torch
4 | import torchtuples as tt
5 |
6 | def make_dataset(numpy):
7 | n_events = 2
8 | n_frac = 4
9 | m = 10
10 | n = m * n_frac * n_events
11 | p = 5
12 | input = torch.randn((n, p))
13 | durations = torch.arange(m).repeat(int(n / m))
14 | events = torch.arange(n_events).repeat(int(n / n_events)).float()
15 | target = (durations, events)
16 | data = tt.tuplefy(input, target)
17 | if numpy:
18 | data = data.to_numpy()
19 | return data
20 |
21 | def fit_model(data, model):
22 | model.fit(*data, epochs=1, verbose=False, val_data=data)
23 | return model
24 |
25 | def assert_survs(input, model, with_dl=True):
26 | preds = model.predict_surv(input)
27 | assert type(preds) is type(input)
28 | assert preds.shape[0] == input.shape[0]
29 | surv_df = model.predict_surv_df(input)
30 | assert type(surv_df) is pd.DataFrame
31 | assert type(surv_df.values) is np.ndarray
32 | assert preds.shape[0] == surv_df.shape[1]
33 | assert preds.shape[1] == surv_df.shape[0]
34 | np_input = tt.tuplefy(input).to_numpy()[0]
35 | torch_input = tt.tuplefy(input).to_tensor()[0]
36 | np_preds = model.predict_surv(np_input)
37 | torch_preds = model.predict_surv(torch_input)
38 | assert (np_preds == torch_preds.cpu().numpy()).all()
39 | if with_dl:
40 | dl_input = tt.tuplefy(input).make_dataloader(512, False)
41 | dl_preds = model.predict_surv(dl_input)
42 | assert type(np_preds) is type(dl_preds), f"got {type(np_preds)}, and, {type(dl_preds)}"
43 | assert (np_preds == dl_preds).all()
44 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | from pycox import utils
4 |
5 | def test_kaplan_meier():
6 | durations = np.array([1., 1., 2., 3.])
7 | events = np.array([1, 1, 1, 0])
8 | surv = utils.kaplan_meier(durations, events)
9 | assert (surv.index.values == np.arange(4, dtype=float)).all()
10 | assert (surv.values == np.array([1., 0.5, 0.25, 0.25])).all()
11 |
12 | @pytest.mark.parametrize('n', [10, 85, 259])
13 | @pytest.mark.parametrize('p_cens', [0, 0.3, 0.8])
14 | def test_kaplan_meier_vs_lifelines(n, p_cens):
15 | from lifelines import KaplanMeierFitter
16 | np.random.seed(0)
17 | durations = np.random.uniform(0, 100, n)
18 | events = np.random.binomial(1, 1 - p_cens, n).astype('float')
19 | km = utils.kaplan_meier(durations, events)
20 | kmf = KaplanMeierFitter().fit(durations, events).survival_function_['KM_estimate']
21 | assert km.shape == kmf.shape
22 | assert (km - kmf).abs().max() < 1e-14
23 | assert (km.index == kmf.index).all()
24 |
--------------------------------------------------------------------------------