├── .github └── workflows │ ├── publish_pypi.yml │ └── pythonpackage.yml ├── .gitignore ├── LICENSE ├── README.md ├── examples ├── 01_introduction.ipynb ├── 02_introduction.ipynb ├── 03_network_architectures.ipynb ├── 04_mnist_dataloaders_cnn.ipynb ├── README.md ├── administrative_brier_score.ipynb ├── cox-cc.ipynb ├── cox-ph.ipynb ├── cox-time.ipynb ├── deephit.ipynb ├── deephit_competing_risks.ipynb ├── mtlr.ipynb ├── pc-hazard.ipynb └── pmf.ipynb ├── figures └── logo.svg ├── pycox ├── __init__.py ├── datasets │ ├── __init__.py │ ├── _dataset_loader.py │ ├── from_deepsurv.py │ ├── from_kkbox.py │ ├── from_rdatasets.py │ └── from_simulations.py ├── evaluation │ ├── __init__.py │ ├── admin.py │ ├── concordance.py │ ├── eval_surv.py │ ├── ipcw.py │ └── metrics.py ├── models │ ├── __init__.py │ ├── base.py │ ├── bce_surv.py │ ├── cox.py │ ├── cox_cc.py │ ├── cox_time.py │ ├── data.py │ ├── deephit.py │ ├── interpolation.py │ ├── logistic_hazard.py │ ├── loss.py │ ├── mtlr.py │ ├── pc_hazard.py │ ├── pmf.py │ └── utils.py ├── preprocessing │ ├── __init__.py │ ├── discretization.py │ ├── feature_transforms.py │ └── label_transforms.py ├── simulations │ ├── __init__.py │ ├── base.py │ ├── discrete_logit_hazard.py │ └── relative_risk.py └── utils.py ├── requirements-dev.txt ├── setup.cfg ├── setup.py └── tests ├── evaluation └── test_admin.py ├── models ├── test_bce_surv.py ├── test_cox.py ├── test_cox_cc.py ├── test_cox_time.py ├── test_deephit.py ├── test_interpolation.py ├── test_logistic_hazard.py ├── test_loss.py ├── test_models_utils.py ├── test_mtlr.py ├── test_pc_hazard.py ├── test_pmf.py └── utils_model_testing.py └── test_utils.py /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package to PyPI 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: 3.9 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build 21 | run: | 22 | python setup.py sdist bdist_wheel 23 | - name: Publish to PyPI 24 | uses: pypa/gh-action-pypi-publish@master 25 | with: 26 | password: ${{ secrets.PYPI_PASSWORD }} 27 | -------------------------------------------------------------------------------- /.github/workflows/pythonpackage.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | jobs: 10 | build_test: 11 | name: Test on ${{ matrix.config.os }} with Python ${{ matrix.python-version }} 12 | runs-on: ${{ matrix.config.os }} 13 | strategy: 14 | max-parallel: 4 15 | fail-fast: false 16 | matrix: 17 | python-version: ['3.8', '3.9', '3.10'] 18 | config: 19 | - { os: ubuntu-latest, torch-version: "torch --index-url https://download.pytorch.org/whl/cpu"} 20 | - { os: windows-latest, torch-version: "torch"} 21 | - { os: macOS-latest, torch-version: "torch"} 22 | steps: 23 | - uses: actions/checkout@v4 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install ${{ matrix.config.torch-version }} 32 | # python setup.py install 33 | pip install . 34 | pip install -r requirements-dev.txt 35 | - name: Lint with flake8 36 | run: | 37 | pip install flake8 38 | # stop the build if there are Python syntax errors or undefined names 39 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 40 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 41 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 42 | - name: Test with pytest 43 | run: | 44 | pytest 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # datasets 104 | /datasets/ 105 | 106 | # vscode 107 | .vscode/ 108 | 109 | # torch files 110 | *.torch 111 | 112 | # data files 113 | /pycox/datasets/data/ 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2018, Haavard Kvamme 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | The notebooks in this directory describe how the methods can be used. 4 | 5 | -------------------------------------------------------------------------------- /figures/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
pycox
<font style="font-size: 112px" color="#666666">pycox</font>
-------------------------------------------------------------------------------- /pycox/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for pycox.""" 4 | 5 | __author__ = """Haavard Kvamme""" 6 | __email__ = 'haavard.kvamme@gmail.com' 7 | __version__ = '0.3.0' 8 | 9 | import pycox.datasets 10 | import pycox.evaluation 11 | import pycox.preprocessing 12 | import pycox.simulations 13 | import pycox.utils 14 | import pycox.models 15 | -------------------------------------------------------------------------------- /pycox/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from pycox.datasets import from_deepsurv 2 | from pycox.datasets import from_rdatasets 3 | from pycox.datasets import from_kkbox 4 | from pycox.datasets import from_simulations 5 | 6 | 7 | support = from_deepsurv._Support() 8 | metabric = from_deepsurv._Metabric() 9 | gbsg = from_deepsurv._Gbsg() 10 | flchain = from_rdatasets._Flchain() 11 | nwtco = from_rdatasets._Nwtco() 12 | kkbox_v1 = from_kkbox._DatasetKKBoxChurn() 13 | kkbox = from_kkbox._DatasetKKBoxAdmin() 14 | sac3 = from_simulations._SAC3() 15 | rr_nl_nhp = from_simulations._RRNLNPH() 16 | sac_admin5 = from_simulations._SACAdmin5() 17 | -------------------------------------------------------------------------------- /pycox/datasets/_dataset_loader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import pandas as pd 3 | import pycox 4 | import os 5 | 6 | _DATA_OVERRIDE = os.environ.get('PYCOX_DATA_DIR', None) 7 | if _DATA_OVERRIDE: 8 | _PATH_DATA = Path(_DATA_OVERRIDE) 9 | else: 10 | _PATH_ROOT = Path(pycox.__file__).parent 11 | _PATH_DATA = _PATH_ROOT / 'datasets' / 'data' 12 | _PATH_DATA.mkdir(parents=True, exist_ok=True) 13 | 14 | class _DatasetLoader: 15 | """Abstract class for loading data sets. 16 | """ 17 | name = NotImplemented 18 | _checksum = None 19 | 20 | def __init__(self): 21 | self.path = _PATH_DATA / f"{self.name}.feather" 22 | 23 | def read_df(self): 24 | if not self.path.exists(): 25 | print(f"Dataset '{self.name}' not locally available. Downloading...") 26 | self._download() 27 | print(f"Done") 28 | df = pd.read_feather(self.path) 29 | df = self._label_cols_at_end(df) 30 | return df 31 | 32 | def _download(self): 33 | raise NotImplementedError 34 | 35 | def delete_local_copy(self): 36 | if not self.path.exists(): 37 | raise RuntimeError("File does not exists.") 38 | self.path.unlink() 39 | 40 | def _label_cols_at_end(self, df): 41 | if hasattr(self, 'col_duration') and hasattr(self, 'col_event'): 42 | col_label = [self.col_duration, self.col_event] 43 | df = df[list(df.columns.drop(col_label)) + col_label] 44 | return df 45 | 46 | def checksum(self): 47 | """Checks that the dataset is correct. 48 | 49 | Returns: 50 | bool -- If the check passed. 51 | """ 52 | if self._checksum is None: 53 | raise NotImplementedError("No available comparison for this dataset.") 54 | df = self.read_df() 55 | return self._checksum_df(df) 56 | 57 | def _checksum_df(self, df): 58 | if self._checksum is None: 59 | raise NotImplementedError("No available comparison for this dataset.") 60 | import hashlib 61 | val = get_checksum(df) 62 | return val == self._checksum 63 | 64 | 65 | def get_checksum(df): 66 | import hashlib 67 | val = hashlib.sha256(df.to_csv().encode()).hexdigest() 68 | return val 69 | -------------------------------------------------------------------------------- /pycox/datasets/from_deepsurv.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import requests 3 | import h5py 4 | import pandas as pd 5 | from pycox.datasets._dataset_loader import _DatasetLoader 6 | 7 | 8 | class _DatasetDeepSurv(_DatasetLoader): 9 | _dataset_url = "https://raw.githubusercontent.com/jaredleekatzman/DeepSurv/master/experiments/data/" 10 | _datasets = { 11 | 'support': "support/support_train_test.h5", 12 | 'metabric': "metabric/metabric_IHC4_clinical_train_test.h5", 13 | 'gbsg': "gbsg/gbsg_cancer_train_test.h5", 14 | } 15 | col_duration = 'duration' 16 | col_event = 'event' 17 | def _download(self): 18 | url = self._dataset_url + self._datasets[self.name] 19 | path = self.path.parent / f"{self.name}.h5" 20 | with requests.Session() as s: 21 | r = s.get(url) 22 | with open(path, 'wb') as f: 23 | f.write(r.content) 24 | 25 | data = defaultdict(dict) 26 | with h5py.File(path) as f: 27 | for ds in f: 28 | for array in f[ds]: 29 | data[ds][array] = f[ds][array][:] 30 | 31 | path.unlink() 32 | train = _make_df(data['train']) 33 | test = _make_df(data['test']) 34 | df = pd.concat([train, test]).reset_index(drop=True) 35 | df.to_feather(self.path) 36 | 37 | 38 | def _make_df(data): 39 | x = data['x'] 40 | t = data['t'] 41 | d = data['e'] 42 | 43 | colnames = ['x'+str(i) for i in range(x.shape[1])] 44 | df = (pd.DataFrame(x, columns=colnames) 45 | .assign(duration=t) 46 | .assign(event=d)) 47 | return df 48 | 49 | 50 | class _Support(_DatasetDeepSurv): 51 | """Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT). 52 | 53 | A study of survival for seriously ill hospitalized adults. 54 | 55 | This is the processed data set used in the DeepSurv paper (Katzman et al. 2018), and details 56 | can be found at https://doi.org/10.1186/s12874-018-0482-1 57 | 58 | See https://github.com/jaredleekatzman/DeepSurv/tree/master/experiments/data 59 | for original data. 60 | 61 | Variables: 62 | x0, ..., x13: 63 | numerical covariates. 64 | duration: (duration) 65 | the right-censored event-times. 66 | event: (event) 67 | event indicator {1: event, 0: censoring}. 68 | """ 69 | name = 'support' 70 | _checksum = 'b07a9d216bf04501e832084e5b7955cb84dfef834810037c548dee82ea251f8d' 71 | 72 | 73 | class _Metabric(_DatasetDeepSurv): 74 | """The Molecular Taxonomy of Breast Cancer International Consortium (METABRIC). 75 | 76 | Gene and protein expression profiles to determine new breast cancer subgroups in 77 | order to help physicians provide better treatment recommendations. 78 | 79 | This is the processed data set used in the DeepSurv paper (Katzman et al. 2018), and details 80 | can be found at https://doi.org/10.1186/s12874-018-0482-1 81 | 82 | See https://github.com/jaredleekatzman/DeepSurv/tree/master/experiments/data 83 | for original data. 84 | 85 | Variables: 86 | x0, ..., x8: 87 | numerical covariates. 88 | duration: (duration) 89 | the right-censored event-times. 90 | event: (event) 91 | event indicator {1: event, 0: censoring}. 92 | """ 93 | name = 'metabric' 94 | _checksum = '310b74b97cc37c9eddd29f253ae3c06015dc63a17a71e4a68ff339dbe265f417' 95 | 96 | 97 | class _Gbsg(_DatasetDeepSurv): 98 | """ Rotterdam & German Breast Cancer Study Group (GBSG) 99 | 100 | A combination of the Rotterdam tumor bank and the German Breast Cancer Study Group. 101 | 102 | This is the processed data set used in the DeepSurv paper (Katzman et al. 2018), and details 103 | can be found at https://doi.org/10.1186/s12874-018-0482-1 104 | 105 | See https://github.com/jaredleekatzman/DeepSurv/tree/master/experiments/data 106 | for original data. 107 | 108 | Variables: 109 | x0, ..., x6: 110 | numerical covariates. 111 | duration: (duration) 112 | the right-censored event-times. 113 | event: (event) 114 | event indicator {1: event, 0: censoring}. 115 | """ 116 | name = 'gbsg' 117 | _checksum = 'de2359bee62bf36b9e3f901fea4a9fbef2d145e26e9384617d0d3f75892fe5ce' 118 | -------------------------------------------------------------------------------- /pycox/datasets/from_rdatasets.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pycox.datasets._dataset_loader import _DatasetLoader 3 | 4 | def download_from_rdatasets(package, name): 5 | datasets = (pd.read_csv("https://raw.githubusercontent.com/vincentarelbundock/Rdatasets/master/datasets.csv") 6 | .loc[lambda x: x['Package'] == package].set_index('Item')) 7 | if not name in datasets.index: 8 | raise ValueError(f"Dataset {name} not found.") 9 | info = datasets.loc[name] 10 | url = info.CSV 11 | return pd.read_csv(url), info 12 | 13 | 14 | class _DatasetRdatasetsSurvival(_DatasetLoader): 15 | """Data sets from Rdataset survival. 16 | """ 17 | def _download(self): 18 | df, info = download_from_rdatasets('survival', self.name) 19 | self.info = info 20 | df.to_feather(self.path) 21 | 22 | 23 | class _Flchain(_DatasetRdatasetsSurvival): 24 | """Assay of serum free light chain (FLCHAIN). 25 | Obtained from Rdatasets (https://github.com/vincentarelbundock/Rdatasets). 26 | 27 | A study of the relationship between serum free light chain (FLC) and mortality. 28 | The original sample contains samples on approximately 2/3 of the residents of Olmsted 29 | County aged 50 or greater. 30 | 31 | For details see http://vincentarelbundock.github.io/Rdatasets/doc/survival/flchain.html 32 | 33 | Variables: 34 | age: 35 | age in years. 36 | sex: 37 | F=female, M=male. 38 | sample.yr: 39 | the calendar year in which a blood sample was obtained. 40 | kappa: 41 | serum free light chain, kappa portion. 42 | lambda: 43 | serum free light chain, lambda portion. 44 | flc.grp: 45 | the FLC group for the subject, as used in the original analysis. 46 | creatinine: 47 | serum creatinine. 48 | mgus: 49 | 1 if the subject had been diagnosed with monoclonal gammapothy (MGUS). 50 | futime: (duration) 51 | days from enrollment until death. Note that there are 3 subjects whose sample 52 | was obtained on their death date. 53 | death: (event) 54 | 0=alive at last contact date, 1=dead. 55 | chapter: 56 | for those who died, a grouping of their primary cause of death by chapter headings 57 | of the International Code of Diseases ICD-9. 58 | 59 | """ 60 | name = 'flchain' 61 | col_duration = 'futime' 62 | col_event = 'death' 63 | _checksum = 'ec12748a1aa5790457c09793387337bb03b1dc45a22a2d58a8c2b9ad1f2648dd' 64 | 65 | def read_df(self, processed=True): 66 | """Get dataset. 67 | 68 | If 'processed' is False, return the raw data set. 69 | See the code for processing. 70 | 71 | Keyword Arguments: 72 | processed {bool} -- If 'False' get raw data, else get processed (see '??flchain.read_df'). 73 | (default: {True}) 74 | """ 75 | df = super().read_df() 76 | if processed: 77 | df = (df 78 | .drop(['chapter', 'Unnamed: 0'], axis=1) 79 | .loc[lambda x: x['creatinine'].isna() == False] 80 | .reset_index(drop=True) 81 | .assign(sex=lambda x: (x['sex'] == 'M'))) 82 | 83 | categorical = ['sample.yr', 'flc.grp'] 84 | for col in categorical: 85 | df[col] = df[col].astype('category') 86 | for col in df.columns.drop(categorical): 87 | df[col] = df[col].astype('float32') 88 | return df 89 | 90 | 91 | class _Nwtco(_DatasetRdatasetsSurvival): 92 | """Data from the National Wilm's Tumor Study (NWTCO) 93 | Obtained from Rdatasets (https://github.com/vincentarelbundock/Rdatasets). 94 | 95 | Measurement error example. Tumor histology predicts survival, but prediction is stronger 96 | with central lab histology than with the local institution determination. 97 | 98 | For details see http://vincentarelbundock.github.io/Rdatasets/doc/survival/nwtco.html 99 | 100 | Variables: 101 | seqno: 102 | id number 103 | instit: 104 | histology from local institution 105 | histol: 106 | histology from central lab 107 | stage: 108 | disease stage 109 | study: 110 | study 111 | rel: (event) 112 | indicator for relapse 113 | edrel: (duration) 114 | time to relapse 115 | age: 116 | age in months 117 | in.subcohort: 118 | included in the subcohort for the example in the paper 119 | 120 | References 121 | NE Breslow and N Chatterjee (1999), Design and analysis of two-phase studies with binary 122 | outcome applied to Wilms tumor prognosis. Applied Statistics 48, 457–68. 123 | """ 124 | name = 'nwtco' 125 | col_duration = 'edrel' 126 | col_event = 'rel' 127 | _checksum = '5aa3de698dadb60154dd59196796e382739ff56dc6cbd39cfc2fda50d69d118e' 128 | 129 | def read_df(self, processed=True): 130 | """Get dataset. 131 | 132 | If 'processed' is False, return the raw data set. 133 | See the code for processing. 134 | 135 | Keyword Arguments: 136 | processed {bool} -- If 'False' get raw data, else get processed (see '??nwtco.read_df'). 137 | (default: {True}) 138 | """ 139 | df = super().read_df() 140 | if processed: 141 | df = (df 142 | .assign(instit_2=df['instit'] - 1, 143 | histol_2=df['histol'] - 1, 144 | study_4=df['study'] - 3, 145 | stage=df['stage'].astype('category')) 146 | .drop(['Unnamed: 0', 'seqno', 'instit', 'histol', 'study'], axis=1)) 147 | for col in df.columns.drop('stage'): 148 | df[col] = df[col].astype('float32') 149 | df = self._label_cols_at_end(df) 150 | return df 151 | -------------------------------------------------------------------------------- /pycox/datasets/from_simulations.py: -------------------------------------------------------------------------------- 1 | """Make dataset from the simulations, so we don't have to compute over again. 2 | """ 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from pycox import simulations 7 | from pycox.datasets._dataset_loader import _DatasetLoader 8 | 9 | class _SimDataset(_DatasetLoader): 10 | col_duration = 'duration' 11 | col_event = 'event' 12 | cols_true = ['duration_true', 'censoring_true'] 13 | 14 | def read_df(self, add_true=True): 15 | if not self.path.exists(): 16 | print(f"Dataset '{self.name}' not created yet. Making dataset...") 17 | self._simulate_data() 18 | print(f"Done") 19 | df = super().read_df() 20 | if add_true is False: 21 | df = self._drop_true(df) 22 | return df 23 | 24 | def _simulate_data(self): 25 | raise NotImplementedError 26 | 27 | def _download(self): 28 | raise NotImplementedError("There is no `_download` for simulated data.") 29 | 30 | def _drop_true(self, df): 31 | return df.drop(columns=self.cols_true) 32 | 33 | 34 | class _RRNLNPH(_SimDataset): 35 | """Dataset from simulation study in "Time-to-Event Prediction with Neural 36 | Networks and Cox Regression" [1]. 37 | 38 | This is a continuous-time simulation study with event times drawn from a 39 | relative risk non-linear non-proportional hazards model (RRNLNPH). 40 | The full details are given in the paper [1]. 41 | 42 | The dataset is created with `pycox.simulations.SimStudyNonLinearNonPH` (see 43 | `rr_nl_nph._simulate_data`). 44 | 45 | Variables: 46 | x0, x1, x2: 47 | numerical covariates. 48 | duration: (duration) 49 | the right-censored event-times. 50 | event: (event) 51 | event indicator {1: event, 0: censoring}. 52 | duration_true: 53 | the uncensored event times. 54 | event_true: 55 | if `duration_true` is an event. 56 | censoring_true: 57 | the censoring times. 58 | 59 | To generate more data: 60 | >>> from pycox.simulations import SimStudyNonLinearNonPH 61 | >>> n = 10000 62 | >>> sim = SimStudyNonLinearNonPH() 63 | >>> data = sim.simulate(n) 64 | >>> df = sim.dict2df(data, True) 65 | 66 | References: 67 | [1] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel. 68 | Time-to-event prediction with neural networks and Cox regression. 69 | Journal of Machine Learning Research, 20(129):1–30, 2019. 70 | http://jmlr.org/papers/v20/18-424.html 71 | """ 72 | name = 'rr_nl_nph' 73 | _checksum = '4952a8712403f7222d1bec58e36cdbfcd46aa31ddf87c5fb2c455565fc3f7068' 74 | 75 | def _simulate_data(self): 76 | np.random.seed(1234) 77 | sim = simulations.SimStudyNonLinearNonPH() 78 | data = sim.simulate(25000) 79 | df = sim.dict2df(data, True) 80 | df.to_feather(self.path) 81 | 82 | 83 | class _SAC3(_SimDataset): 84 | """Dataset from simulation study in "Continuous and Discrete-Time Survival Prediction 85 | with Neural Networks" [1]. 86 | 87 | The dataset is created with `pycox.simulations.SimStudySACConstCensor` 88 | (see `sac3._simulate_data`). 89 | 90 | The full details are given in Appendix A.1 in [1]. 91 | 92 | Variables: 93 | x0, ..., x44: 94 | numerical covariates. 95 | duration: (duration) 96 | the right-censored event-times. 97 | event: (event) 98 | event indicator {1: event, 0: censoring}. 99 | duration_true: 100 | the uncensored event times (only censored at max-time 100.) 101 | event_true: 102 | if `duration_true` is an event. 103 | censoring_true: 104 | the censoring times. 105 | 106 | To generate more data: 107 | >>> from pycox.simulations import SimStudySACCensorConst 108 | >>> n = 10000 109 | >>> sim = SimStudySACCensorConst() 110 | >>> data = sim.simulate(n) 111 | >>> df = sim.dict2df(data, True, False) 112 | 113 | References: 114 | [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction 115 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. 116 | https://arxiv.org/pdf/1910.06724.pdf 117 | """ 118 | name = 'sac3' 119 | _checksum = '2941d46baf0fbae949933565dc88663adbf1d8f5a58f989baf915d6586641fea' 120 | 121 | def _simulate_data(self): 122 | np.random.seed(1234) 123 | sim = simulations.SimStudySACCensorConst() 124 | data = sim.simulate(100000) 125 | df = sim.dict2df(data, True, False) 126 | df.to_feather(self.path) 127 | 128 | 129 | class _SACAdmin5(_SimDataset): 130 | """Dataset from simulation study in [1]. 131 | The survival function is the same as in sac3, but the censoring is administrative 132 | and determined by five covariates. 133 | 134 | Variables: 135 | x0, ..., x22: 136 | numerical covariates. 137 | duration: (duration) 138 | the right-censored event-times. 139 | event: (event) 140 | event indicator {1: event, 0: censoring}. 141 | duration_true: 142 | the uncensored event times (only censored at max-time 100.) 143 | event_true: 144 | if `duration_true` is an event or right-censored at time 100. 145 | censoring_true: 146 | the censoring times. 147 | 148 | To generate more data: 149 | >>> from pycox.simulations import SimStudySACAdmin 150 | >>> n = 10000 151 | >>> sim = SimStudySACAdmin() 152 | >>> data = sim.simulate(n) 153 | >>> df = sim.dict2df(data, True, True) 154 | 155 | References: 156 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems 157 | and Solutions. arXiv preprint arXiv:1912.08581, 2019. 158 | https://arxiv.org/pdf/1912.08581.pdf 159 | """ 160 | name = 'sac_admin5' 161 | _checksum = '9882bc8651315bcd80cba20b5f11040d71e4a84865898d7c2ca7b82ccba56683' 162 | 163 | def _simulate_data(self): 164 | np.random.seed(1234) 165 | sim = simulations.SimStudySACAdmin() 166 | data = sim.simulate(50000) 167 | df = sim.dict2df(data, True, True) 168 | df.to_feather(self.path) 169 | -------------------------------------------------------------------------------- /pycox/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from pycox.evaluation.eval_surv import EvalSurv 3 | # from pycox.evaluation import binomial_log_likelihood, brier_score,\ 4 | # integrated_binomial_log_likelihood, integrated_brier_score 5 | # from pycox.evaluation.concordance import concordance_td 6 | # from pycox.evaluation.km_inverce_censor_weight import binomial_log_likelihood_km, brier_score_km,\ 7 | # integrated_binomial_log_likelihood_km_numpy, integrated_brier_score_km_numpy 8 | -------------------------------------------------------------------------------- /pycox/evaluation/admin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import numba 4 | from pycox.utils import idx_at_times 5 | 6 | 7 | def administrative_scores(func): 8 | if not func.__class__.__module__.startswith('numba'): 9 | raise ValueError("Need to provide numba compiled function") 10 | def metric(time_grid, durations, durations_c, events, surv, index_surv, reduce=True, steps_surv='post'): 11 | if not hasattr(time_grid, '__iter__'): 12 | time_grid = np.array([time_grid]) 13 | assert (type(time_grid) is type(durations) is type(events) is type(surv) is 14 | type(index_surv) is type(durations_c) is np.ndarray), 'Need all input to be np.ndarrays' 15 | assert (durations[events == 0] == durations_c[events == 0]).all(), 'Censored observations need same `durations` and `durations_c`' 16 | assert (durations[events == 1] <= durations_c[events == 1]).all(), '`durations` cannot be larger than `durations_c`' 17 | idx_ts_surv = idx_at_times(index_surv, time_grid, steps_surv, assert_sorted=True) 18 | scores, norm = _admin_scores(func, time_grid, durations, durations_c, events, surv, idx_ts_surv) 19 | if reduce is True: 20 | return scores.sum(axis=1) / norm 21 | return scores, norm.reshape(-1, 1) 22 | return metric 23 | 24 | @numba.njit(parallel=True) 25 | def _admin_scores(func, time_grid, durations, durations_c, events, surv, idx_ts_surv): 26 | def _single(func, ts, durations, durations_c, events, surv, idx_ts_surv_i, 27 | scores, n_indiv): 28 | for i in range(n_indiv): 29 | tt = durations[i] 30 | tc = durations_c[i] 31 | d = events[i] 32 | s = surv[idx_ts_surv_i, i] 33 | scores[i] = func(ts, tt, tc, d, s) 34 | 35 | n_times = len(time_grid) 36 | n_indiv = len(durations) 37 | scores = np.empty((n_times, n_indiv)) 38 | scores.fill(np.nan) 39 | normalizer = np.empty(n_times) 40 | normalizer.fill(np.nan) 41 | for i in numba.prange(n_times): 42 | ts = time_grid[i] 43 | idx_ts_surv_i = idx_ts_surv[i] 44 | scores_i = scores[i] 45 | normalizer[i] = (durations_c >= ts).sum() 46 | _single(func, ts, durations, durations_c, events, surv, idx_ts_surv_i, scores_i, n_indiv) 47 | return scores, normalizer 48 | 49 | @numba.njit 50 | def _brier_score(ts, tt, tc, d, s): 51 | if (tt <= ts) and (d == 1) and (tc >= ts): 52 | return np.power(s, 2) 53 | if tt >= ts: 54 | return np.power(1 - s, 2) 55 | return 0. 56 | 57 | @numba.njit 58 | def _binomial_log_likelihood(ts, tt, tc, d, s, eps=1e-7): 59 | if s < eps: 60 | s = eps 61 | elif s > (1 - eps): 62 | s = 1 - eps 63 | if (tt <= ts) and (d == 1) and (tc >= ts): 64 | return np.log(1 - s) 65 | if tt >= ts: 66 | return np.log(s) 67 | return 0. 68 | 69 | brier_score = administrative_scores(_brier_score) 70 | binomial_log_likelihood = administrative_scores(_binomial_log_likelihood) 71 | 72 | 73 | def _integrated_admin_metric(func): 74 | def metric(time_grid, durations, durations_c, events, surv, index_surv, steps_surv='post'): 75 | scores = func(time_grid, durations, durations_c, events, surv, index_surv, True, steps_surv) 76 | integral = scipy.integrate.simps(scores, time_grid) 77 | return integral / (time_grid[-1] - time_grid[0]) 78 | return metric 79 | 80 | integrated_brier_score = _integrated_admin_metric(brier_score) 81 | integrated_binomial_log_likelihood = _integrated_admin_metric(binomial_log_likelihood) 82 | 83 | -------------------------------------------------------------------------------- /pycox/evaluation/concordance.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pandas as pd 4 | import numba 5 | 6 | 7 | @numba.jit(nopython=True) 8 | def _is_comparable(t_i, t_j, d_i, d_j): 9 | return ((t_i < t_j) & d_i) | ((t_i == t_j) & (d_i | d_j)) 10 | 11 | @numba.jit(nopython=True) 12 | def _is_comparable_antolini(t_i, t_j, d_i, d_j): 13 | return ((t_i < t_j) & d_i) | ((t_i == t_j) & d_i & (d_j == 0)) 14 | 15 | @numba.jit(nopython=True) 16 | def _is_concordant(s_i, s_j, t_i, t_j, d_i, d_j): 17 | conc = 0. 18 | if t_i < t_j: 19 | conc = (s_i < s_j) + (s_i == s_j) * 0.5 20 | elif t_i == t_j: 21 | if d_i & d_j: 22 | conc = 1. - (s_i != s_j) * 0.5 23 | elif d_i: 24 | conc = (s_i < s_j) + (s_i == s_j) * 0.5 # different from RSF paper. 25 | elif d_j: 26 | conc = (s_i > s_j) + (s_i == s_j) * 0.5 # different from RSF paper. 27 | return conc * _is_comparable(t_i, t_j, d_i, d_j) 28 | 29 | @numba.jit(nopython=True) 30 | def _is_concordant_antolini(s_i, s_j, t_i, t_j, d_i, d_j): 31 | return (s_i < s_j) & _is_comparable_antolini(t_i, t_j, d_i, d_j) 32 | 33 | @numba.jit(nopython=True, parallel=True) 34 | def _sum_comparable(t, d, is_comparable_func): 35 | n = t.shape[0] 36 | count = 0. 37 | for i in numba.prange(n): 38 | for j in range(n): 39 | if j != i: 40 | count += is_comparable_func(t[i], t[j], d[i], d[j]) 41 | return count 42 | 43 | @numba.jit(nopython=True, parallel=True) 44 | def _sum_concordant(s, t, d): 45 | n = len(t) 46 | count = 0. 47 | for i in numba.prange(n): 48 | for j in range(n): 49 | if j != i: 50 | count += _is_concordant(s[i, i], s[i, j], t[i], t[j], d[i], d[j]) 51 | return count 52 | 53 | @numba.jit(nopython=True, parallel=True) 54 | def _sum_concordant_disc(s, t, d, s_idx, is_concordant_func): 55 | n = len(t) 56 | count = 0 57 | for i in numba.prange(n): 58 | idx = s_idx[i] 59 | for j in range(n): 60 | if j != i: 61 | count += is_concordant_func(s[idx, i], s[idx, j], t[i], t[j], d[i], d[j]) 62 | return count 63 | 64 | def concordance_td(durations, events, surv, surv_idx, method='adj_antolini'): 65 | """Time dependent concorance index from 66 | Antolini, L.; Boracchi, P.; and Biganzoli, E. 2005. A timedependent discrimination 67 | index for survival data. Statistics in Medicine 24:3927–3944. 68 | 69 | If 'method' is 'antolini', the concordance from Antolini et al. is computed. 70 | 71 | If 'method' is 'adj_antolini' (default) we have made a small modifications 72 | for ties in predictions and event times. 73 | We have followed step 3. in Sec 5.1. in Random Survial Forests paper, except for the last 74 | point with "T_i = T_j, but not both are deaths", as that doesn't make much sense. 75 | See '_is_concordant'. 76 | 77 | Arguments: 78 | durations {np.array[n]} -- Event times (or censoring times.) 79 | events {np.array[n]} -- Event indicators (0 is censoring). 80 | surv {np.array[n_times, n]} -- Survival function (each row is a duraratoin, and each col 81 | is an individual). 82 | surv_idx {np.array[n_test]} -- Mapping of survival_func s.t. 'surv_idx[i]' gives index in 83 | 'surv' corresponding to the event time of individual 'i'. 84 | 85 | Keyword Arguments: 86 | method {str} -- Type of c-index 'antolini' or 'adj_antolini' (default {'adj_antolini'}). 87 | 88 | Returns: 89 | float -- Time dependent concordance index. 90 | """ 91 | if np.isfortran(surv): 92 | surv = np.array(surv, order='C') 93 | assert durations.shape[0] == surv.shape[1] == surv_idx.shape[0] == events.shape[0] 94 | assert type(durations) is type(events) is type(surv) is type(surv_idx) is np.ndarray 95 | if events.dtype in ('float', 'float32'): 96 | events = events.astype('int32') 97 | if method == 'adj_antolini': 98 | is_concordant = _is_concordant 99 | is_comparable = _is_comparable 100 | return (_sum_concordant_disc(surv, durations, events, surv_idx, is_concordant) / 101 | _sum_comparable(durations, events, is_comparable)) 102 | elif method == 'antolini': 103 | is_concordant = _is_concordant_antolini 104 | is_comparable = _is_comparable_antolini 105 | return (_sum_concordant_disc(surv, durations, events, surv_idx, is_concordant) / 106 | _sum_comparable(durations, events, is_comparable)) 107 | return ValueError(f"Need 'method' to be e.g. 'antolini', got '{method}'.") 108 | 109 | -------------------------------------------------------------------------------- /pycox/evaluation/eval_surv.py: -------------------------------------------------------------------------------- 1 | 2 | import warnings 3 | import numpy as np 4 | import pandas as pd 5 | from pycox.evaluation.concordance import concordance_td 6 | from pycox.evaluation import ipcw, admin 7 | from pycox import utils 8 | 9 | 10 | class EvalSurv: 11 | """Class for evaluating predictions. 12 | 13 | Arguments: 14 | surv {pd.DataFrame} -- Survival predictions. 15 | durations {np.array} -- Durations of test set. 16 | events {np.array} -- Events of test set. 17 | 18 | Keyword Arguments: 19 | censor_surv {str, pd.DataFrame, EvalSurv} -- Censoring distribution. 20 | If provided data frame (survival function for censoring) or EvalSurv object, 21 | this will be used. 22 | If 'km', we will fit a Kaplan-Meier to the dataset. 23 | (default: {None}) 24 | censor_durations {np.array}: -- Administrative censoring times. (default: {None}) 25 | steps {str} -- For durations between values of `surv.index` choose the higher index 'pre' 26 | or lower index 'post'. For a visualization see `help(EvalSurv.steps)`. (default: {'post'}) 27 | """ 28 | def __init__(self, surv, durations, events, censor_surv=None, censor_durations=None, steps='post'): 29 | assert (type(durations) == type(events) == np.ndarray), 'Need `durations` and `events` to be arrays' 30 | self.surv = surv 31 | self.durations = durations 32 | self.events = events 33 | self.censor_surv = censor_surv 34 | self.censor_durations = censor_durations 35 | self.steps = steps 36 | assert pd.Series(self.index_surv).is_monotonic_increasing 37 | 38 | @property 39 | def censor_surv(self): 40 | """Estimated survival for censorings. 41 | Also an EvalSurv object. 42 | """ 43 | return self._censor_surv 44 | 45 | @censor_surv.setter 46 | def censor_surv(self, censor_surv): 47 | if isinstance(censor_surv, EvalSurv): 48 | self._censor_surv = censor_surv 49 | elif type(censor_surv) is str: 50 | if censor_surv == 'km': 51 | self.add_km_censor() 52 | else: 53 | raise ValueError(f"censor_surv cannot be {censor_surv}. Use e.g. 'km'") 54 | elif censor_surv is not None: 55 | self.add_censor_est(censor_surv) 56 | else: 57 | self._censor_surv = None 58 | 59 | @property 60 | def index_surv(self): 61 | return self.surv.index.values 62 | 63 | @property 64 | def steps(self): 65 | """How to handle predictions that are between two indexes in `index_surv`. 66 | 67 | For a visualization, run the following: 68 | ev = EvalSurv(pd.DataFrame(np.linspace(1, 0, 7)), np.empty(7), np.ones(7), steps='pre') 69 | ax = ev[0].plot_surv() 70 | ev.steps = 'post' 71 | ev[0].plot_surv(ax=ax, style='--') 72 | ax.legend(['pre', 'post']) 73 | """ 74 | return self._steps 75 | 76 | @steps.setter 77 | def steps(self, steps): 78 | vals = ['post', 'pre'] 79 | if steps not in vals: 80 | raise ValueError(f"`steps` needs to be {vals}, got {steps}") 81 | self._steps = steps 82 | 83 | def add_censor_est(self, censor_surv, steps='post'): 84 | """Add censoring estimates so one can use inverse censoring weighting. 85 | `censor_surv` are the survival estimates trained on (durations, 1-events), 86 | 87 | Arguments: 88 | censor_surv {pd.DataFrame} -- Censor survival curves. 89 | 90 | Keyword Arguments: 91 | round {str} -- For durations between values of `surv.index` choose the higher index 'pre' 92 | or lower index 'post'. If `None` use `self.steps` (default: {None}) 93 | """ 94 | if not isinstance(censor_surv, EvalSurv): 95 | censor_surv = self._constructor(censor_surv, self.durations, 1-self.events, None, 96 | steps=steps) 97 | self.censor_surv = censor_surv 98 | return self 99 | 100 | def add_km_censor(self, steps='post'): 101 | """Add censoring estimates obtained by Kaplan-Meier on the test set 102 | (durations, 1-events). 103 | """ 104 | km = utils.kaplan_meier(self.durations, 1-self.events) 105 | surv = pd.DataFrame(np.repeat(km.values.reshape(-1, 1), len(self.durations), axis=1), 106 | index=km.index) 107 | return self.add_censor_est(surv, steps) 108 | 109 | @property 110 | def censor_durations(self): 111 | """Administrative censoring times.""" 112 | return self._censor_durations 113 | 114 | @censor_durations.setter 115 | def censor_durations(self, val): 116 | if val is not None: 117 | assert (self.durations[self.events == 0] == val[self.events == 0]).all(),\ 118 | 'Censored observations need same `durations` and `censor_durations`' 119 | assert (self.durations[self.events == 1] <= val[self.events == 1]).all(),\ 120 | '`durations` cannot be larger than `censor_durations`' 121 | if (self.durations == val).all(): 122 | warnings.warn("`censor_durations` are equal to `durations`." + 123 | " `censor_durations` are likely wrong!") 124 | self._censor_durations = val 125 | else: 126 | self._censor_durations = val 127 | 128 | @property 129 | def _constructor(self): 130 | return EvalSurv 131 | 132 | def __getitem__(self, index): 133 | if not (hasattr(index, '__iter__') or type(index) is slice) : 134 | index = [index] 135 | surv = self.surv.iloc[:, index] 136 | durations = self.durations[index] 137 | events = self.events[index] 138 | new = self._constructor(surv, durations, events, None, steps=self.steps) 139 | if self.censor_surv is not None: 140 | new.censor_surv = self.censor_surv[index] 141 | return new 142 | 143 | def plot_surv(self, **kwargs): 144 | """Plot survival estimates. 145 | kwargs are passed to `self.surv.plot`. 146 | """ 147 | if len(self.durations) > 50: 148 | raise RuntimeError("We don't allow to plot more than 50 lines. Use e.g. `ev[1:5].plot()`") 149 | if 'drawstyle' in kwargs: 150 | raise RuntimeError(f"`drawstyle` is set by `self.steps`. Remove from **kwargs") 151 | return self.surv.plot(drawstyle=f"steps-{self.steps}", **kwargs) 152 | 153 | def idx_at_times(self, times): 154 | """Get the index (iloc) of the `surv.index` closest to `times`. 155 | I.e. surv.loc[tims] (almost)= surv.iloc[idx_at_times(times)]. 156 | 157 | Useful for finding predictions at given durations. 158 | """ 159 | return utils.idx_at_times(self.index_surv, times, self.steps) 160 | 161 | def _duration_idx(self): 162 | return self.idx_at_times(self.durations) 163 | 164 | def surv_at_times(self, times): 165 | idx = self.idx_at_times(times) 166 | return self.surv.iloc[idx] 167 | 168 | # def prob_alive(self, time_grid): 169 | # return self.surv_at_times(time_grid).values 170 | 171 | def concordance_td(self, method='adj_antolini'): 172 | """Time dependent concorance index from 173 | Antolini, L.; Boracchi, P.; and Biganzoli, E. 2005. A time-dependent discrimination 174 | index for survival data. Statistics in Medicine 24:3927–3944. 175 | 176 | If 'method' is 'antolini', the concordance from Antolini et al. is computed. 177 | 178 | If 'method' is 'adj_antolini' (default) we have made a small modifications 179 | for ties in predictions and event times. 180 | We have followed step 3. in Sec 5.1. in Random Survival Forests paper, except for the last 181 | point with "T_i = T_j, but not both are deaths", as that doesn't make much sense. 182 | See 'metrics._is_concordant'. 183 | 184 | Keyword Arguments: 185 | method {str} -- Type of c-index 'antolini' or 'adj_antolini' (default {'adj_antolini'}). 186 | 187 | Returns: 188 | float -- Time dependent concordance index. 189 | """ 190 | return concordance_td(self.durations, self.events, self.surv.values, 191 | self._duration_idx(), method) 192 | 193 | def brier_score(self, time_grid, max_weight=np.inf): 194 | """Brier score weighted by the inverse censoring distribution. 195 | See Section 3.1.2 or [1] for details of the wighting scheme. 196 | 197 | Arguments: 198 | time_grid {np.array} -- Durations where the brier score should be calculated. 199 | 200 | Keyword Arguments: 201 | max_weight {float} -- Max weight value (max number of individuals an individual 202 | can represent (default {np.inf}). 203 | 204 | References: 205 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems 206 | and Solutions. arXiv preprint arXiv:1912.08581, 2019. 207 | https://arxiv.org/pdf/1912.08581.pdf 208 | """ 209 | if self.censor_surv is None: 210 | raise ValueError("""Need to add censor_surv to compute Brier score. Use 'add_censor_est' 211 | or 'add_km_censor' for Kaplan-Meier""") 212 | bs = ipcw.brier_score(time_grid, self.durations, self.events, self.surv.values, 213 | self.censor_surv.surv.values, self.index_surv, 214 | self.censor_surv.index_surv, max_weight, True, self.steps, 215 | self.censor_surv.steps) 216 | return pd.Series(bs, index=time_grid).rename('brier_score') 217 | 218 | def nbll(self, time_grid, max_weight=np.inf): 219 | """Negative binomial log-likelihood weighted by the inverse censoring distribution. 220 | See Section 3.1.2 or [1] for details of the wighting scheme. 221 | 222 | Arguments: 223 | time_grid {np.array} -- Durations where the brier score should be calculated. 224 | 225 | Keyword Arguments: 226 | max_weight {float} -- Max weight value (max number of individuals an individual 227 | can represent (default {np.inf}). 228 | 229 | References: 230 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems 231 | and Solutions. arXiv preprint arXiv:1912.08581, 2019. 232 | https://arxiv.org/pdf/1912.08581.pdf 233 | """ 234 | if self.censor_surv is None: 235 | raise ValueError("""Need to add censor_surv to compute the score. Use 'add_censor_est' 236 | or 'add_km_censor' for Kaplan-Meier""") 237 | bll = ipcw.binomial_log_likelihood(time_grid, self.durations, self.events, self.surv.values, 238 | self.censor_surv.surv.values, self.index_surv, 239 | self.censor_surv.index_surv, max_weight, True, self.steps, 240 | self.censor_surv.steps) 241 | return pd.Series(-bll, index=time_grid).rename('nbll') 242 | 243 | def integrated_brier_score(self, time_grid, max_weight=np.inf): 244 | """Integrated Brier score weighted by the inverse censoring distribution. 245 | Essentially an integral over values obtained from `brier_score(time_grid, max_weight)`. 246 | 247 | Arguments: 248 | time_grid {np.array} -- Durations where the brier score should be calculated. 249 | 250 | Keyword Arguments: 251 | max_weight {float} -- Max weight value (max number of individuals an individual 252 | can represent (default {np.inf}). 253 | """ 254 | if self.censor_surv is None: 255 | raise ValueError("Need to add censor_surv to compute briser score. Use 'add_censor_est'") 256 | return ipcw.integrated_brier_score(time_grid, self.durations, self.events, self.surv.values, 257 | self.censor_surv.surv.values, self.index_surv, 258 | self.censor_surv.index_surv, max_weight, self.steps, 259 | self.censor_surv.steps) 260 | 261 | def integrated_nbll(self, time_grid, max_weight=np.inf): 262 | """Integrated negative binomial log-likelihood weighted by the inverse censoring distribution. 263 | Essentially an integral over values obtained from `nbll(time_grid, max_weight)`. 264 | 265 | Arguments: 266 | time_grid {np.array} -- Durations where the brier score should be calculated. 267 | 268 | Keyword Arguments: 269 | max_weight {float} -- Max weight value (max number of individuals an individual 270 | can represent (default {np.inf}). 271 | """ 272 | if self.censor_surv is None: 273 | raise ValueError("Need to add censor_surv to compute the score. Use 'add_censor_est'") 274 | ibll = ipcw.integrated_binomial_log_likelihood(time_grid, self.durations, self.events, self.surv.values, 275 | self.censor_surv.surv.values, self.index_surv, 276 | self.censor_surv.index_surv, max_weight, self.steps, 277 | self.censor_surv.steps) 278 | return -ibll 279 | 280 | def brier_score_admin(self, time_grid): 281 | """The Administrative Brier score proposed by [1]. 282 | Removes individuals as they are administratively censored, event if they have experienced an 283 | event. 284 | 285 | Arguments: 286 | time_grid {np.array} -- Durations where the brier score should be calculated. 287 | 288 | References: 289 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems 290 | and Solutions. arXiv preprint arXiv:1912.08581, 2019. 291 | https://arxiv.org/pdf/1912.08581.pdf 292 | """ 293 | if self.censor_durations is None: 294 | raise ValueError("Need to provide `censor_durations` (censoring durations) to use this method") 295 | bs = admin.brier_score(time_grid, self.durations, self.censor_durations, self.events, 296 | self.surv.values, self.index_surv, True, self.steps) 297 | return pd.Series(bs, index=time_grid).rename('brier_score') 298 | 299 | def integrated_brier_score_admin(self, time_grid): 300 | """The Integrated administrative Brier score proposed by [1]. 301 | Removes individuals as they are administratively censored, event if they have experienced an 302 | event. 303 | 304 | Arguments: 305 | time_grid {np.array} -- Durations where the brier score should be calculated. 306 | 307 | References: 308 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems 309 | and Solutions. arXiv preprint arXiv:1912.08581, 2019. 310 | https://arxiv.org/pdf/1912.08581.pdf 311 | """ 312 | if self.censor_durations is None: 313 | raise ValueError("Need to provide `censor_durations` (censoring durations) to use this method") 314 | ibs = admin.integrated_brier_score(time_grid, self.durations, self.censor_durations, self.events, 315 | self.surv.values, self.index_surv, self.steps) 316 | return ibs 317 | 318 | def nbll_admin(self, time_grid): 319 | """The negative administrative binomial log-likelihood proposed by [1]. 320 | Removes individuals as they are administratively censored, event if they have experienced an 321 | event. 322 | 323 | Arguments: 324 | time_grid {np.array} -- Durations where the brier score should be calculated. 325 | 326 | References: 327 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems 328 | and Solutions. arXiv preprint arXiv:1912.08581, 2019. 329 | https://arxiv.org/pdf/1912.08581.pdf 330 | """ 331 | if self.censor_durations is None: 332 | raise ValueError("Need to provide `censor_durations` (censoring durations) to use this method") 333 | bll = admin.binomial_log_likelihood(time_grid, self.durations, self.censor_durations, self.events, 334 | self.surv.values, self.index_surv, True, self.steps) 335 | return pd.Series(-bll, index=time_grid).rename('nbll') 336 | 337 | def integrated_nbll_admin(self, time_grid): 338 | """The Integrated negative administrative binomial log-likelihood score proposed by [1]. 339 | Removes individuals as they are administratively censored, event if they have experienced an 340 | event. 341 | 342 | Arguments: 343 | time_grid {np.array} -- Durations where the brier score should be calculated. 344 | 345 | References: 346 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems 347 | and Solutions. arXiv preprint arXiv:1912.08581, 2019. 348 | https://arxiv.org/pdf/1912.08581.pdf 349 | """ 350 | if self.censor_durations is None: 351 | raise ValueError("Need to provide `censor_durations` (censoring durations) to use this method") 352 | ibll = admin.integrated_binomial_log_likelihood(time_grid, self.durations, self.censor_durations, 353 | self.events, self.surv.values, self.index_surv, 354 | self.steps) 355 | return -ibll 356 | -------------------------------------------------------------------------------- /pycox/evaluation/ipcw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy 3 | import numba 4 | from pycox import utils 5 | 6 | @numba.njit(parallel=True) 7 | def _inv_cens_scores(func, time_grid, durations, events, surv, censor_surv, idx_ts_surv, idx_ts_censor, 8 | idx_tt_censor, scores, weights, n_times, n_indiv, max_weight): 9 | def _inv_cens_score_single(func, ts, durations, events, surv, censor_surv, idx_ts_surv_i, 10 | idx_ts_censor_i, idx_tt_censor, scores, weights, n_indiv, max_weight): 11 | min_g = 1./max_weight 12 | for i in range(n_indiv): 13 | tt = durations[i] 14 | d = events[i] 15 | s = surv[idx_ts_surv_i, i] 16 | g_ts = censor_surv[idx_ts_censor_i, i] 17 | g_tt = censor_surv[idx_tt_censor[i], i] 18 | g_ts = max(g_ts, min_g) 19 | g_tt = max(g_tt, min_g) 20 | score, w = func(ts, tt, s, g_ts, g_tt, d) 21 | #w = min(w, max_weight) 22 | scores[i] = score * w 23 | weights[i] = w 24 | 25 | for i in numba.prange(n_times): 26 | ts = time_grid[i] 27 | idx_ts_surv_i = idx_ts_surv[i] 28 | idx_ts_censor_i = idx_ts_censor[i] 29 | scores_i = scores[i] 30 | weights_i = weights[i] 31 | _inv_cens_score_single(func, ts, durations, events, surv, censor_surv, idx_ts_surv_i, 32 | idx_ts_censor_i, idx_tt_censor, scores_i, weights_i, n_indiv, max_weight) 33 | 34 | def _inverse_censoring_weighted_metric(func): 35 | if not func.__class__.__module__.startswith('numba'): 36 | raise ValueError("Need to provide numba compiled function") 37 | def metric(time_grid, durations, events, surv, censor_surv, index_surv, index_censor, max_weight=np.inf, 38 | reduce=True, steps_surv='post', steps_censor='post'): 39 | if not hasattr(time_grid, '__iter__'): 40 | time_grid = np.array([time_grid]) 41 | assert (type(time_grid) is type(durations) is type(events) is type(surv) is type(censor_surv) is 42 | type(index_surv) is type(index_censor) is np.ndarray), 'Need all input to be np.ndarrays' 43 | n_times = len(time_grid) 44 | n_indiv = len(durations) 45 | scores = np.zeros((n_times, n_indiv)) 46 | weights = np.zeros((n_times, n_indiv)) 47 | idx_ts_surv = utils.idx_at_times(index_surv, time_grid, steps_surv, assert_sorted=True) 48 | idx_ts_censor = utils.idx_at_times(index_censor, time_grid, steps_censor, assert_sorted=True) 49 | idx_tt_censor = utils.idx_at_times(index_censor, durations, 'pre', assert_sorted=True) 50 | if steps_censor == 'post': 51 | idx_tt_censor = (idx_tt_censor - 1).clip(0) 52 | # This ensures that we get G(tt-) 53 | _inv_cens_scores(func, time_grid, durations, events, surv, censor_surv, idx_ts_surv, idx_ts_censor, 54 | idx_tt_censor, scores, weights, n_times, n_indiv, max_weight) 55 | if reduce is True: 56 | return np.sum(scores, axis=1) / np.sum(weights, axis=1) 57 | return scores, weights 58 | return metric 59 | 60 | @numba.njit() 61 | def _brier_score(ts, tt, s, g_ts, g_tt, d): 62 | if (tt <= ts) and d == 1: 63 | return np.power(s, 2), 1./g_tt 64 | if tt > ts: 65 | return np.power(1 - s, 2), 1./g_ts 66 | return 0., 0. 67 | 68 | @numba.njit() 69 | def _binomial_log_likelihood(ts, tt, s, g_ts, g_tt, d, eps=1e-7): 70 | s = eps if s < eps else s 71 | s = (1-eps) if s > (1 - eps) else s 72 | if (tt <= ts) and d == 1: 73 | return np.log(1 - s), 1./g_tt 74 | if tt > ts: 75 | return np.log(s), 1./g_ts 76 | return 0., 0. 77 | 78 | brier_score = _inverse_censoring_weighted_metric(_brier_score) 79 | binomial_log_likelihood = _inverse_censoring_weighted_metric(_binomial_log_likelihood) 80 | 81 | def _integrated_inverce_censoring_weighed_metric(func): 82 | def metric(time_grid, durations, events, surv, censor_surv, index_surv, index_censor, 83 | max_weight=np.inf, steps_surv='post', steps_censor='post'): 84 | scores = func(time_grid, durations, events, surv, censor_surv, index_surv, index_censor, 85 | max_weight, True, steps_surv, steps_censor) 86 | integral = scipy.integrate.simps(scores, time_grid) 87 | return integral / (time_grid[-1] - time_grid[0]) 88 | return metric 89 | 90 | integrated_brier_score = _integrated_inverce_censoring_weighed_metric(brier_score) 91 | integrated_binomial_log_likelihood = _integrated_inverce_censoring_weighed_metric(binomial_log_likelihood) 92 | -------------------------------------------------------------------------------- /pycox/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Some relevant metrics 3 | ''' 4 | import numpy as np 5 | import pandas as pd 6 | 7 | def partial_log_likelihood_ph(log_partial_hazards, durations, events, mean=True): 8 | """Partial log-likelihood for PH models. 9 | 10 | Arguments: 11 | log_partial_hazards {np.array} -- Log partial hazards (e.g. x^T beta). 12 | durations {np.array} -- Durations. 13 | events {np.array} -- Events. 14 | 15 | Keyword Arguments: 16 | mean {bool} -- Return the mean. (default: {True}) 17 | 18 | Returns: 19 | pd.Series or float -- partial log-likelihood or mean. 20 | """ 21 | 22 | df = pd.DataFrame(dict(duration=durations, event=events, lph=log_partial_hazards)) 23 | pll = (df 24 | .sort_values('duration', ascending=False) 25 | .assign(cum_ph=(lambda x: x['lph'] 26 | .pipe(np.exp) 27 | .cumsum() 28 | .groupby(x['duration']) 29 | .transform('max'))) 30 | .loc[lambda x: x['event'] == 1] 31 | .assign(pll=lambda x: x['lph'] - np.log(x['cum_ph'])) 32 | ['pll']) 33 | if mean: 34 | return pll.mean() 35 | return pll 36 | -------------------------------------------------------------------------------- /pycox/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from pycox.models import base, loss, utils, pmf, data 3 | from pycox.models.cox import CoxPH 4 | from pycox.models.cox_cc import CoxCC 5 | from pycox.models.cox_time import CoxTime 6 | from pycox.models.deephit import DeepHitSingle, DeepHit 7 | from pycox.models.pmf import PMF 8 | from pycox.models.logistic_hazard import LogisticHazard 9 | from pycox.models.pc_hazard import PCHazard 10 | from pycox.models.mtlr import MTLR 11 | from pycox.models.bce_surv import BCESurv 12 | -------------------------------------------------------------------------------- /pycox/models/base.py: -------------------------------------------------------------------------------- 1 | import torchtuples as tt 2 | import warnings 3 | 4 | 5 | class SurvBase(tt.Model): 6 | """Base class for survival models. 7 | Essentially same as torchtuples.Model, 8 | """ 9 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, 10 | to_cpu=False, num_workers=0): 11 | """Predict the survival function for `input`. 12 | See `prediction_surv_df` to return a DataFrame instead. 13 | 14 | Arguments: 15 | input {dataloader, tuple, np.ndarray, or torch.tensor} -- Input to net. 16 | 17 | Keyword Arguments: 18 | batch_size {int} -- Batch size (default: {8224}) 19 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 20 | (default: {None}) 21 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 22 | to_cpu {bool} -- For larger data sets we need to move the results to cpu 23 | (default: {False}) 24 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 25 | 26 | Returns: 27 | [TupleTree, np.ndarray or tensor] -- Predictions 28 | """ 29 | raise NotImplementedError 30 | 31 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0): 32 | """Predict the survival function for `input` and return as a pandas DataFrame. 33 | See `predict_surv` to return tensor or np.array instead. 34 | 35 | Arguments: 36 | input {dataloader, tuple, np.ndarray, or torch.tensor} -- Input to net. 37 | 38 | Keyword Arguments: 39 | batch_size {int} -- Batch size (default: {8224}) 40 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 41 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 42 | 43 | Returns: 44 | pd.DataFrame -- Predictions 45 | """ 46 | raise NotImplementedError 47 | 48 | def predict_hazard(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, 49 | num_workers=0): 50 | """Predict the hazard function for `input`. 51 | 52 | Arguments: 53 | input {dataloader, tuple, np.ndarray, or torch.tensor} -- Input to net. 54 | 55 | Keyword Arguments: 56 | batch_size {int} -- Batch size (default: {8224}) 57 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 58 | (default: {None}) 59 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 60 | grads {bool} -- If gradients should be computed (default: {False}) 61 | to_cpu {bool} -- For larger data sets we need to move the results to cpu 62 | (default: {False}) 63 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 64 | 65 | Returns: 66 | [np.ndarray or tensor] -- Predicted hazards 67 | """ 68 | raise NotImplementedError 69 | 70 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, 71 | num_workers=0): 72 | """Predict the probability mass function (PMF) for `input`. 73 | 74 | Arguments: 75 | input {dataloader, tuple, np.ndarray, or torch.tensor} -- Input to net. 76 | 77 | Keyword Arguments: 78 | batch_size {int} -- Batch size (default: {8224}) 79 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 80 | (default: {None}) 81 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 82 | grads {bool} -- If gradients should be computed (default: {False}) 83 | to_cpu {bool} -- For larger data sets we need to move the results to cpu 84 | (default: {False}) 85 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 86 | 87 | Returns: 88 | [np.ndarray or tensor] -- Predictions 89 | """ 90 | raise NotImplementedError 91 | 92 | 93 | class _SurvModelBase(tt.Model): 94 | """Base class for survival models. 95 | Essentially same as torchtuples.Model, 96 | """ 97 | def __init__(self, net, loss=None, optimizer=None, device=None): 98 | warnings.warn('Will be removed shortly', DeprecationWarning) 99 | super().__init__(net, loss, optimizer, device) 100 | 101 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, 102 | to_cpu=False, num_workers=0): 103 | """Predict the survival function for `input`. 104 | See `prediction_surv_df` to return a DataFrame instead. 105 | 106 | Arguments: 107 | input {tuple, np.ndarray, or torch.tensor} -- Input to net. 108 | 109 | Keyword Arguments: 110 | batch_size {int} -- Batch size (default: {8224}) 111 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 112 | (default: {None}) 113 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 114 | to_cpu {bool} -- For larger data sets we need to move the results to cpu 115 | (default: {False}) 116 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 117 | 118 | Returns: 119 | [TupleTree, np.ndarray or tensor] -- Predictions 120 | """ 121 | raise NotImplementedError 122 | 123 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0): 124 | """Predict the survival function for `input` and return as a pandas DataFrame. 125 | See `predict_surv` to return tensor or np.array instead. 126 | 127 | Arguments: 128 | input {tuple, np.ndarray, or torch.tensor} -- Input to net. 129 | 130 | Keyword Arguments: 131 | batch_size {int} -- Batch size (default: {8224}) 132 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 133 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 134 | 135 | Returns: 136 | pd.DataFrame -- Predictions 137 | """ 138 | raise NotImplementedError 139 | -------------------------------------------------------------------------------- /pycox/models/bce_surv.py: -------------------------------------------------------------------------------- 1 | """Estimate survival curve with binomial log-likelihood. 2 | 3 | This method is not smart to use!!!!!!! 4 | """ 5 | import pandas as pd 6 | import torch 7 | from pycox import models 8 | from pycox.preprocessing import label_transforms 9 | 10 | 11 | class BCESurv(models.base.SurvBase): 12 | """ 13 | The BCESurv method is a discrete-time survival model that parametrize the survival function directly 14 | and disregards individuals as they are censored. Each output node represents a binary classifier at 15 | the corresponding time, where all censored individual are removed. 16 | See [1] for details. 17 | 18 | Arguments: 19 | net {torch.nn.Module} -- A torch module. 20 | 21 | Keyword Arguments: 22 | optimizer {Optimizer} -- A torch optimizer or similar. Preferably use torchtuples.optim instead of 23 | torch.optim, as this allows for reinitialization, etc. If 'None' set to torchtuples.optim.AdamW. 24 | (default: {None}) 25 | device {str, int, torch.device} -- Device to compute on. (default: {None}) 26 | Preferably pass a torch.device object. 27 | If 'None': use default gpu if available, else use cpu. 28 | If 'int': used that gpu: torch.device('cuda:'). 29 | If 'string': string is passed to torch.device('string'). 30 | duration_index {list, np.array} -- Array of durations that defines the discrete times. 31 | This is used to set the index of the DataFrame in `predict_surv_df`. 32 | loss {func} -- An alternative loss function (default: {None}) 33 | 34 | References: 35 | [1] Håvard Kvamme and Ørnulf Borgan. The Brier Score under Administrative Censoring: Problems 36 | and Solutions. arXiv preprint arXiv:1912.08581, 2019. 37 | https://arxiv.org/pdf/1912.08581.pdf 38 | """ 39 | label_transform = label_transforms.LabTransDiscreteTime 40 | 41 | def __init__(self, net, optimizer=None, device=None, duration_index=None, loss=None): 42 | self.duration_index = duration_index 43 | if loss is None: 44 | loss = models.loss.BCESurvLoss() 45 | super().__init__(net, loss, optimizer, device) 46 | 47 | @property 48 | def duration_index(self): 49 | """ 50 | Array of durations that defines the discrete times. This is used to set the index 51 | of the DataFrame in `predict_surv_df`. 52 | 53 | Returns: 54 | np.array -- Duration index. 55 | """ 56 | return self._duration_index 57 | 58 | @duration_index.setter 59 | def duration_index(self, val): 60 | self._duration_index = val 61 | 62 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0, is_dataloader=None): 63 | surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers, is_dataloader) 64 | return pd.DataFrame(surv.transpose(), self.duration_index) 65 | 66 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, 67 | num_workers=0, is_dataloader=None): 68 | return self.predict(input, batch_size, numpy, eval_, False, to_cpu, num_workers, 69 | is_dataloader, torch.sigmoid) 70 | 71 | def interpolate(self, sub=10, scheme='const_pdf', duration_index=None): 72 | """Use interpolation for predictions. 73 | There is only one scheme: 74 | `const_pdf` and `lin_surv` which assumes pice-wise constant PMF in each interval (linear survival). 75 | 76 | Keyword Arguments: 77 | sub {int} -- Number of "sub" units in interpolation grid. If `sub` is 10 we have a grid with 78 | 10 times the number of grid points than the original `duration_index` (default: {10}). 79 | scheme {str} -- Type of interpolation {'const_pdf'}. 80 | See `InterpolateDiscrete` (default: {'const_pdf'}) 81 | duration_index {np.array} -- Cuts used for discretization. Does not affect interpolation, 82 | only for setting index in `predict_surv_df` (default: {None}) 83 | 84 | Returns: 85 | [InterpolateLogisticHazard] -- Object for prediction with interpolation. 86 | """ 87 | if duration_index is None: 88 | duration_index = self.duration_index 89 | return models.interpolation.InterpolateDiscrete(self, scheme, duration_index, sub) 90 | -------------------------------------------------------------------------------- /pycox/models/cox.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | import torchtuples as tt 7 | from pycox import models 8 | 9 | def search_sorted_idx(array, values): 10 | '''For sorted array, get index of values. 11 | If value not in array, give left index of value. 12 | ''' 13 | n = len(array) 14 | idx = np.searchsorted(array, values) 15 | idx[idx == n] = n-1 # We can't have indexes higher than the length-1 16 | not_exact = values != array[idx] 17 | idx -= not_exact 18 | if any(idx < 0): 19 | warnings.warn('Given value smaller than first value') 20 | idx[idx < 0] = 0 21 | return idx 22 | 23 | 24 | class _CoxBase(models.base.SurvBase): 25 | duration_col = 'duration' 26 | event_col = 'event' 27 | 28 | def fit(self, input, target, batch_size=256, epochs=1, callbacks=None, verbose=True, 29 | num_workers=0, shuffle=True, metrics=None, val_data=None, val_batch_size=8224, 30 | **kwargs): 31 | """Fit model with inputs and targets. Where 'input' is the covariates, and 32 | 'target' is a tuple with (durations, events). 33 | 34 | Arguments: 35 | input {np.array, tensor or tuple} -- Input x passed to net. 36 | target {np.array, tensor or tuple} -- Target [durations, events]. 37 | 38 | Keyword Arguments: 39 | batch_size {int} -- Elements in each batch (default: {256}) 40 | epochs {int} -- Number of epochs (default: {1}) 41 | callbacks {list} -- list of callbacks (default: {None}) 42 | verbose {bool} -- Print progress (default: {True}) 43 | num_workers {int} -- Number of workers used in the dataloader (default: {0}) 44 | shuffle {bool} -- If we should shuffle the order of the dataset (default: {True}) 45 | **kwargs are passed to 'make_dataloader' method. 46 | 47 | Returns: 48 | TrainingLogger -- Training log 49 | """ 50 | self.training_data = tt.tuplefy(input, target) 51 | return super().fit(input, target, batch_size, epochs, callbacks, verbose, 52 | num_workers, shuffle, metrics, val_data, val_batch_size, 53 | **kwargs) 54 | 55 | def _compute_baseline_hazards(self, input, df, max_duration, batch_size, eval_=True, num_workers=0): 56 | raise NotImplementedError 57 | 58 | def target_to_df(self, target): 59 | durations, events = tt.tuplefy(target).to_numpy() 60 | df = pd.DataFrame({self.duration_col: durations, self.event_col: events}) 61 | return df 62 | 63 | def compute_baseline_hazards(self, input=None, target=None, max_duration=None, sample=None, batch_size=8224, 64 | set_hazards=True, eval_=True, num_workers=0): 65 | """Computes the Breslow estimates form the data defined by `input` and `target` 66 | (if `None` use training data). 67 | 68 | Typically call 69 | model.compute_baseline_hazards() after fitting. 70 | 71 | Keyword Arguments: 72 | input -- Input data (train input) (default: {None}) 73 | target -- Target data (train target) (default: {None}) 74 | max_duration {float} -- Don't compute estimates for duration higher (default: {None}) 75 | sample {float or int} -- Compute estimates of subsample of data (default: {None}) 76 | batch_size {int} -- Batch size (default: {8224}) 77 | set_hazards {bool} -- Set hazards in model object, or just return hazards. (default: {True}) 78 | 79 | Returns: 80 | pd.Series -- Pandas series with baseline hazards. Index is duration_col. 81 | """ 82 | if (input is None) and (target is None): 83 | if not hasattr(self, 'training_data'): 84 | raise ValueError("Need to give a 'input' and 'target' to this function.") 85 | input, target = self.training_data 86 | df = self.target_to_df(target)#.sort_values(self.duration_col) 87 | if sample is not None: 88 | if sample >= 1: 89 | df = df.sample(n=sample) 90 | else: 91 | df = df.sample(frac=sample) 92 | input = tt.tuplefy(input).to_numpy().iloc[df.index.values] 93 | base_haz = self._compute_baseline_hazards(input, df, max_duration, batch_size, 94 | eval_=eval_, num_workers=num_workers) 95 | if set_hazards: 96 | self.compute_baseline_cumulative_hazards(set_hazards=True, baseline_hazards_=base_haz) 97 | return base_haz 98 | 99 | def compute_baseline_cumulative_hazards(self, input=None, target=None, max_duration=None, sample=None, 100 | batch_size=8224, set_hazards=True, baseline_hazards_=None, 101 | eval_=True, num_workers=0): 102 | """See `compute_baseline_hazards. This is the cumulative version.""" 103 | if ((input is not None) or (target is not None)) and (baseline_hazards_ is not None): 104 | raise ValueError("'input', 'target' and 'baseline_hazards_' can not both be different from 'None'.") 105 | if baseline_hazards_ is None: 106 | baseline_hazards_ = self.compute_baseline_hazards(input, target, max_duration, sample, batch_size, 107 | set_hazards=False, eval_=eval_, num_workers=num_workers) 108 | assert baseline_hazards_.index.is_monotonic_increasing,\ 109 | 'Need index of baseline_hazards_ to be monotonic increasing, as it represents time.' 110 | bch = (baseline_hazards_ 111 | .cumsum() 112 | .rename('baseline_cumulative_hazards')) 113 | if set_hazards: 114 | self.baseline_hazards_ = baseline_hazards_ 115 | self.baseline_cumulative_hazards_ = bch 116 | return bch 117 | 118 | def predict_cumulative_hazards(self, input, max_duration=None, batch_size=8224, verbose=False, 119 | baseline_hazards_=None, eval_=True, num_workers=0): 120 | """See `predict_survival_function`.""" 121 | if type(input) is pd.DataFrame: 122 | input = self.df_to_input(input) 123 | if baseline_hazards_ is None: 124 | if not hasattr(self, 'baseline_hazards_'): 125 | raise ValueError('Need to compute baseline_hazards_. E.g run `model.compute_baseline_hazards()`') 126 | baseline_hazards_ = self.baseline_hazards_ 127 | assert baseline_hazards_.index.is_monotonic_increasing,\ 128 | 'Need index of baseline_hazards_ to be monotonic increasing, as it represents time.' 129 | return self._predict_cumulative_hazards(input, max_duration, batch_size, verbose, baseline_hazards_, 130 | eval_, num_workers=num_workers) 131 | 132 | def _predict_cumulative_hazards(self, input, max_duration, batch_size, verbose, baseline_hazards_, 133 | eval_=True, num_workers=0): 134 | raise NotImplementedError 135 | 136 | def predict_surv_df(self, input, max_duration=None, batch_size=8224, verbose=False, baseline_hazards_=None, 137 | eval_=True, num_workers=0): 138 | """Predict survival function for `input`. S(x, t) = exp(-H(x, t)) 139 | Require computed baseline hazards. 140 | 141 | Arguments: 142 | input {np.array, tensor or tuple} -- Input x passed to net. 143 | 144 | Keyword Arguments: 145 | max_duration {float} -- Don't compute estimates for duration higher (default: {None}) 146 | batch_size {int} -- Batch size (default: {8224}) 147 | baseline_hazards_ {pd.Series} -- Baseline hazards. If `None` used `model.baseline_hazards_` (default: {None}) 148 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 149 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 150 | 151 | Returns: 152 | pd.DataFrame -- Survival estimates. One columns for each individual. 153 | """ 154 | return np.exp(-self.predict_cumulative_hazards(input, max_duration, batch_size, verbose, baseline_hazards_, 155 | eval_, num_workers)) 156 | 157 | def predict_surv(self, input, max_duration=None, batch_size=8224, numpy=None, verbose=False, 158 | baseline_hazards_=None, eval_=True, num_workers=0): 159 | """Predict survival function for `input`. S(x, t) = exp(-H(x, t)) 160 | Require compueted baseline hazards. 161 | 162 | Arguments: 163 | input {np.array, tensor or tuple} -- Input x passed to net. 164 | 165 | Keyword Arguments: 166 | max_duration {float} -- Don't compute estimates for duration higher (default: {None}) 167 | batch_size {int} -- Batch size (default: {8224}) 168 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 169 | (default: {None}) 170 | baseline_hazards_ {pd.Series} -- Baseline hazards. If `None` used `model.baseline_hazards_` (default: {None}) 171 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 172 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 173 | 174 | Returns: 175 | pd.DataFrame -- Survival estimates. One columns for each individual. 176 | """ 177 | surv = self.predict_surv_df(input, max_duration, batch_size, verbose, baseline_hazards_, 178 | eval_, num_workers) 179 | surv = torch.from_numpy(surv.values.transpose()) 180 | return tt.utils.array_or_tensor(surv, numpy, input) 181 | 182 | def save_net(self, path, **kwargs): 183 | """Save self.net and baseline hazards to file. 184 | 185 | Arguments: 186 | path {str} -- Path to file. 187 | **kwargs are passed to torch.save 188 | 189 | Returns: 190 | None 191 | """ 192 | path, extension = os.path.splitext(path) 193 | if extension == "": 194 | extension = '.pt' 195 | super().save_net(path+extension, **kwargs) 196 | if hasattr(self, 'baseline_hazards_'): 197 | self.baseline_hazards_.to_pickle(path+'_blh.pickle') 198 | 199 | def load_net(self, path, **kwargs): 200 | """Load net and hazards from file. 201 | 202 | Arguments: 203 | path {str} -- Path to file. 204 | **kwargs are passed to torch.load 205 | 206 | Returns: 207 | None 208 | """ 209 | path, extension = os.path.splitext(path) 210 | if extension == "": 211 | extension = '.pt' 212 | super().load_net(path+extension, **kwargs) 213 | blh_path = path+'_blh.pickle' 214 | if os.path.isfile(blh_path): 215 | self.baseline_hazards_ = pd.read_pickle(blh_path) 216 | self.baseline_cumulative_hazards_ = self.baseline_hazards_.cumsum() 217 | 218 | def df_to_input(self, df): 219 | input = df[self.input_cols].values 220 | return input 221 | 222 | 223 | class _CoxPHBase(_CoxBase): 224 | def _compute_baseline_hazards(self, input, df_target, max_duration, batch_size, eval_=True, num_workers=0): 225 | if max_duration is None: 226 | max_duration = np.inf 227 | 228 | # Here we are computing when expg when there are no events. 229 | # Could be made faster, by only computing when there are events. 230 | return (df_target 231 | .assign(expg=np.exp(self.predict(input, batch_size, True, eval_, num_workers=num_workers))) 232 | .groupby(self.duration_col) 233 | .agg({'expg': 'sum', self.event_col: 'sum'}) 234 | .sort_index(ascending=False) 235 | .assign(expg=lambda x: x['expg'].cumsum()) 236 | .pipe(lambda x: x[self.event_col]/x['expg']) 237 | .fillna(0.) 238 | .iloc[::-1] 239 | .loc[lambda x: x.index <= max_duration] 240 | .rename('baseline_hazards')) 241 | 242 | def _predict_cumulative_hazards(self, input, max_duration, batch_size, verbose, baseline_hazards_, 243 | eval_=True, num_workers=0): 244 | max_duration = np.inf if max_duration is None else max_duration 245 | if baseline_hazards_ is self.baseline_hazards_: 246 | bch = self.baseline_cumulative_hazards_ 247 | else: 248 | bch = self.compute_baseline_cumulative_hazards(set_hazards=False, 249 | baseline_hazards_=baseline_hazards_) 250 | bch = bch.loc[lambda x: x.index <= max_duration] 251 | expg = np.exp(self.predict(input, batch_size, True, eval_, num_workers=num_workers)).reshape(1, -1) 252 | return pd.DataFrame(bch.values.reshape(-1, 1).dot(expg), 253 | index=bch.index) 254 | 255 | def partial_log_likelihood(self, input, target, g_preds=None, batch_size=8224, eps=1e-7, eval_=True, 256 | num_workers=0): 257 | '''Calculate the partial log-likelihood for the events in datafram df. 258 | This likelihood does not sample the controls. 259 | Note that censored data (non events) does not have a partial log-likelihood. 260 | 261 | Arguments: 262 | input {tuple, np.ndarray, or torch.tensor} -- Input to net. 263 | target {tuple, np.ndarray, or torch.tensor} -- Target labels. 264 | 265 | Keyword Arguments: 266 | g_preds {np.array} -- Predictions from `model.predict` (default: {None}) 267 | batch_size {int} -- Batch size (default: {8224}) 268 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 269 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 270 | 271 | Returns: 272 | Partial log-likelihood. 273 | ''' 274 | df = self.target_to_df(target) 275 | if g_preds is None: 276 | g_preds = self.predict(input, batch_size, True, eval_, num_workers=num_workers) 277 | return (df 278 | .assign(_g_preds=g_preds) 279 | .sort_values(self.duration_col, ascending=False) 280 | .assign(_cum_exp_g=(lambda x: x['_g_preds'] 281 | .pipe(np.exp) 282 | .cumsum() 283 | .groupby(x[self.duration_col]) 284 | .transform('max'))) 285 | .loc[lambda x: x[self.event_col] == 1] 286 | .assign(pll=lambda x: x['_g_preds'] - np.log(x['_cum_exp_g'] + eps)) 287 | ['pll']) 288 | 289 | 290 | class CoxPH(_CoxPHBase): 291 | """Cox proportional hazards model parameterized with a neural net. 292 | This is essentially the DeepSurv method [1]. 293 | 294 | The loss function is not quite the partial log-likelihood, but close. 295 | The difference is that for tied events, we use a random order instead of 296 | including all individuals that had an event at that point in time. 297 | 298 | Arguments: 299 | net {torch.nn.Module} -- A pytorch net. 300 | 301 | Keyword Arguments: 302 | optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None}) 303 | device {str, int, torch.device} -- Device to compute on. (default: {None}) 304 | Preferably pass a torch.device object. 305 | If 'None': use default gpu if available, else use cpu. 306 | If 'int': used that gpu: torch.device('cuda:'). 307 | If 'string': string is passed to torch.device('string'). 308 | 309 | [1] Jared L. Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger. 310 | Deepsurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. 311 | BMC Medical Research Methodology, 18(1), 2018. 312 | https://bmcmedresmethodol.biomedcentral.com/articles/10.1186/s12874-018-0482-1 313 | """ 314 | def __init__(self, net, optimizer=None, device=None, loss=None): 315 | if loss is None: 316 | loss = models.loss.CoxPHLoss() 317 | super().__init__(net, loss, optimizer, device) 318 | 319 | 320 | class CoxPHSorted(_CoxPHBase): 321 | """Cox proportional hazards model parameterized with a neural net. 322 | This is essentially the DeepSurv method [1]. 323 | 324 | The loss function is not quite the partial log-likelihood, but close. 325 | The difference is that for tied events, we use a random order instead of 326 | including all individuals that had an event at that point in time. 327 | 328 | Arguments: 329 | net {torch.nn.Module} -- A pytorch net. 330 | 331 | Keyword Arguments: 332 | optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None}) 333 | device {str, int, torch.device} -- Device to compute on. (default: {None}) 334 | Preferably pass a torch.device object. 335 | If 'None': use default gpu if available, else use cpu. 336 | If 'int': used that gpu: torch.device('cuda:'). 337 | If 'string': string is passed to torch.device('string'). 338 | 339 | [1] Jared L. Katzman, Uri Shaham, Alexander Cloninger, Jonathan Bates, Tingting Jiang, and Yuval Kluger. 340 | Deepsurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. 341 | BMC Medical Research Methodology, 18(1), 2018. 342 | https://bmcmedresmethodol.biomedcentral.com/articles/10.1186/s12874-018-0482-1 343 | """ 344 | def __init__(self, net, optimizer=None, device=None, loss=None): 345 | warnings.warn('Use `CoxPH` instead. This will be removed', DeprecationWarning) 346 | if loss is None: 347 | loss = models.loss.CoxPHLossSorted() 348 | super().__init__(net, loss, optimizer, device) 349 | 350 | @staticmethod 351 | def make_dataloader(data, batch_size, shuffle, num_workers=0): 352 | dataloader = tt.make_dataloader(data, batch_size, shuffle, num_workers, 353 | make_dataset=models.data.DurationSortedDataset) 354 | return dataloader 355 | 356 | def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0): 357 | dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers) 358 | return dataloader 359 | -------------------------------------------------------------------------------- /pycox/models/cox_cc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchtuples as tt 3 | from pycox import models 4 | 5 | 6 | class _CoxCCBase(models.cox._CoxBase): 7 | make_dataset = NotImplementedError 8 | 9 | def __init__(self, net, optimizer=None, device=None, shrink=0., loss=None): 10 | if loss is None: 11 | loss = models.loss.CoxCCLoss(shrink) 12 | super().__init__(net, loss, optimizer, device) 13 | 14 | def fit(self, input, target, batch_size=256, epochs=1, callbacks=None, verbose=True, 15 | num_workers=0, shuffle=True, metrics=None, val_data=None, val_batch_size=8224, 16 | n_control=1, shrink=None, **kwargs): 17 | """Fit model with inputs and targets. Where 'input' is the covariates, and 18 | 'target' is a tuple with (durations, events). 19 | 20 | Arguments: 21 | input {np.array, tensor or tuple} -- Input x passed to net. 22 | target {np.array, tensor or tuple} -- Target [durations, events]. 23 | 24 | Keyword Arguments: 25 | batch_size {int} -- Elements in each batch (default: {256}) 26 | epochs {int} -- Number of epochs (default: {1}) 27 | callbacks {list} -- list of callbacks (default: {None}) 28 | verbose {bool} -- Print progress (default: {True}) 29 | num_workers {int} -- Number of workers used in the dataloader (default: {0}) 30 | shuffle {bool} -- If we should shuffle the order of the dataset (default: {True}) 31 | n_control {int} -- Number of control samples. 32 | **kwargs are passed to 'make_dataloader' method. 33 | 34 | Returns: 35 | TrainingLogger -- Training log 36 | """ 37 | input, target = self._sorted_input_target(input, target) 38 | if shrink is not None: 39 | self.loss.shrink = shrink 40 | return super().fit(input, target, batch_size, epochs, callbacks, verbose, 41 | num_workers, shuffle, metrics, val_data, val_batch_size, 42 | n_control=n_control, **kwargs) 43 | 44 | def compute_metrics(self, input, metrics): 45 | if (self.loss is None) and (self.loss in metrics.values()): 46 | raise RuntimeError(f"Need to specify a loss (self.loss). It's currently None") 47 | input = self._to_device(input) 48 | batch_size = input.lens().flatten().get_if_all_equal() 49 | if batch_size is None: 50 | raise RuntimeError("All elements in input does not have the same length.") 51 | case, control = input # both are TupleTree 52 | input_all = tt.TupleTree((case,) + control).cat() 53 | g_all = self.net(*input_all) 54 | g_all = tt.tuplefy(g_all).split(batch_size).flatten() 55 | g_case = g_all[0] 56 | g_control = g_all[1:] 57 | res = {name: metric(g_case, g_control) for name, metric in metrics.items()} 58 | return res 59 | 60 | def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0): 61 | """Dataloader for prediction. The input is either the regular input, or a tuple 62 | with input and label. 63 | 64 | Arguments: 65 | input {np.array, tensor, tuple} -- Input to net, or tuple with input and labels. 66 | batch_size {int} -- Batch size. 67 | 68 | Keyword Arguments: 69 | shuffle {bool} -- If we should shuffle in the dataloader. (default: {False}) 70 | num_workers {int} -- Number of worker in dataloader. (default: {0}) 71 | 72 | Returns: 73 | dataloader -- A dataloader. 74 | """ 75 | dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers) 76 | return dataloader 77 | 78 | def make_dataloader(self, data, batch_size, shuffle=True, num_workers=0, n_control=1): 79 | """Dataloader for training. Data is on the form (input, target), where 80 | target is (durations, events). 81 | 82 | Arguments: 83 | data {tuple} -- Tuple containing (input, (durations, events)). 84 | batch_size {int} -- Batch size. 85 | 86 | Keyword Arguments: 87 | shuffle {bool} -- If shuffle in dataloader (default: {True}) 88 | num_workers {int} -- Number of workers in dataloader. (default: {0}) 89 | n_control {int} -- Number of control samples in dataloader (default: {1}) 90 | 91 | Returns: 92 | dataloader -- Dataloader for training. 93 | """ 94 | input, target = self._sorted_input_target(*data) 95 | durations, events = target 96 | dataset = self.make_dataset(input, durations, events, n_control) 97 | dataloader = tt.data.DataLoaderBatch(dataset, batch_size=batch_size, 98 | shuffle=shuffle, num_workers=num_workers) 99 | return dataloader 100 | 101 | @staticmethod 102 | def _sorted_input_target(input, target): 103 | input, target = tt.tuplefy(input, target).to_numpy() 104 | durations, _ = target 105 | idx_sort = np.argsort(durations) 106 | if (idx_sort == np.arange(0, len(idx_sort))).all(): 107 | return input, target 108 | input = tt.tuplefy(input).iloc[idx_sort] 109 | target = tt.tuplefy(target).iloc[idx_sort] 110 | return input, target 111 | 112 | 113 | class CoxCC(_CoxCCBase, models.cox._CoxPHBase): 114 | """Cox proportional hazards model parameterized with a neural net and 115 | trained with case-control sampling [1]. 116 | This is similar to DeepSurv, but use an approximation of the loss function. 117 | 118 | Arguments: 119 | net {torch.nn.Module} -- A PyTorch net. 120 | 121 | Keyword Arguments: 122 | optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None}) 123 | device {str, int, torch.device} -- Device to compute on. (default: {None}) 124 | Preferably pass a torch.device object. 125 | If 'None': use default gpu if available, else use cpu. 126 | If 'int': used that gpu: torch.device('cuda:'). 127 | If 'string': string is passed to torch.device('string'). 128 | 129 | References: 130 | [1] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel. 131 | Time-to-event prediction with neural networks and Cox regression. 132 | Journal of Machine Learning Research, 20(129):1–30, 2019. 133 | http://jmlr.org/papers/v20/18-424.html 134 | """ 135 | make_dataset = models.data.CoxCCDataset 136 | -------------------------------------------------------------------------------- /pycox/models/cox_time.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from torch import nn 5 | import torchtuples as tt 6 | 7 | from pycox import models 8 | from pycox.preprocessing.label_transforms import LabTransCoxTime 9 | 10 | class CoxTime(models.cox_cc._CoxCCBase): 11 | """The Cox-Time model from [1]. A relative risk model without proportional hazards, trained 12 | with case-control sampling. 13 | 14 | Arguments: 15 | net {torch.nn.Module} -- A PyTorch net. 16 | 17 | Keyword Arguments: 18 | optimizer {torch or torchtuples optimizer} -- Optimizer (default: {None}) 19 | device {str, int, torch.device} -- Device to compute on. (default: {None}) 20 | Preferably pass a torch.device object. 21 | If 'None': use default gpu if available, else use cpu. 22 | If 'int': used that gpu: torch.device('cuda:'). 23 | If 'string': string is passed to torch.device('string'). 24 | shrink {float} -- Shrinkage that encourage the net got give g_case and g_control 25 | closer to zero (a regularizer in a sense). (default: {0.}) 26 | labtrans {pycox.preprocessing.label_tranforms.LabTransCoxTime} -- A object for transforming 27 | durations. Useful for prediction as we can obtain durations on the original scale. 28 | (default: {None}) 29 | 30 | References: 31 | [1] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel. 32 | Time-to-event prediction with neural networks and Cox regression. 33 | Journal of Machine Learning Research, 20(129):1–30, 2019. 34 | http://jmlr.org/papers/v20/18-424.html 35 | """ 36 | make_dataset = models.data.CoxTimeDataset 37 | label_transform = LabTransCoxTime 38 | 39 | def __init__(self, net, optimizer=None, device=None, shrink=0., labtrans=None, loss=None): 40 | self.labtrans = labtrans 41 | super().__init__(net, optimizer, device, shrink, loss) 42 | 43 | def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0): 44 | input, durations = input 45 | input = tt.tuplefy(input) 46 | durations = tt.tuplefy(durations) 47 | new_input = input + durations 48 | dataloader = super().make_dataloader_predict(new_input, batch_size, shuffle, num_workers) 49 | return dataloader 50 | 51 | def predict_surv_df(self, input, max_duration=None, batch_size=8224, verbose=False, baseline_hazards_=None, 52 | eval_=True, num_workers=0): 53 | surv = super().predict_surv_df(input, max_duration, batch_size, verbose, baseline_hazards_, 54 | eval_, num_workers) 55 | if self.labtrans is not None: 56 | surv.index = self.labtrans.map_scaled_to_orig(surv.index) 57 | return surv 58 | 59 | def compute_baseline_hazards(self, input=None, target=None, max_duration=None, sample=None, batch_size=8224, 60 | set_hazards=True, eval_=True, num_workers=0): 61 | if (input is None) and (target is None): 62 | if not hasattr(self, 'training_data'): 63 | raise ValueError('Need to fit, or supply a input and target to this function.') 64 | input, target = self.training_data 65 | df = self.target_to_df(target) 66 | if sample is not None: 67 | if sample >= 1: 68 | df = df.sample(n=sample) 69 | else: 70 | df = df.sample(frac=sample) 71 | df = df.sort_values(self.duration_col) 72 | input = tt.tuplefy(input).to_numpy().iloc[df.index.values] 73 | base_haz = self._compute_baseline_hazards(input, df, max_duration, batch_size, eval_, num_workers) 74 | if set_hazards: 75 | self.compute_baseline_cumulative_hazards(set_hazards=True, baseline_hazards_=base_haz) 76 | return base_haz 77 | 78 | def _compute_baseline_hazards(self, input, df_train_target, max_duration, batch_size, eval_=True, 79 | num_workers=0): 80 | if max_duration is None: 81 | max_duration = np.inf 82 | def compute_expg_at_risk(ix, t): 83 | sub = input.iloc[ix:] 84 | n = sub.lens().flatten().get_if_all_equal() 85 | t = np.repeat(t, n).reshape(-1, 1).astype('float32') 86 | return np.exp(self.predict((sub, t), batch_size, True, eval_, num_workers=num_workers)).flatten().sum() 87 | 88 | if not df_train_target[self.duration_col].is_monotonic_increasing: 89 | raise RuntimeError(f"Need 'df_train_target' to be sorted by {self.duration_col}") 90 | input = tt.tuplefy(input) 91 | df = df_train_target.reset_index(drop=True) 92 | times = (df 93 | .loc[lambda x: x[self.event_col] != 0] 94 | [self.duration_col] 95 | .loc[lambda x: x <= max_duration] 96 | .drop_duplicates(keep='first')) 97 | at_risk_sum = (pd.Series([compute_expg_at_risk(ix, t) for ix, t in times.items()], 98 | index=times.values) 99 | .rename('at_risk_sum')) 100 | events = (df 101 | .groupby(self.duration_col) 102 | [[self.event_col]] 103 | .agg('sum') 104 | .loc[lambda x: x.index <= max_duration]) 105 | base_haz = (events 106 | .join(at_risk_sum, how='left', sort=True) 107 | .pipe(lambda x: x[self.event_col] / x['at_risk_sum']) 108 | .fillna(0.) 109 | .rename('baseline_hazards')) 110 | return base_haz 111 | 112 | def _predict_cumulative_hazards(self, input, max_duration, batch_size, verbose, baseline_hazards_, 113 | eval_=True, num_workers=0): 114 | def expg_at_time(t): 115 | t = np.repeat(t, n_cols).reshape(-1, 1).astype('float32') 116 | if tt.tuplefy(input).type() is torch.Tensor: 117 | t = torch.from_numpy(t) 118 | return np.exp(self.predict((input, t), batch_size, True, eval_, num_workers=num_workers)).flatten() 119 | 120 | if tt.utils.is_dl(input): 121 | raise NotImplementedError(f"Prediction with a dataloader as input is not supported ") 122 | input = tt.tuplefy(input) 123 | max_duration = np.inf if max_duration is None else max_duration 124 | baseline_hazards_ = baseline_hazards_.loc[lambda x: x.index <= max_duration] 125 | n_rows, n_cols = baseline_hazards_.shape[0], input.lens().flatten().get_if_all_equal() 126 | hazards = np.empty((n_rows, n_cols)) 127 | for idx, t in enumerate(baseline_hazards_.index): 128 | if verbose: 129 | print(idx, 'of', len(baseline_hazards_)) 130 | hazards[idx, :] = expg_at_time(t) 131 | hazards[baseline_hazards_.values == 0] = 0. # in case hazards are inf here 132 | hazards *= baseline_hazards_.values.reshape(-1, 1) 133 | return pd.DataFrame(hazards, index=baseline_hazards_.index).cumsum() 134 | 135 | def partial_log_likelihood(self, input, target, batch_size=8224, eval_=True, num_workers=0): 136 | def expg_sum(t, i): 137 | sub = input_sorted.iloc[i:] 138 | n = sub.lens().flatten().get_if_all_equal() 139 | t = np.repeat(t, n).reshape(-1, 1).astype('float32') 140 | return np.exp(self.predict((sub, t), batch_size, True, eval_, num_workers=num_workers)).flatten().sum() 141 | 142 | durations, events = target 143 | df = pd.DataFrame({self.duration_col: durations, self.event_col: events}) 144 | df = df.sort_values(self.duration_col) 145 | input = tt.tuplefy(input) 146 | input_sorted = input.iloc[df.index.values] 147 | 148 | times = (df 149 | .assign(_idx=np.arange(len(df))) 150 | .loc[lambda x: x[self.event_col] == True] 151 | .drop_duplicates(self.duration_col, keep='first') 152 | .assign(_expg_sum=lambda x: [expg_sum(t, i) for t, i in zip(x[self.duration_col], x['_idx'])]) 153 | .drop([self.event_col, '_idx'], axis=1)) 154 | 155 | idx_name_old = df.index.name 156 | idx_name = '__' + idx_name_old if idx_name_old else '__index' 157 | df.index.name = idx_name 158 | 159 | pll = df.loc[lambda x: x[self.event_col] == True] 160 | input_event = input.iloc[pll.index.values] 161 | durations_event = pll[self.duration_col].values.reshape(-1, 1) 162 | g_preds = self.predict((input_event, durations_event), batch_size, True, eval_, num_workers=num_workers).flatten() 163 | pll = (pll 164 | .assign(_g_preds=g_preds) 165 | .reset_index() 166 | .merge(times, on=self.duration_col) 167 | .set_index(idx_name) 168 | .assign(pll=lambda x: x['_g_preds'] - np.log(x['_expg_sum'])) 169 | ['pll']) 170 | 171 | pll.index.name = idx_name_old 172 | return pll 173 | 174 | 175 | class MLPVanillaCoxTime(nn.Module): 176 | """A version of torchtuples.practical.MLPVanilla that works for CoxTime. 177 | The difference is that it takes `time` as an additional input and removes the output bias and 178 | output activation. 179 | """ 180 | def __init__(self, in_features, num_nodes, batch_norm=True, dropout=None, activation=nn.ReLU, 181 | w_init_=lambda w: nn.init.kaiming_normal_(w, nonlinearity='relu')): 182 | super().__init__() 183 | in_features += 1 184 | out_features = 1 185 | output_activation = None 186 | output_bias=False 187 | self.net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout, 188 | activation, output_activation, output_bias, w_init_) 189 | 190 | def forward(self, input, time): 191 | input = torch.cat([input, time], dim=1) 192 | return self.net(input) 193 | 194 | 195 | class MixedInputMLPCoxTime(nn.Module): 196 | """A version of torchtuples.practical.MixedInputMLP that works for CoxTime. 197 | The difference is that it takes `time` as an additional input and removes the output bias and 198 | output activation. 199 | """ 200 | def __init__(self, in_features, num_embeddings, embedding_dims, num_nodes, batch_norm=True, 201 | dropout=None, activation=nn.ReLU, dropout_embedding=0., 202 | w_init_=lambda w: nn.init.kaiming_normal_(w, nonlinearity='relu')): 203 | super().__init__() 204 | in_features += 1 205 | out_features = 1 206 | output_activation = None 207 | output_bias=False 208 | self.net = tt.practical.MixedInputMLP(in_features, num_embeddings, embedding_dims, num_nodes, 209 | out_features, batch_norm, dropout, activation, 210 | dropout_embedding, output_activation, output_bias, w_init_) 211 | 212 | def forward(self, input_numeric, input_categoric, time): 213 | input_numeric = torch.cat([input_numeric, time], dim=1) 214 | return self.net(input_numeric, input_categoric) 215 | -------------------------------------------------------------------------------- /pycox/models/data.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pandas as pd 4 | import numba 5 | import torch 6 | import torchtuples as tt 7 | 8 | 9 | def sample_alive_from_dates(dates, at_risk_dict, n_control=1): 10 | '''Sample index from living at time given in dates. 11 | dates: np.array of times (or pd.Series). 12 | at_risk_dict: dict with at_risk_dict[time] = . 13 | n_control: number of samples. 14 | ''' 15 | lengths = np.array([at_risk_dict[x].shape[0] for x in dates]) # Can be moved outside 16 | idx = (np.random.uniform(size=(n_control, dates.size)) * lengths).astype('int') 17 | samp = np.empty((dates.size, n_control), dtype=int) 18 | 19 | for it, time in enumerate(dates): 20 | samp[it, :] = at_risk_dict[time][idx[:, it]] 21 | return samp 22 | 23 | def make_at_risk_dict(durations): 24 | """Create dict(duration: indices) from sorted df. 25 | A dict mapping durations to indices. 26 | For each time => index of all individual alive. 27 | 28 | Arguments: 29 | durations {np.arrary} -- durations. 30 | """ 31 | assert type(durations) is np.ndarray, 'Need durations to be a numpy array' 32 | durations = pd.Series(durations) 33 | assert durations.is_monotonic_increasing, 'Requires durations to be monotonic' 34 | allidx = durations.index.values 35 | keys = durations.drop_duplicates(keep='first') 36 | at_risk_dict = dict() 37 | for ix, t in keys.items(): 38 | at_risk_dict[t] = allidx[ix:] 39 | return at_risk_dict 40 | 41 | 42 | class DurationSortedDataset(tt.data.DatasetTuple): 43 | """We assume the dataset contrain `(input, durations, events)`, and 44 | sort the batch based on descending `durations`. 45 | 46 | See `torchtuples.data.DatasetTuple`. 47 | """ 48 | def __getitem__(self, index): 49 | batch = super().__getitem__(index) 50 | input, (duration, event) = batch 51 | idx_sort = duration.sort(descending=True)[1] 52 | event = event.float() 53 | batch = tt.tuplefy(input, event).iloc[idx_sort] 54 | return batch 55 | 56 | 57 | class CoxCCDataset(torch.utils.data.Dataset): 58 | def __init__(self, input, durations, events, n_control=1): 59 | df_train_target = pd.DataFrame(dict(duration=durations, event=events)) 60 | self.durations = df_train_target.loc[lambda x: x['event'] == 1]['duration'] 61 | self.at_risk_dict = make_at_risk_dict(durations) 62 | 63 | self.input = tt.tuplefy(input) 64 | assert type(self.durations) is pd.Series 65 | self.n_control = n_control 66 | 67 | def __getitem__(self, index): 68 | if (not hasattr(index, '__iter__')) and (type(index) is not slice): 69 | index = [index] 70 | fails = self.durations.iloc[index] 71 | x_case = self.input.iloc[fails.index] 72 | control_idx = sample_alive_from_dates(fails.values, self.at_risk_dict, self.n_control) 73 | x_control = tt.TupleTree(self.input.iloc[idx] for idx in control_idx.transpose()) 74 | return tt.tuplefy(x_case, x_control).to_tensor() 75 | 76 | def __len__(self): 77 | return len(self.durations) 78 | 79 | 80 | class CoxTimeDataset(CoxCCDataset): 81 | def __init__(self, input, durations, events, n_control=1): 82 | super().__init__(input, durations, events, n_control) 83 | self.durations_tensor = tt.tuplefy(self.durations.values.reshape(-1, 1)).to_tensor() 84 | 85 | def __getitem__(self, index): 86 | if not hasattr(index, '__iter__'): 87 | index = [index] 88 | durations = self.durations_tensor.iloc[index] 89 | case, control = super().__getitem__(index) 90 | case = case + durations 91 | control = control.apply_nrec(lambda x: x + durations) 92 | return tt.tuplefy(case, control) 93 | 94 | @numba.njit 95 | def _pair_rank_mat(mat, idx_durations, events, dtype='float32'): 96 | n = len(idx_durations) 97 | for i in range(n): 98 | dur_i = idx_durations[i] 99 | ev_i = events[i] 100 | if ev_i == 0: 101 | continue 102 | for j in range(n): 103 | dur_j = idx_durations[j] 104 | ev_j = events[j] 105 | if (dur_i < dur_j) or ((dur_i == dur_j) and (ev_j == 0)): 106 | mat[i, j] = 1 107 | return mat 108 | 109 | def pair_rank_mat(idx_durations, events, dtype='float32'): 110 | """Indicator matrix R with R_ij = 1{T_i < T_j and D_i = 1}. 111 | So it takes value 1 if we observe that i has an event before j and zero otherwise. 112 | 113 | Arguments: 114 | idx_durations {np.array} -- Array with durations. 115 | events {np.array} -- Array with event indicators. 116 | 117 | Keyword Arguments: 118 | dtype {str} -- dtype of array (default: {'float32'}) 119 | 120 | Returns: 121 | np.array -- n x n matrix indicating if i has an observerd event before j. 122 | """ 123 | idx_durations = idx_durations.reshape(-1) 124 | events = events.reshape(-1) 125 | n = len(idx_durations) 126 | mat = np.zeros((n, n), dtype=dtype) 127 | mat = _pair_rank_mat(mat, idx_durations, events, dtype) 128 | return mat 129 | 130 | 131 | class DeepHitDataset(tt.data.DatasetTuple): 132 | def __getitem__(self, index): 133 | input, target = super().__getitem__(index) 134 | target = target.to_numpy() 135 | rank_mat = pair_rank_mat(*target) 136 | target = tt.tuplefy(*target, rank_mat).to_tensor() 137 | return tt.tuplefy(input, target) 138 | -------------------------------------------------------------------------------- /pycox/models/deephit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import torchtuples as tt 5 | 6 | from pycox import models 7 | from pycox.models.utils import pad_col 8 | 9 | class DeepHitSingle(models.pmf.PMFBase): 10 | """The DeepHit methods by [1] but only for single event (not competing risks). 11 | 12 | Note that `alpha` is here defined differently than in [1], as `alpha` is weighting between 13 | the likelihood and rank loss (see Appendix D in [2]) 14 | loss = alpha * nll + (1 - alpha) rank_loss(sigma). 15 | 16 | Also, unlike [1], this implementation allows for survival past the max durations, i.e., it 17 | does not assume all events happen within the defined duration grid. See [3] for details. 18 | 19 | Keyword Arguments: 20 | alpha {float} -- Weighting (0, 1) likelihood and rank loss (L2 in paper). 21 | 1 gives only likelihood, and 0 gives only rank loss. (default: {0.2}) 22 | sigma {float} -- from eta in rank loss (L2 in paper) (default: {0.1}) 23 | 24 | References: 25 | [1] Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar. Deephit: A deep learning 26 | approach to survival analysis with competing risks. In Thirty-Second AAAI Conference on Artificial 27 | Intelligence, 2018. 28 | http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit 29 | 30 | [2] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel. 31 | Time-to-event prediction with neural networks and Cox regression. 32 | Journal of Machine Learning Research, 20(129):1–30, 2019. 33 | http://jmlr.org/papers/v20/18-424.html 34 | 35 | [3] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction 36 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. 37 | https://arxiv.org/pdf/1910.06724.pdf 38 | """ 39 | def __init__(self, net, optimizer=None, device=None, duration_index=None, alpha=0.2, sigma=0.1, loss=None): 40 | if loss is None: 41 | loss = models.loss.DeepHitSingleLoss(alpha, sigma) 42 | super().__init__(net, loss, optimizer, device, duration_index) 43 | 44 | def make_dataloader(self, data, batch_size, shuffle, num_workers=0): 45 | dataloader = super().make_dataloader(data, batch_size, shuffle, num_workers, 46 | make_dataset=models.data.DeepHitDataset) 47 | return dataloader 48 | 49 | def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0): 50 | dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers) 51 | return dataloader 52 | 53 | 54 | class DeepHit(tt.Model): 55 | """DeepHit for competing risks [1]. 56 | For single risk (only one event type) use `DeepHitSingle` instead! 57 | 58 | Note that `alpha` is here defined differently than in [1], as `alpha` is weighting between 59 | the likelihood and rank loss (see Appendix D in [2]) 60 | loss = alpha * nll + (1 - alpha) rank_loss(sigma). 61 | 62 | Also, unlike [1], this implementation allows for survival past the max durations, i.e., it 63 | does not assume all events happen within the defined duration grid. See [3] for details. 64 | 65 | Keyword Arguments: 66 | alpha {float} -- Weighting (0, 1) likelihood and rank loss (L2 in paper). 67 | 1 gives only likelihood, and 0 gives only rank loss. (default: {0.2}) 68 | sigma {float} -- from eta in rank loss (L2 in paper) (default: {0.1}) 69 | 70 | References: 71 | [1] Changhee Lee, William R Zame, Jinsung Yoon, and Mihaela van der Schaar. Deephit: A deep learning 72 | approach to survival analysis with competing risks. In Thirty-Second AAAI Conference on Artificial 73 | Intelligence, 2018. 74 | http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit 75 | 76 | [2] Håvard Kvamme, Ørnulf Borgan, and Ida Scheel. 77 | Time-to-event prediction with neural networks and Cox regression. 78 | Journal of Machine Learning Research, 20(129):1–30, 2019. 79 | http://jmlr.org/papers/v20/18-424.html 80 | 81 | [3] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction 82 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. 83 | https://arxiv.org/pdf/1910.06724.pdf 84 | """ 85 | def __init__(self, net, optimizer=None, device=None, alpha=0.2, sigma=0.1, duration_index=None, loss=None): 86 | self.duration_index = duration_index 87 | if loss is None: 88 | loss = models.loss.DeepHitLoss(alpha, sigma) 89 | super().__init__(net, loss, optimizer, device) 90 | 91 | @property 92 | def duration_index(self): 93 | """ 94 | Array of durations that defines the discrete times. This is used to set the index 95 | of the DataFrame in `predict_surv_df`. 96 | 97 | Returns: 98 | np.array -- Duration index. 99 | """ 100 | return self._duration_index 101 | 102 | @duration_index.setter 103 | def duration_index(self, val): 104 | self._duration_index = val 105 | 106 | def make_dataloader(self, data, batch_size, shuffle, num_workers=0): 107 | dataloader = super().make_dataloader(data, batch_size, shuffle, num_workers, 108 | make_dataset=models.data.DeepHitDataset) 109 | return dataloader 110 | 111 | def make_dataloader_predict(self, input, batch_size, shuffle=False, num_workers=0): 112 | dataloader = super().make_dataloader(input, batch_size, shuffle, num_workers) 113 | return dataloader 114 | 115 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0): 116 | """Predict the survival function for `input`, i.e., survive all of the event types, 117 | and return as a pandas DataFrame. 118 | See `prediction_surv_df` to return a DataFrame instead. 119 | 120 | Arguments: 121 | input {tuple, np.ndarra, or torch.tensor} -- Input to net. 122 | 123 | Keyword Arguments: 124 | batch_size {int} -- Batch size (default: {8224}) 125 | eval_ {bool} -- If 'True', use 'eval' modede on net. (default: {True}) 126 | num_workers {int} -- Number of workes in created dataloader (default: {0}) 127 | 128 | Returns: 129 | pd.DataFrame -- Predictions 130 | """ 131 | surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers) 132 | return pd.DataFrame(surv, self.duration_index) 133 | 134 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, 135 | to_cpu=False, num_workers=0): 136 | """Predict the survival function for `input`, i.e., survive all of the event types. 137 | See `prediction_surv_df` to return a DataFrame instead. 138 | 139 | Arguments: 140 | input {tuple, np.ndarra, or torch.tensor} -- Input to net. 141 | 142 | Keyword Arguments: 143 | batch_size {int} -- Batch size (default: {8224}) 144 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 145 | (default: {None}) 146 | eval_ {bool} -- If 'True', use 'eval' modede on net. (default: {True}) 147 | to_cpu {bool} -- For larger data sets we need to move the results to cpu 148 | (default: {False}) 149 | num_workers {int} -- Number of workes in created dataloader (default: {0}) 150 | 151 | Returns: 152 | [TupleTree, np.ndarray or tensor] -- Predictions 153 | """ 154 | cif = self.predict_cif(input, batch_size, False, eval_, to_cpu, num_workers) 155 | surv = 1. - cif.sum(0) 156 | return tt.utils.array_or_tensor(surv, numpy, input) 157 | 158 | def predict_cif(self, input, batch_size=8224, numpy=None, eval_=True, 159 | to_cpu=False, num_workers=0): 160 | """Predict the cumulative incidence function (cif) for `input`. 161 | 162 | Arguments: 163 | input {tuple, np.ndarray, or torch.tensor} -- Input to net. 164 | 165 | Keyword Arguments: 166 | batch_size {int} -- Batch size (default: {8224}) 167 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 168 | (default: {None}) 169 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 170 | to_cpu {bool} -- For larger data sets we need to move the results to cpu 171 | (default: {False}) 172 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 173 | 174 | Returns: 175 | [np.ndarray or tensor] -- Predictions 176 | """ 177 | pmf = self.predict_pmf(input, batch_size, False, eval_, to_cpu, num_workers) 178 | cif = pmf.cumsum(1) 179 | return tt.utils.array_or_tensor(cif, numpy, input) 180 | 181 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, 182 | to_cpu=False, num_workers=0): 183 | """Predict the probability mass fuction (PMF) for `input`. 184 | 185 | Arguments: 186 | input {tuple, np.ndarray, or torch.tensor} -- Input to net. 187 | 188 | Keyword Arguments: 189 | batch_size {int} -- Batch size (default: {8224}) 190 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 191 | (default: {None}) 192 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 193 | grads {bool} -- If gradients should be computed (default: {False}) 194 | to_cpu {bool} -- For larger data sets we need to move the results to cpu 195 | (default: {False}) 196 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 197 | 198 | Returns: 199 | [np.ndarray or tensor] -- Predictions 200 | """ 201 | preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers) 202 | pmf = pad_col(preds.view(preds.size(0), -1)).softmax(1)[:, :-1] 203 | pmf = pmf.view(preds.shape).transpose(0, 1).transpose(1, 2) 204 | return tt.utils.array_or_tensor(pmf, numpy, input) 205 | -------------------------------------------------------------------------------- /pycox/models/interpolation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | import torchtuples as tt 4 | from pycox.models import utils 5 | 6 | 7 | class InterpolateDiscrete: 8 | """Interpolation of discrete models, for continuous predictions. 9 | There are two schemes: 10 | `const_hazard` and `exp_surv` which assumes pice-wise constant hazard in each interval (exponential survival). 11 | `const_pdf` and `lin_surv` which assumes pice-wise constant pmf in each interval (linear survival). 12 | 13 | Arguments: 14 | model {[type]} -- [description] 15 | 16 | Keyword Arguments: 17 | duration_index {np.array} -- Cuts used for discretization. Does not affect interpolation, 18 | only for setting index in `predict_surv_df` (default: {None}) 19 | scheme {str} -- Type of interpolation {'const_hazard', 'const_pdf'} (default: {'const_pdf'}) 20 | sub {int} -- Number of "sub" units in interpolation grid. If `sub` is 10 we have a grid with 21 | 10 times the number of grid points than the original `duration_index` (default: {10}). 22 | 23 | Keyword Arguments: 24 | """ 25 | def __init__(self, model, scheme='const_pdf', duration_index=None, sub=10, epsilon=1e-7): 26 | self.model = model 27 | self.scheme = scheme 28 | self.duration_index = duration_index 29 | self.sub = sub 30 | 31 | @property 32 | def sub(self): 33 | return self._sub 34 | 35 | @sub.setter 36 | def sub(self, sub): 37 | if type(sub) is not int: 38 | raise ValueError(f"Need `sub` to have type `int`, got {type(sub)}") 39 | self._sub = sub 40 | 41 | def predict_hazard(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 42 | raise NotImplementedError 43 | 44 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 45 | raise NotImplementedError 46 | 47 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 48 | """Predict the survival function for `input`. 49 | See `prediction_surv_df` to return a DataFrame instead. 50 | 51 | Arguments: 52 | input {tuple, np.ndarray, or torch.tensor} -- Input to net. 53 | 54 | Keyword Arguments: 55 | batch_size {int} -- Batch size (default: {8224}) 56 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 57 | (default: {None}) 58 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 59 | to_cpu {bool} -- For larger data sets we need to move the results to cpu 60 | (default: {False}) 61 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 62 | 63 | Returns: 64 | [np.ndarray or tensor] -- Predictions 65 | """ 66 | return self._surv_const_pdf(input, batch_size, numpy, eval_, to_cpu, num_workers) 67 | 68 | def _surv_const_pdf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, 69 | num_workers=0): 70 | """Basic method for constant PDF interpolation that use `self.model.predict_surv`. 71 | 72 | Arguments: 73 | input {tuple, np.ndarray, or torch.tensor} -- Input to net. 74 | 75 | Keyword Arguments: 76 | batch_size {int} -- Batch size (default: {8224}) 77 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 78 | (default: {None}) 79 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 80 | to_cpu {bool} -- For larger data sets we need to move the results to cpu 81 | (default: {False}) 82 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 83 | 84 | Returns: 85 | [np.ndarray or tensor] -- Predictions 86 | """ 87 | s = self.model.predict_surv(input, batch_size, False, eval_, to_cpu, num_workers) 88 | n, m = s.shape 89 | device = s.device 90 | diff = (s[:, 1:] - s[:, :-1]).contiguous().view(-1, 1).repeat(1, self.sub).view(n, -1) 91 | rho = torch.linspace(0, 1, self.sub+1, device=device)[:-1].contiguous().repeat(n, m-1) 92 | s_prev = s[:, :-1].contiguous().view(-1, 1).repeat(1, self.sub).view(n, -1) 93 | surv = torch.zeros(n, int((m-1)*self.sub + 1)) 94 | surv[:, :-1] = diff * rho + s_prev 95 | surv[:, -1] = s[:, -1] 96 | return tt.utils.array_or_tensor(surv, numpy, input) 97 | 98 | def predict_surv_df(self, input, batch_size=8224, eval_=True, to_cpu=False, num_workers=0): 99 | """Predict the survival function for `input` and return as a pandas DataFrame. 100 | See `predict_surv` to return tensor or np.array instead. 101 | 102 | Arguments: 103 | input {tuple, np.ndarray, or torch.tensor} -- Input to net. 104 | 105 | Keyword Arguments: 106 | batch_size {int} -- Batch size (default: {8224}) 107 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 108 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 109 | 110 | Returns: 111 | pd.DataFrame -- Predictions 112 | """ 113 | surv = self.predict_surv(input, batch_size, True, eval_, to_cpu, num_workers) 114 | index = None 115 | if self.duration_index is not None: 116 | index = utils.make_subgrid(self.duration_index, self.sub) 117 | return pd.DataFrame(surv.transpose(), index) 118 | 119 | 120 | class InterpolatePMF(InterpolateDiscrete): 121 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 122 | if not self.scheme in ['const_pdf', 'lin_surv']: 123 | raise NotImplementedError 124 | pmf = self.model.predict_pmf(input, batch_size, False, eval_, to_cpu, num_workers) 125 | n, m = pmf.shape 126 | pmf_cdi = pmf[:, 1:].contiguous().view(-1, 1).repeat(1, self.sub).div(self.sub).view(n, -1) 127 | pmf_cdi = utils.pad_col(pmf_cdi, where='start') 128 | pmf_cdi[:, 0] = pmf[:, 0] 129 | return tt.utils.array_or_tensor(pmf_cdi, numpy, input) 130 | 131 | def _surv_const_pdf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 132 | pmf = self.predict_pmf(input, batch_size, False, eval_, to_cpu, num_workers) 133 | surv = 1 - pmf.cumsum(1) 134 | return tt.utils.array_or_tensor(surv, numpy, input) 135 | 136 | 137 | class InterpolateLogisticHazard(InterpolateDiscrete): 138 | epsilon = 1e-7 139 | def predict_hazard(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 140 | if self.scheme in ['const_hazard', 'exp_surv']: 141 | haz = self._hazard_const_haz(input, batch_size, numpy, eval_, to_cpu, num_workers) 142 | else: 143 | raise NotImplementedError 144 | return haz 145 | 146 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 147 | if self.scheme in ['const_hazard', 'exp_surv']: 148 | surv = self._surv_const_haz(input, batch_size, numpy, eval_, to_cpu, num_workers) 149 | elif self.scheme in ['const_pdf', 'lin_surv']: 150 | surv = self._surv_const_pdf(input, batch_size, numpy, eval_, to_cpu, num_workers) 151 | else: 152 | raise NotImplementedError 153 | return surv 154 | 155 | def _hazard_const_haz(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, 156 | num_workers=0): 157 | """Computes the continuous-time constant hazard interpolation. 158 | Essentially we what the discrete survival estimates to match the continuous time at the knots. 159 | So essentially we want 160 | $$S(tau_j) = prod_{k=1}^j [1 - h_k] = prod_{k=1}{j} exp[-eta_k].$$ 161 | where $h_k$ is the discrete hazard estimates and $eta_k$ continuous time hazards multiplied 162 | with the length of the duration interval as they are defined for the PC-Hazard method. 163 | Thus we get 164 | $$eta_k = - log[1 - h_k]$$ 165 | which can be divided by the length of the time interval to get the continuous time hazards. 166 | """ 167 | haz_orig = self.model.predict_hazard(input, batch_size, False, eval_, to_cpu, num_workers) 168 | haz = (1 - haz_orig).add(self.epsilon).log().mul(-1).relu()[:, 1:].contiguous() 169 | n = haz.shape[0] 170 | haz = haz.view(-1, 1).repeat(1, self.sub).view(n, -1).div(self.sub) 171 | haz = utils.pad_col(haz, where='start') 172 | haz[:, 0] = haz_orig[:, 0] 173 | return tt.utils.array_or_tensor(haz, numpy, input) 174 | 175 | def _surv_const_haz(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 176 | haz = self._hazard_const_haz(input, batch_size, False, eval_, to_cpu, num_workers) 177 | surv_0 = 1 - haz[:, :1] 178 | surv = utils.pad_col(haz[:, 1:], where='start').cumsum(1).mul(-1).exp().mul(surv_0) 179 | return tt.utils.array_or_tensor(surv, numpy, input) 180 | -------------------------------------------------------------------------------- /pycox/models/logistic_hazard.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | import torchtuples as tt 6 | from pycox import models 7 | from pycox.models.utils import pad_col, make_subgrid 8 | from pycox.preprocessing import label_transforms 9 | from pycox.models.interpolation import InterpolateLogisticHazard 10 | 11 | class LogisticHazard(models.base.SurvBase): 12 | """ 13 | A discrete-time survival model that minimize the likelihood for right-censored data by 14 | parameterizing the hazard function. Also known as "Nnet-survival" [3]. 15 | 16 | The Logistic-Hazard was first proposed by [2], but this implementation follows [1]. 17 | 18 | Arguments: 19 | net {torch.nn.Module} -- A torch module. 20 | 21 | Keyword Arguments: 22 | optimizer {Optimizer} -- A torch optimizer or similar. Preferably use torchtuples.optim instead of 23 | torch.optim, as this allows for reinitialization, etc. If 'None' set to torchtuples.optim.AdamW. 24 | (default: {None}) 25 | device {str, int, torch.device} -- Device to compute on. (default: {None}) 26 | Preferably pass a torch.device object. 27 | If 'None': use default gpu if available, else use cpu. 28 | If 'int': used that gpu: torch.device('cuda:'). 29 | If 'string': string is passed to torch.device('string'). 30 | duration_index {list, np.array} -- Array of durations that defines the discrete times. 31 | This is used to set the index of the DataFrame in `predict_surv_df`. 32 | 33 | References: 34 | [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction 35 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. 36 | https://arxiv.org/pdf/1910.06724.pdf 37 | 38 | [2] Charles C. Brown. On the use of indicator variables for studying the time-dependence of parameters 39 | in a response-time model. Biometrics, 31(4):863–872, 1975. 40 | https://www.jstor.org/stable/2529811?seq=1#metadata_info_tab_contents 41 | 42 | [3] Michael F. Gensheimer and Balasubramanian Narasimhan. A scalable discrete-time survival model for 43 | neural networks. PeerJ, 7:e6257, 2019. 44 | https://peerj.com/articles/6257/ 45 | """ 46 | label_transform = label_transforms.LabTransDiscreteTime 47 | 48 | def __init__(self, net, optimizer=None, device=None, duration_index=None, loss=None): 49 | self.duration_index = duration_index 50 | if loss is None: 51 | loss = models.loss.NLLLogistiHazardLoss() 52 | super().__init__(net, loss, optimizer, device) 53 | 54 | @property 55 | def duration_index(self): 56 | """ 57 | Array of durations that defines the discrete times. This is used to set the index 58 | of the DataFrame in `predict_surv_df`. 59 | 60 | Returns: 61 | np.array -- Duration index. 62 | """ 63 | return self._duration_index 64 | 65 | @duration_index.setter 66 | def duration_index(self, val): 67 | self._duration_index = val 68 | 69 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0): 70 | surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers) 71 | return pd.DataFrame(surv.transpose(), self.duration_index) 72 | 73 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, 74 | num_workers=0, epsilon=1e-7): 75 | hazard = self.predict_hazard(input, batch_size, False, eval_, to_cpu, num_workers) 76 | surv = (1 - hazard).add(epsilon).log().cumsum(1).exp() 77 | return tt.utils.array_or_tensor(surv, numpy, input) 78 | 79 | 80 | def predict_hazard(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, 81 | num_workers=0): 82 | hazard = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers).sigmoid() 83 | return tt.utils.array_or_tensor(hazard, numpy, input) 84 | 85 | def interpolate(self, sub=10, scheme='const_pdf', duration_index=None): 86 | """Use interpolation for predictions. 87 | There are two schemes: 88 | `const_hazard` and `exp_surv` which assumes pice-wise constant hazard in each interval (exponential survival). 89 | `const_pdf` and `lin_surv` which assumes pice-wise constant PMF in each interval (linear survival). 90 | 91 | Keyword Arguments: 92 | sub {int} -- Number of "sub" units in interpolation grid. If `sub` is 10 we have a grid with 93 | 10 times the number of grid points than the original `duration_index` (default: {10}). 94 | scheme {str} -- Type of interpolation {'const_hazard', 'const_pdf'}. 95 | See `InterpolateDiscrete` (default: {'const_pdf'}) 96 | duration_index {np.array} -- Cuts used for discretization. Does not affect interpolation, 97 | only for setting index in `predict_surv_df` (default: {None}) 98 | 99 | Returns: 100 | [InterpolateLogisticHazard] -- Object for prediction with interpolation. 101 | """ 102 | if duration_index is None: 103 | duration_index = self.duration_index 104 | return InterpolateLogisticHazard(self, scheme, duration_index, sub) 105 | -------------------------------------------------------------------------------- /pycox/models/mtlr.py: -------------------------------------------------------------------------------- 1 | 2 | from pycox import models 3 | import torchtuples as tt 4 | from pycox.models import utils 5 | 6 | class MTLR(models.pmf.PMFBase): 7 | """ 8 | The (Neural) Multi-Task Logistic Regression, MTLR [1] and N-MTLR [2]. 9 | A discrete-time survival model that minimize the likelihood for right-censored data. 10 | 11 | This is essentially a PMF parametrization with an extra cumulative sum, as explained in [3]. 12 | 13 | Arguments: 14 | net {torch.nn.Module} -- A torch module. 15 | 16 | Keyword Arguments: 17 | optimizer {Optimizer} -- A torch optimizer or similar. Preferably use torchtuples.optim instead of 18 | torch.optim, as this allows for reinitialization, etc. If 'None' set to torchtuples.optim.AdamW. 19 | (default: {None}) 20 | device {str, int, torch.device} -- Device to compute on. (default: {None}) 21 | Preferably pass a torch.device object. 22 | If 'None': use default gpu if available, else use cpu. 23 | If 'int': used that gpu: torch.device('cuda:'). 24 | If 'string': string is passed to torch.device('string'). 25 | duration_index {list, np.array} -- Array of durations that defines the discrete times. 26 | This is used to set the index of the DataFrame in `predict_surv_df`. 27 | 28 | References: 29 | [1] Chun-Nam Yu, Russell Greiner, Hsiu-Chin Lin, and Vickie Baracos. 30 | Learning patient- specific cancer survival distributions as a sequence of dependent regressors. 31 | In Advances in Neural Information Processing Systems 24, pages 1845–1853. 32 | Curran Associates, Inc., 2011. 33 | https://papers.nips.cc/paper/4210-learning-patient-specific-cancer-survival-distributions-as-a-sequence-of-dependent-regressors.pdf 34 | 35 | [2] Stephane Fotso. Deep neural networks for survival analysis based on a multi-task framework. 36 | arXiv preprint arXiv:1801.05512, 2018. 37 | https://arxiv.org/pdf/1801.05512.pdf 38 | 39 | [3] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction 40 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. 41 | https://arxiv.org/pdf/1910.06724.pdf 42 | """ 43 | def __init__(self, net, optimizer=None, device=None, duration_index=None, loss=None): 44 | if loss is None: 45 | loss = models.loss.NLLMTLRLoss() 46 | super().__init__(net, loss, optimizer, device, duration_index) 47 | 48 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 49 | preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers) 50 | preds = utils.cumsum_reverse(preds, dim=1) 51 | pmf = utils.pad_col(preds).softmax(1)[:, :-1] 52 | return tt.utils.array_or_tensor(pmf, numpy, input) 53 | -------------------------------------------------------------------------------- /pycox/models/pc_hazard.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import pandas as pd 3 | import torch 4 | import torch.nn.functional as F 5 | import torchtuples as tt 6 | from pycox import models 7 | from pycox.models.utils import pad_col, make_subgrid 8 | from pycox.preprocessing import label_transforms 9 | 10 | class PCHazard(models.base.SurvBase): 11 | """The PC-Hazard (piecewise constant hazard) method from [1]. 12 | The Piecewise Constant Hazard (PC-Hazard) model from [1] which assumes that the continuous-time 13 | hazard function is constant in a set of predefined intervals. It is similar to the Piecewise 14 | Exponential Models [2] but with a softplus activation instead of the exponential function. 15 | 16 | Note that the label_transform is slightly different than that of the LogistcHazard and PMF methods. 17 | This typically results in one less output node. 18 | 19 | References: 20 | [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction 21 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. 22 | https://arxiv.org/pdf/1910.06724.pdf 23 | 24 | [2] Michael Friedman. Piecewise exponential models for survival data with covariates. 25 | The Annals of Statistics, 10(1):101–113, 1982. 26 | https://projecteuclid.org/euclid.aos/1176345693 27 | """ 28 | label_transform = label_transforms.LabTransPCHazard 29 | 30 | def __init__(self, net, optimizer=None, device=None, duration_index=None, sub=1, loss=None): 31 | self.duration_index = duration_index 32 | self.sub = sub 33 | if loss is None: 34 | loss = models.loss.NLLPCHazardLoss() 35 | super().__init__(net, loss, optimizer, device) 36 | if self.duration_index is not None: 37 | self._check_out_features() 38 | 39 | @property 40 | def sub(self): 41 | return self._sub 42 | 43 | @sub.setter 44 | def sub(self, sub): 45 | if type(sub) is not int: 46 | raise ValueError(f"Need `sub` to have type `int`, got {type(sub)}") 47 | self._sub = sub 48 | 49 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 50 | hazard = self.predict_hazard(input, batch_size, False, eval_, to_cpu, num_workers) 51 | surv = hazard.cumsum(1).mul(-1).exp() 52 | return tt.utils.array_or_tensor(surv, numpy, input) 53 | 54 | def predict_hazard(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, num_workers=0): 55 | """Predict the hazard function for `input`. 56 | 57 | Arguments: 58 | input {tuple, np.ndarra, or torch.tensor} -- Input to net. 59 | 60 | Keyword Arguments: 61 | batch_size {int} -- Batch size (default: {8224}) 62 | numpy {bool} -- 'False' gives tensor, 'True' gives numpy, and None give same as input 63 | (default: {None}) 64 | eval_ {bool} -- If 'True', use 'eval' mode on net. (default: {True}) 65 | to_cpu {bool} -- For larger data sets we need to move the results to cpu 66 | (default: {False}) 67 | num_workers {int} -- Number of workers in created dataloader (default: {0}) 68 | 69 | Returns: 70 | [np.ndarray or tensor] -- Predicted hazards 71 | """ 72 | preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers) 73 | n = preds.shape[0] 74 | hazard = F.softplus(preds).view(-1, 1).repeat(1, self.sub).view(n, -1).div(self.sub) 75 | hazard = pad_col(hazard, where='start') 76 | return tt.utils.array_or_tensor(hazard, numpy, input) 77 | 78 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0): 79 | self._check_out_features() 80 | surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers) 81 | index = None 82 | if self.duration_index is not None: 83 | index = make_subgrid(self.duration_index, self.sub) 84 | return pd.DataFrame(surv.transpose(), index) 85 | 86 | def fit(self, input, target, batch_size=256, epochs=1, callbacks=None, verbose=True, 87 | num_workers=0, shuffle=True, metrics=None, val_data=None, val_batch_size=8224, 88 | check_out_features=True, **kwargs): 89 | if check_out_features: 90 | self._check_out_features(target) 91 | return super().fit(input, target, batch_size, epochs, callbacks, verbose, num_workers, 92 | shuffle, metrics, val_data, val_batch_size, **kwargs) 93 | 94 | def fit_dataloader(self, dataloader, epochs=1, callbacks=None, verbose=True, metrics=None, 95 | val_dataloader=None, check_out_features=True): 96 | if check_out_features: 97 | self._check_out_features() 98 | return super().fit_dataloader(dataloader, epochs, callbacks, verbose, metrics, val_dataloader) 99 | 100 | def _check_out_features(self, target=None): 101 | last = list(self.net.modules())[-1] 102 | if hasattr(last, 'out_features'): 103 | m_output = last.out_features 104 | if self.duration_index is not None: 105 | n_grid = len(self.duration_index) 106 | if n_grid == m_output: 107 | raise ValueError("Output of `net` is one too large. Should have length "+ 108 | f"{len(self.duration_index)-1}") 109 | if n_grid != (m_output + 1): 110 | raise ValueError(f"Output of `net` does not correspond with `duration_index`") 111 | if target is not None: 112 | max_idx = tt.tuplefy(target).to_numpy()[0].max() 113 | if m_output != (max_idx + 1): 114 | raise ValueError(f"Output of `net` is {m_output}, but data only trains {max_idx + 1} indices. "+ 115 | f"Output of `net` should be {max_idx + 1}."+ 116 | "Set `check_out_feature=False` to suppress this Error.") 117 | -------------------------------------------------------------------------------- /pycox/models/pmf.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torchtuples as tt 3 | from pycox import models 4 | from pycox.models.utils import pad_col 5 | from pycox.preprocessing import label_transforms 6 | from pycox.models.interpolation import InterpolatePMF 7 | 8 | 9 | class PMFBase(models.base.SurvBase): 10 | """Base class for PMF methods. 11 | """ 12 | label_transform = label_transforms.LabTransDiscreteTime 13 | 14 | def __init__(self, net, loss=None, optimizer=None, device=None, duration_index=None): 15 | self.duration_index = duration_index 16 | super().__init__(net, loss, optimizer, device) 17 | 18 | @property 19 | def duration_index(self): 20 | """ 21 | Array of durations that defines the discrete times. This is used to set the index 22 | of the DataFrame in `predict_surv_df`. 23 | 24 | Returns: 25 | np.array -- Duration index. 26 | """ 27 | return self._duration_index 28 | 29 | @duration_index.setter 30 | def duration_index(self, val): 31 | self._duration_index = val 32 | 33 | def predict_surv(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, 34 | num_workers=0): 35 | pmf = self.predict_pmf(input, batch_size, False, eval_, to_cpu, num_workers) 36 | surv = 1 - pmf.cumsum(1) 37 | return tt.utils.array_or_tensor(surv, numpy, input) 38 | 39 | def predict_pmf(self, input, batch_size=8224, numpy=None, eval_=True, to_cpu=False, 40 | num_workers=0): 41 | preds = self.predict(input, batch_size, False, eval_, False, to_cpu, num_workers) 42 | pmf = pad_col(preds).softmax(1)[:, :-1] 43 | return tt.utils.array_or_tensor(pmf, numpy, input) 44 | 45 | def predict_surv_df(self, input, batch_size=8224, eval_=True, num_workers=0): 46 | surv = self.predict_surv(input, batch_size, True, eval_, True, num_workers) 47 | return pd.DataFrame(surv.transpose(), self.duration_index) 48 | 49 | def interpolate(self, sub=10, scheme='const_pdf', duration_index=None): 50 | """Use interpolation for predictions. 51 | There are only one scheme: 52 | `const_pdf` and `lin_surv` which assumes pice-wise constant pmf in each interval (linear survival). 53 | 54 | Keyword Arguments: 55 | sub {int} -- Number of "sub" units in interpolation grid. If `sub` is 10 we have a grid with 56 | 10 times the number of grid points than the original `duration_index` (default: {10}). 57 | scheme {str} -- Type of interpolation {'const_hazard', 'const_pdf'}. 58 | See `InterpolateDiscrete` (default: {'const_pdf'}) 59 | duration_index {np.array} -- Cuts used for discretization. Does not affect interpolation, 60 | only for setting index in `predict_surv_df` (default: {None}) 61 | 62 | Returns: 63 | [InterpolationPMF] -- Object for prediction with interpolation. 64 | """ 65 | if duration_index is None: 66 | duration_index = self.duration_index 67 | return InterpolatePMF(self, scheme, duration_index, sub) 68 | 69 | 70 | class PMF(PMFBase): 71 | """ 72 | The PMF is a discrete-time survival model that parametrize the probability mass function (PMF) 73 | and optimizer the survival likelihood. It is the foundation of methods such as DeepHit and MTLR. 74 | See [1] for details. 75 | 76 | Arguments: 77 | net {torch.nn.Module} -- A torch module. 78 | 79 | Keyword Arguments: 80 | optimizer {Optimizer} -- A torch optimizer or similar. Preferably use torchtuples.optim instead of 81 | torch.optim, as this allows for reinitialization, etc. If 'None' set to torchtuples.optim.AdamW. 82 | (default: {None}) 83 | device {str, int, torch.device} -- Device to compute on. (default: {None}) 84 | Preferably pass a torch.device object. 85 | If 'None': use default gpu if available, else use cpu. 86 | If 'int': used that gpu: torch.device('cuda:'). 87 | If 'string': string is passed to torch.device('string'). 88 | duration_index {list, np.array} -- Array of durations that defines the discrete times. 89 | This is used to set the index of the DataFrame in `predict_surv_df`. 90 | 91 | References: 92 | [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction 93 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. 94 | https://arxiv.org/pdf/1910.06724.pdf 95 | """ 96 | def __init__(self, net, optimizer=None, device=None, duration_index=None, loss=None): 97 | if loss is None: 98 | loss = models.loss.NLLPMFLoss() 99 | super().__init__(net, loss, optimizer, device, duration_index) 100 | 101 | -------------------------------------------------------------------------------- /pycox/models/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torchtuples as tt 6 | 7 | def pad_col(input, val=0, where='end'): 8 | """Addes a column of `val` at the start of end of `input`.""" 9 | if len(input.shape) != 2: 10 | raise ValueError(f"Only works for `phi` tensor that is 2-D.") 11 | pad = torch.zeros_like(input[:, :1]) 12 | if val != 0: 13 | pad = pad + val 14 | if where == 'end': 15 | return torch.cat([input, pad], dim=1) 16 | elif where == 'start': 17 | return torch.cat([pad, input], dim=1) 18 | raise ValueError(f"Need `where` to be 'start' or 'end', got {where}") 19 | 20 | def array_or_tensor(tensor, numpy, input): 21 | warnings.warn('Use `torchtuples.utils.array_or_tensor` instead', DeprecationWarning) 22 | return tt.utils.array_or_tensor(tensor, numpy, input) 23 | 24 | def make_subgrid(grid, sub=1): 25 | """When calling `predict_surv` with sub != 1 this can help with 26 | creating the duration index of the survival estimates. 27 | 28 | E.g. 29 | sub = 5 30 | surv = model.predict_surv(test_input, sub=sub) 31 | grid = model.make_subgrid(cuts, sub) 32 | surv = pd.DataFrame(surv, index=grid) 33 | """ 34 | subgrid = tt.TupleTree(np.linspace(start, end, num=sub+1)[:-1] 35 | for start, end in zip(grid[:-1], grid[1:])) 36 | subgrid = subgrid.apply(lambda x: tt.TupleTree(x)).flatten() + (grid[-1],) 37 | return subgrid 38 | 39 | def log_softplus(input, threshold=-15.): 40 | """Equivalent to 'F.softplus(input).log()', but for 'input < threshold', 41 | we return 'input', as this is approximately the same. 42 | 43 | Arguments: 44 | input {torch.tensor} -- Input tensor 45 | 46 | Keyword Arguments: 47 | threshold {float} -- Treshold for when to just return input (default: {-15.}) 48 | 49 | Returns: 50 | torch.tensor -- return log(softplus(input)). 51 | """ 52 | output = input.clone() 53 | above = input >= threshold 54 | output[above] = F.softplus(input[above]).log() 55 | return output 56 | 57 | def cumsum_reverse(input: torch.Tensor, dim: int = 1) -> torch.Tensor: 58 | if dim != 1: 59 | raise NotImplementedError 60 | input = input.sum(1, keepdim=True) - pad_col(input, where='start').cumsum(1) 61 | return input[:, :-1] 62 | -------------------------------------------------------------------------------- /pycox/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from pycox.preprocessing import feature_transforms, label_transforms, discretization 2 | -------------------------------------------------------------------------------- /pycox/preprocessing/discretization.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import pandas as pd 4 | from pycox import utils 5 | 6 | 7 | def make_cuts(n_cuts, scheme, durations, events, min_=0., dtype='float64'): 8 | if scheme == 'equidistant': 9 | cuts = cuts_equidistant(durations.max(), n_cuts, min_, dtype) 10 | elif scheme == 'quantiles': 11 | cuts = cuts_quantiles(durations, events, n_cuts, min_, dtype) 12 | else: 13 | raise ValueError(f"Got invalid `scheme` {scheme}.") 14 | if (np.diff(cuts) == 0).any(): 15 | raise ValueError("cuts are not unique.") 16 | return cuts 17 | 18 | def _values_if_series(x): 19 | if type(x) is pd.Series: 20 | return x.values 21 | return x 22 | 23 | def cuts_equidistant(max_, num, min_=0., dtype='float64'): 24 | return np.linspace(min_, max_, num, dtype=dtype) 25 | 26 | def cuts_quantiles(durations, events, num, min_=0., dtype='float64'): 27 | """ 28 | If min_ = None, we will use durations.min() for the first cut. 29 | """ 30 | km = utils.kaplan_meier(durations, events) 31 | surv_est, surv_durations = km.values, km.index.values 32 | s_cuts = np.linspace(km.values.min(), km.values.max(), num) 33 | cuts_idx = np.searchsorted(surv_est[::-1], s_cuts)[::-1] 34 | cuts = surv_durations[::-1][cuts_idx] 35 | cuts = np.unique(cuts) 36 | if len(cuts) != num: 37 | warnings.warn(f"cuts are not unique, continue with {len(cuts)} cuts instead of {num}") 38 | cuts[0] = durations.min() if min_ is None else min_ 39 | assert cuts[-1] == durations.max(), 'something wrong...' 40 | return cuts.astype(dtype) 41 | 42 | def _is_monotonic_increasing(x): 43 | assert len(x.shape) == 1, 'Only works for 1d' 44 | return (x[1:] >= x[:-1]).all() 45 | 46 | def bin_numerical(x, right_cuts, error_on_larger=False): 47 | """ 48 | Discretize x into bins defined by right_cuts (needs to be sorted). 49 | If right_cuts = [1, 2], we have bins (-inf, 1], (1, 2], (2, inf). 50 | error_on_larger results in a ValueError if x contains larger 51 | values than right_cuts. 52 | 53 | Returns index of bins. 54 | To optaine values do righ_cuts[bin_numerica(x, right_cuts)]. 55 | """ 56 | assert _is_monotonic_increasing(right_cuts), 'Need `right_cuts` to be sorted.' 57 | bins = np.searchsorted(right_cuts, x, side='left') 58 | if bins.max() == right_cuts.size: 59 | if error_on_larger: 60 | raise ValueError('x contrains larger values than right_cuts.') 61 | return bins 62 | 63 | def discretize(x, cuts, side='right', error_on_larger=False): 64 | """Discretize x to cuts. 65 | 66 | Arguments: 67 | x {np.array} -- Array of times. 68 | cuts {np.array} -- Sortet array of discrete times. 69 | 70 | Keyword Arguments: 71 | side {str} -- If we shold round down or up (left, right) (default: {'right'}) 72 | error_on_larger {bool} -- If we shold return an error if we pass higher values 73 | than cuts (default: {False}) 74 | 75 | Returns: 76 | np.array -- Discretized values. 77 | """ 78 | if side not in ['right', 'left']: 79 | raise ValueError('side argument needs to be right or left.') 80 | bins = bin_numerical(x, cuts, error_on_larger) 81 | if side == 'right': 82 | cuts = np.concatenate((cuts, np.array([np.inf]))) 83 | return cuts[bins] 84 | bins_cut = bins.copy() 85 | bins_cut[bins_cut == cuts.size] = -1 86 | exact = cuts[bins_cut] == x 87 | left_bins = bins - 1 + exact 88 | vals = cuts[left_bins] 89 | vals[left_bins == -1] = - np.inf 90 | return vals 91 | 92 | 93 | class _OnlyTransform: 94 | """Abstract class for sklearn preprocessing methods. 95 | Only implements fit and fit_transform. 96 | """ 97 | def fit(self, *args): 98 | return self 99 | 100 | def transform(self, *args): 101 | raise NotImplementedError 102 | 103 | def fit_transform(self, *args): 104 | return self.fit(*args).transform(*args) 105 | 106 | 107 | class DiscretizeUnknownC(_OnlyTransform): 108 | """Implementation of scheme 2. 109 | 110 | cuts should be [t0, t1, ..., t_m], where t_m is right sensored value. 111 | """ 112 | def __init__(self, cuts, right_censor=False, censor_side='left'): 113 | self.cuts = cuts 114 | self.right_censor = right_censor 115 | self.censor_side = censor_side 116 | 117 | def transform(self, duration, event): 118 | dtype_event = event.dtype 119 | event = event.astype('bool') 120 | if self.right_censor: 121 | duration = duration.copy() 122 | censor = duration > self.cuts.max() 123 | duration[censor] = self.cuts.max() 124 | event[censor] = False 125 | if duration.max() > self.cuts.max(): 126 | raise ValueError("`duration` contains larger values than cuts. Set `right_censor`=True to censor these") 127 | td = np.zeros_like(duration) 128 | c = event == False 129 | td[event] = discretize(duration[event], self.cuts, side='right', error_on_larger=True) 130 | if c.any(): 131 | td[c] = discretize(duration[c], self.cuts, side=self.censor_side, error_on_larger=True) 132 | return td, event.astype(dtype_event) 133 | 134 | 135 | def duration_idx_map(duration): 136 | duration = np.unique(duration) 137 | duration = np.sort(duration) 138 | idx = np.arange(duration.shape[0]) 139 | return {d: i for i, d in zip(idx, duration)} 140 | 141 | 142 | class Duration2Idx(_OnlyTransform): 143 | def __init__(self, durations=None): 144 | self.durations = durations 145 | if durations is None: 146 | raise NotImplementedError() 147 | if self.durations is not None: 148 | self.duration_to_idx = self._make_map(self.durations) 149 | 150 | @staticmethod 151 | def _make_map(durations): 152 | return np.vectorize(duration_idx_map(durations).get) 153 | 154 | def transform(self, duration, y=None): 155 | if duration.dtype is not self.durations.dtype: 156 | raise ValueError('Need `time` to have same type as `self.durations`.') 157 | idx = self.duration_to_idx(duration) 158 | if np.isnan(idx).any(): 159 | raise ValueError('Encountered `nans` in transformed indexes.') 160 | return idx 161 | 162 | 163 | class IdxDiscUnknownC: 164 | """Get indexed for discrete data using cuts. 165 | 166 | Arguments: 167 | cuts {np.array} -- Array or right cuts. 168 | 169 | Keyword Arguments: 170 | label_cols {tuple} -- Name of label columns in dataframe (default: {None}). 171 | """ 172 | def __init__(self, cuts, label_cols=None, censor_side='left'): 173 | self.cuts = cuts 174 | self.duc = DiscretizeUnknownC(cuts, right_censor=True, censor_side=censor_side) 175 | self.di = Duration2Idx(cuts) 176 | self.label_cols= label_cols 177 | 178 | def transform(self, time, d): 179 | time, d = self.duc.transform(time, d) 180 | idx = self.di.transform(time) 181 | return idx, d 182 | 183 | def transform_df(self, df): 184 | if self.label_cols is None: 185 | raise RuntimeError("Need to set 'label_cols' to use this. Use 'transform instead'") 186 | col_duration, col_event = self.label_cols 187 | time = df[col_duration].values 188 | d = df[col_event].values 189 | return self.transform(time, d) 190 | -------------------------------------------------------------------------------- /pycox/preprocessing/feature_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | class OrderedCategoricalLong: 5 | """Transform pandas series or numpy array to categorical, and get (long) values, 6 | i.e. index of category. Useful for entity embeddings. 7 | Zero is reserved for unknown categories or nans. 8 | 9 | Keyword Arguments: 10 | min_per_category {int} -- Number of instances required to not be set to nan (default: {20}) 11 | return_series {bool} -- If return a array or pd.Series (default: {False}) 12 | 13 | Returns: 14 | [pd.Series] -- Series with long values reffering to categories. 15 | """ 16 | def __init__(self, min_per_category=20, return_series=False): 17 | 18 | self.min_per_category = min_per_category 19 | self.return_series = return_series 20 | 21 | def fit(self, series, y=None): 22 | series = pd.Series(series).copy() 23 | smaller = series.value_counts() < self.min_per_category 24 | values = smaller[smaller].index.values 25 | for v in values: 26 | series[series == v] = np.nan 27 | self.categories = series.astype('category').cat.categories 28 | return self 29 | 30 | def transform(self, series, y=None): 31 | series = pd.Series(series).copy() 32 | transformed = pd.Categorical(series, categories=self.categories, ordered=True) 33 | transformed = pd.Series(transformed, index=series.index) 34 | transformed = transformed.cat.codes.astype('int64') + 1 35 | return transformed if self.return_series else transformed.values 36 | 37 | def fit_transform(self, series, y=None): 38 | return self.fit(series, y).transform(series, y) 39 | -------------------------------------------------------------------------------- /pycox/preprocessing/label_transforms.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | from sklearn.preprocessing import StandardScaler 4 | from pycox.preprocessing.discretization import (make_cuts, IdxDiscUnknownC, _values_if_series, 5 | DiscretizeUnknownC, Duration2Idx) 6 | 7 | 8 | class LabTransCoxTime: 9 | """ 10 | Label transforms useful for CoxTime models. It can log-transform and standardize the durations. 11 | 12 | It also creates `map_scaled_to_orig` which is the inverse transform of the durations data, 13 | enabling us to set the correct time scale for predictions. 14 | This can be done by passing the object to the CoxTime init: 15 | model = CoxTime(net, labrans=labtrans) 16 | which gives the correct time scale of survival predictions 17 | surv = model.predict_surv_df(x) 18 | 19 | Keyword Arguments: 20 | log_duration {bool} -- Log-transform durations, i.e. 'log(1+x)'. (default: {False}) 21 | with_mean {bool} -- Center the duration before scaling. 22 | Passed to `sklearn.preprocessing.StandardScaler` (default: {True}) 23 | with_std {bool} -- Scale duration to unit variance. 24 | Passed to `sklearn.preprocessing.StandardScaler` (default: {True}) 25 | """ 26 | def __init__(self, log_duration=False, with_mean=True, with_std=True): 27 | self.log_duration = log_duration 28 | self.duration_scaler = StandardScaler(copy=True, with_mean=with_mean, with_std=with_std) 29 | 30 | @property 31 | def map_scaled_to_orig(self): 32 | """Map from transformed durations back to the original durations, i.e. inverse transform. 33 | 34 | Use it to e.g. set index of survival predictions: 35 | surv = model.predict_surv_df(x_test) 36 | surv.index = labtrans.map_scaled_to_orig(surv.index) 37 | """ 38 | if not hasattr(self, '_inverse_duration_map'): 39 | raise ValueError('Need to fit the models before you can call this method') 40 | return self._inverse_duration_map 41 | 42 | def fit(self, durations, events): 43 | self.fit_transform(durations, events) 44 | return self 45 | 46 | def fit_transform(self, durations, events): 47 | train_durations = durations 48 | durations = durations.astype('float32') 49 | events = events.astype('float32') 50 | if self.log_duration: 51 | durations = np.log1p(durations) 52 | durations = self.duration_scaler.fit_transform(durations.reshape(-1, 1)).flatten() 53 | self._inverse_duration_map = {scaled: orig for orig, scaled in zip(train_durations, durations)} 54 | self._inverse_duration_map = np.vectorize(self._inverse_duration_map.get) 55 | return durations, events 56 | 57 | def transform(self, durations, events): 58 | durations = durations.astype('float32') 59 | events = events.astype('float32') 60 | if self.log_duration: 61 | durations = np.log1p(durations) 62 | durations = self.duration_scaler.transform(durations.reshape(-1, 1)).flatten() 63 | return durations, events 64 | 65 | @property 66 | def out_features(self): 67 | """Returns the number of output features that should be used in the torch model. 68 | This always returns 1, and is just included for api design purposes. 69 | 70 | Returns: 71 | [int] -- Number of output features. 72 | """ 73 | return 1 74 | 75 | 76 | class LabTransDiscreteTime: 77 | """ 78 | Discretize continuous (duration, event) pairs based on a set of cut points. 79 | One can either determine the cut points in form of passing an array to this class, 80 | or one can obtain cut points based on the training data. 81 | 82 | The discretization learned from fitting to data will move censorings to the left cut point, 83 | and events to right cut point. 84 | 85 | Arguments: 86 | cuts {int, array} -- Defining cut points, either the number of cuts, or the actual cut points. 87 | 88 | Keyword Arguments: 89 | scheme {str} -- Scheme used for discretization. Either 'equidistant' or 'quantiles' 90 | (default: {'equidistant}) 91 | min_ {float} -- Starting duration (default: {0.}) 92 | dtype {str, dtype} -- dtype of discretization. 93 | """ 94 | def __init__(self, cuts, scheme='equidistant', min_=0., dtype=None): 95 | self._cuts = cuts 96 | self._scheme = scheme 97 | self._min = min_ 98 | self._dtype_init = dtype 99 | self._predefined_cuts = False 100 | self.cuts = None 101 | if hasattr(cuts, '__iter__'): 102 | if type(cuts) is list: 103 | cuts = np.array(cuts) 104 | self.cuts = cuts 105 | self.idu = IdxDiscUnknownC(self.cuts) 106 | assert dtype is None, "Need `dtype` to be `None` for specified cuts" 107 | self._dtype = type(self.cuts[0]) 108 | self._dtype_init = self._dtype 109 | self._predefined_cuts = True 110 | 111 | def fit(self, durations, events): 112 | if self._predefined_cuts: 113 | warnings.warn("Calling fit method, when 'cuts' are already defined. Leaving cuts unchanged.") 114 | return self 115 | self._dtype = self._dtype_init 116 | if self._dtype is None: 117 | if isinstance(durations[0], np.floating): 118 | self._dtype = durations.dtype 119 | else: 120 | self._dtype = np.dtype('float64') 121 | durations = durations.astype(self._dtype) 122 | self.cuts = make_cuts(self._cuts, self._scheme, durations, events, self._min, self._dtype) 123 | self.idu = IdxDiscUnknownC(self.cuts) 124 | return self 125 | 126 | def fit_transform(self, durations, events): 127 | self.fit(durations, events) 128 | idx_durations, events = self.transform(durations, events) 129 | return idx_durations, events 130 | 131 | def transform(self, durations, events): 132 | durations = _values_if_series(durations) 133 | durations = durations.astype(self._dtype) 134 | events = _values_if_series(events) 135 | idx_durations, events = self.idu.transform(durations, events) 136 | return idx_durations.astype('int64'), events.astype('float32') 137 | 138 | @property 139 | def out_features(self): 140 | """Returns the number of output features that should be used in the torch model. 141 | 142 | Returns: 143 | [int] -- Number of output features. 144 | """ 145 | if self.cuts is None: 146 | raise ValueError("Need to call `fit` before this is accessible.") 147 | return len(self.cuts) 148 | 149 | 150 | class LabTransPCHazard: 151 | """ 152 | Defining time intervals (`cuts`) needed for the `PCHazard` method [1]. 153 | One can either determine the cut points in form of passing an array to this class, 154 | or one can obtain cut points based on the training data. 155 | 156 | Arguments: 157 | cuts {int, array} -- Defining cut points, either the number of cuts, or the actual cut points. 158 | 159 | Keyword Arguments: 160 | scheme {str} -- Scheme used for discretization. Either 'equidistant' or 'quantiles' 161 | (default: {'equidistant}) 162 | min_ {float} -- Starting duration (default: {0.}) 163 | dtype {str, dtype} -- dtype of discretization. 164 | 165 | References: 166 | [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction 167 | with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. 168 | https://arxiv.org/pdf/1910.06724.pdf 169 | """ 170 | def __init__(self, cuts, scheme='equidistant', min_=0., dtype=None): 171 | self._cuts = cuts 172 | self._scheme = scheme 173 | self._min = min_ 174 | self._dtype_init = dtype 175 | self._predefined_cuts = False 176 | self.cuts = None 177 | if hasattr(cuts, '__iter__'): 178 | if type(cuts) is list: 179 | cuts = np.array(cuts) 180 | self.cuts = cuts 181 | self.idu = IdxDiscUnknownC(self.cuts) 182 | assert dtype is None, "Need `dtype` to be `None` for specified cuts" 183 | self._dtype = type(self.cuts[0]) 184 | self._dtype_init = self._dtype 185 | self._predefined_cuts = True 186 | else: 187 | self._cuts += 1 188 | 189 | def fit(self, durations, events): 190 | if self._predefined_cuts: 191 | warnings.warn("Calling fit method, when 'cuts' are already defined. Leaving cuts unchanged.") 192 | return self 193 | self._dtype = self._dtype_init 194 | if self._dtype is None: 195 | if isinstance(durations[0], np.floating): 196 | self._dtype = durations.dtype 197 | else: 198 | self._dtype = np.dtype('float64') 199 | durations = durations.astype(self._dtype) 200 | self.cuts = make_cuts(self._cuts, self._scheme, durations, events, self._min, self._dtype) 201 | self.duc = DiscretizeUnknownC(self.cuts, right_censor=True, censor_side='right') 202 | self.di = Duration2Idx(self.cuts) 203 | return self 204 | 205 | def fit_transform(self, durations, events): 206 | self.fit(durations, events) 207 | return self.transform(durations, events) 208 | 209 | def transform(self, durations, events): 210 | durations = _values_if_series(durations) 211 | durations = durations.astype(self._dtype) 212 | events = _values_if_series(events) 213 | dur_disc, events = self.duc.transform(durations, events) 214 | idx_durations = self.di.transform(dur_disc) 215 | cut_diff = np.diff(self.cuts) 216 | assert (cut_diff > 0).all(), 'Cuts are not unique.' 217 | t_frac = 1. - (dur_disc - durations) / cut_diff[idx_durations-1] 218 | if idx_durations.min() == 0: 219 | warnings.warn("""Got event/censoring at start time. Should be removed! It is set s.t. it has no contribution to loss.""") 220 | t_frac[idx_durations == 0] = 0 221 | events[idx_durations == 0] = 0 222 | idx_durations = idx_durations - 1 223 | return idx_durations.astype('int64'), events.astype('float32'), t_frac.astype('float32') 224 | 225 | @property 226 | def out_features(self): 227 | """Returns the number of output features that should be used in the torch model. 228 | 229 | Returns: 230 | [int] -- Number of output features. 231 | """ 232 | if self.cuts is None: 233 | raise ValueError("Need to call `fit` before this is accessible.") 234 | return len(self.cuts) - 1 235 | -------------------------------------------------------------------------------- /pycox/simulations/__init__.py: -------------------------------------------------------------------------------- 1 | from pycox.simulations import relative_risk, discrete_logit_hazard 2 | from pycox.simulations.relative_risk import (SimStudyLinearPH, SimStudyNonLinearPH, 3 | SimStudyNonLinearNonPH) 4 | from pycox.simulations.discrete_logit_hazard import (SimStudySACCensorConst, SimStudySACAdmin, 5 | SimStudySingleSurvUniformAdmin) 6 | -------------------------------------------------------------------------------- /pycox/simulations/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | def dict2df(data, add_true=True, add_censor_covs=False): 5 | """Make a pd.DataFrame from the dict obtained when simulating. 6 | 7 | Arguments: 8 | data {dict} -- Dict from simulation. 9 | 10 | Keyword Arguments: 11 | add_true {bool} -- If we should include the true duration and censoring times 12 | (default: {True}) 13 | add_censor_covs {bool} -- If we should include the censor covariates as covariates. 14 | (default: {False}) 15 | 16 | Returns: 17 | pd.DataFrame -- A DataFrame 18 | """ 19 | covs = data['covs'] 20 | if add_censor_covs: 21 | covs = np.concatenate([covs, data['censor_covs']], axis=1) 22 | df = (pd.DataFrame(covs, columns=[f"x{i}" for i in range(covs.shape[1])]) 23 | .assign(duration=data['durations'].astype('float32'), 24 | event=data['events'].astype('float32'))) 25 | if add_true: 26 | df = df.assign(duration_true=data['durations_true'].astype('float32'), 27 | event_true=data['events_true'].astype('float32'), 28 | censoring_true=data['censor_durations'].astype('float32')) 29 | return df 30 | 31 | 32 | class _SimBase: 33 | def simulate(self, n, surv_df=False): 34 | """Simulate dataset of size `n`. 35 | 36 | Arguments: 37 | n {int} -- Number of simulations 38 | 39 | Keyword Arguments: 40 | surv_df {bool} -- If a dataframe containing the survival function should be returned. 41 | (default: {False}) 42 | 43 | Returns: 44 | [dict] -- A dictionary with the results. 45 | """ 46 | raise NotImplementedError 47 | 48 | def surv_df(self, *args): 49 | """Returns a data frame containing the survival function. 50 | """ 51 | raise NotImplementedError 52 | -------------------------------------------------------------------------------- /pycox/simulations/relative_risk.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from pycox.simulations import base 5 | 6 | 7 | class _SimStudyRelativeRisk(base._SimBase): 8 | '''Abstract class for simulation relative risk survival data, 9 | with constant baseline, and constant censoring distribution 10 | 11 | Parameters: 12 | h0: Is baseline constant. 13 | right_c: Time for right censoring. 14 | c0: Constant censoring distribution 15 | ''' 16 | def __init__(self, h0, right_c=30., c0=30., surv_grid=None): 17 | self.h0 = h0 18 | self.right_c = right_c 19 | self.c0 = c0 20 | self.surv_grid = surv_grid 21 | 22 | def simulate(self, n, surv_df=False): 23 | covs = self.sample_covs(n).astype('float32') 24 | v = np.random.exponential(size=n) 25 | t = self.inv_cum_hazard(v, covs) 26 | c = self.c0 * np.random.exponential(size=n) 27 | tt = t.copy() 28 | tt[c < t] = c[c < t] 29 | tt[tt > self.right_c] = self.right_c 30 | d = tt == t 31 | surv_df = self.surv_df(covs, self.surv_grid) if surv_df else None 32 | # censor_surv_df = NotImplemented if censor_df else None 33 | return dict(covs=covs, durations=tt, events=d, surv_df=surv_df, durations_true=t, 34 | events_true=np.ones_like(t), censor_durations=c, 35 | censor_events=np.ones_like(c)) 36 | 37 | @staticmethod 38 | def sample_covs(n): 39 | raise NotImplementedError 40 | 41 | def inv_cum_hazard(self, v, covs): 42 | '''The inverse of the cumulative hazard.''' 43 | raise NotImplementedError 44 | 45 | def cum_hazard(self, t, covs): 46 | '''The the cumulative hazard function.''' 47 | raise NotImplementedError 48 | 49 | def survival_func(self, t, covs): 50 | '''Returns the survival function.''' 51 | return np.exp(-self.cum_hazard(t, covs)) 52 | 53 | def survival_grid_single(self, covs, t=None): 54 | covs = covs.reshape(1, -1) 55 | if t is None: 56 | t = np.arange(0, 31, 0.5) 57 | return pd.Series(self.survival_func(t, covs), index=t) 58 | 59 | def surv_df(self, covs, t=None): 60 | if t is None: 61 | t = np.linspace(0, 30, 100) 62 | s = [self.survival_grid_single(xx, t) for xx in covs] 63 | return pd.concat(s, axis=1) 64 | 65 | @staticmethod 66 | def dict2df(data, add_true=True): 67 | """Make a pd.DataFrame from the dict obtained when simulating. 68 | 69 | Arguments: 70 | data {dict} -- Dict from simulation. 71 | 72 | Keyword Arguments: 73 | add_true {bool} -- If we should include the true duration and censoring times 74 | (default: {True}) 75 | 76 | Returns: 77 | pd.DataFrame -- A DataFrame 78 | """ 79 | return base.dict2df(data, add_true) 80 | 81 | 82 | class SimStudyLinearPH(_SimStudyRelativeRisk): 83 | '''Survival simulations study for linear prop. hazard model 84 | h(t | x) = h0 exp[g(x)], where g(x) is linear. 85 | 86 | Parameters: 87 | h0: Is baseline constant. 88 | right_c: Time for right censoring. 89 | ''' 90 | def __init__(self, h0=0.1, right_c=30., c0=30., surv_grid=None): 91 | super().__init__(h0, right_c, c0, surv_grid) 92 | 93 | @staticmethod 94 | def sample_covs(n): 95 | return np.random.uniform(-1, 1, size=(n, 3)) 96 | 97 | @staticmethod 98 | def g(covs): 99 | x = covs 100 | x0, x1, x2 = x[:, 0], x[:, 1], x[:, 2] 101 | return 0.44 * x0 + 0.66 * x1 + 0.88 * x2 102 | 103 | def inv_cum_hazard(self, v, covs): 104 | '''The inverse of the cumulative hazard.''' 105 | return v / (self.h0 * np.exp(self.g(covs))) 106 | 107 | def cum_hazard(self, t, covs): 108 | '''The the cumulative hazard function.''' 109 | return self.h0 * t * np.exp(self.g(covs)) 110 | 111 | 112 | class SimStudyNonLinearPH(SimStudyLinearPH): 113 | '''Survival simulations study for non-linear prop. hazard model 114 | h(t | x) = h0 exp[g(x)], where g(x) is non-linear. 115 | 116 | Parameters: 117 | h0: Is baseline constant. 118 | right_c: Time for right censoring. 119 | ''' 120 | @staticmethod 121 | def g(covs): 122 | x = covs 123 | x0, x1, x2 = x[:, 0], x[:, 1], x[:, 2] 124 | beta = 2/3 125 | linear = SimStudyLinearPH.g(x) 126 | nonlinear = beta * (x0**2 + x2**2 + x0*x1 + x1*x2 + x1*x2) 127 | return linear + nonlinear 128 | 129 | 130 | class SimStudyNonLinearNonPH(SimStudyNonLinearPH): 131 | '''Survival simulations study for non-linear non-prop. hazard model. 132 | h(t | x) = h0 * exp[g(t, x)], 133 | with constant h_0, and g(t, x) = a(x) + b(x)*t. 134 | 135 | Cumulative hazard: 136 | H(t | x) = h0 / b(x) * exp[a(x)] * (exp[b(x) * t] - 1) 137 | Inverse: 138 | H^{-1}(v, x) = 1/b(x) log{1 + v * b(x) / h0 exp[-a(x)]} 139 | 140 | Parameters: 141 | h0: Is baseline constant. 142 | right_c: Time for right censoring. 143 | ''' 144 | def __init__(self, h0=0.02, right_c=30., c0=30., surv_grid=None): 145 | super().__init__(h0, right_c, c0, surv_grid) 146 | 147 | @staticmethod 148 | def a(x): 149 | _, _, x2 = x[:, 0], x[:, 1], x[:, 2] 150 | return np.sign(x2) + SimStudyNonLinearPH.g(x) 151 | 152 | @staticmethod 153 | def b(x): 154 | x0, x1, _ = x[:, 0], x[:, 1], x[:, 2] 155 | return np.abs(0.2 * (x0 + x1) + 0.5 * x0 * x1) 156 | 157 | @staticmethod 158 | def g(t, covs): 159 | x = covs 160 | return SimStudyNonLinearNonPH.a(x) + SimStudyNonLinearNonPH.b(x) * t 161 | 162 | def inv_cum_hazard(self, v, covs): 163 | x = covs 164 | return 1 / self.b(x) * np.log(1 + v * self.b(x) / self.h0 * np.exp(-self.a(x))) 165 | 166 | def cum_hazard(self, t, covs): 167 | x = covs 168 | return self.h0 / self.b(x) * np.exp(self.a(x)) * (np.exp(self.b(x)*t) - 1) 169 | -------------------------------------------------------------------------------- /pycox/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import pandas as pd 3 | import numpy as np 4 | import numba 5 | 6 | def idx_at_times(index_surv, times, steps='pre', assert_sorted=True): 7 | """Gives index of `index_surv` corresponding to `time`, i.e. 8 | `index_surv[idx_at_times(index_surv, times)]` give the values of `index_surv` 9 | closet to `times`. 10 | 11 | Arguments: 12 | index_surv {np.array} -- Durations of survival estimates 13 | times {np.array} -- Values one want to match to `index_surv` 14 | 15 | Keyword Arguments: 16 | steps {str} -- Round 'pre' (closest value higher) or 'post' 17 | (closest value lower) (default: {'pre'}) 18 | assert_sorted {bool} -- Assert that index_surv is monotone (default: {True}) 19 | 20 | Returns: 21 | np.array -- Index of `index_surv` that is closest to `times` 22 | """ 23 | if assert_sorted: 24 | assert pd.Series(index_surv).is_monotonic_increasing, "Need 'index_surv' to be monotonic increasing" 25 | if steps == 'pre': 26 | idx = np.searchsorted(index_surv, times) 27 | elif steps == 'post': 28 | idx = np.searchsorted(index_surv, times, side='right') - 1 29 | return idx.clip(0, len(index_surv)-1) 30 | 31 | @numba.njit 32 | def _group_loop(n, surv_idx, durations, events, di, ni): 33 | idx = 0 34 | for i in range(n): 35 | idx += durations[i] != surv_idx[idx] 36 | di[idx] += events[i] 37 | ni[idx] += 1 38 | return di, ni 39 | 40 | def kaplan_meier(durations, events, start_duration=0): 41 | """A very simple Kaplan-Meier fitter. For a more complete implementation 42 | see `lifelines`. 43 | 44 | Arguments: 45 | durations {np.array} -- durations array 46 | events {np.arrray} -- events array 0/1 47 | 48 | Keyword Arguments: 49 | start_duration {int} -- Time start as `start_duration`. (default: {0}) 50 | 51 | Returns: 52 | pd.Series -- Kaplan-Meier estimates. 53 | """ 54 | n = len(durations) 55 | assert n == len(events) 56 | if start_duration > durations.min(): 57 | warnings.warn(f"start_duration {start_duration} is larger than minimum duration {durations.min()}. " 58 | "If intentional, consider changing start_duration when calling kaplan_meier.") 59 | order = np.argsort(durations) 60 | durations = durations[order] 61 | events = events[order] 62 | surv_idx = np.unique(durations) 63 | ni = np.zeros(len(surv_idx), dtype='int') 64 | di = np.zeros_like(ni) 65 | di, ni = _group_loop(n, surv_idx, durations, events, di, ni) 66 | ni = n - ni.cumsum() 67 | ni[1:] = ni[:-1] 68 | ni[0] = n 69 | survive = 1 - di / ni 70 | zero_survive = survive == 0 71 | if zero_survive.any(): 72 | i = np.argmax(zero_survive) 73 | surv = np.zeros_like(survive) 74 | surv[:i] = np.exp(np.log(survive[:i]).cumsum()) 75 | # surv[i:] = surv[i-1] 76 | surv[i:] = 0. 77 | else: 78 | surv = np.exp(np.log(1 - di / ni).cumsum()) 79 | if start_duration < surv_idx.min(): 80 | tmp = np.ones(len(surv)+ 1, dtype=surv.dtype) 81 | tmp[1:] = surv 82 | surv = tmp 83 | tmp = np.zeros(len(surv_idx)+ 1, dtype=surv_idx.dtype) 84 | tmp[1:] = surv_idx 85 | surv_idx = tmp 86 | surv = pd.Series(surv, surv_idx) 87 | return surv 88 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest>=4.0.2 2 | lifelines>=0.22.8 3 | sklearn-pandas>=1.8.0 4 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.3.0 3 | commit = True 4 | tag = False 5 | 6 | [metadata] 7 | license_files = LICENSE 8 | 9 | [bumpversion:file:setup.py] 10 | search = version='{current_version}' 11 | replace = version='{new_version}' 12 | 13 | [bumpversion:file:pycox/__init__.py] 14 | search = __version__ = '{current_version}' 15 | replace = __version__ = '{new_version}' 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """The setup script.""" 5 | 6 | from setuptools import setup, find_packages 7 | 8 | 9 | long_description = """ 10 | **pycox** is a python package for survival analysis and time-to-event prediction with [PyTorch](https://pytorch.org/). 11 | It is built on the [torchtuples](https://github.com/havakv/torchtuples) package for training [PyTorch](https://pytorch.org/) models. 12 | 13 | Read the documentation at: https://github.com/havakv/pycox 14 | 15 | The package contains 16 | 17 | - survival models: (Logistic-Hazard, DeepHit, DeepSurv, Cox-Time, MTLR, etc.) 18 | - evaluation criteria (concordance, Brier score, Binomial log-likelihood, etc.) 19 | - event-time datasets (SUPPORT, METABRIC, KKBox, etc) 20 | - simulation studies 21 | - illustrative examples 22 | """ 23 | 24 | requirements = [ 25 | 'torchtuples>=0.2.0', 26 | 'feather-format>=0.4.0', 27 | 'h5py>=2.9.0', 28 | 'numba>=0.44', 29 | 'scikit-learn>=0.21.2', 30 | 'requests>=2.22.0', 31 | 'py7zr>=0.11.3', 32 | ] 33 | 34 | setup( 35 | name='pycox', 36 | version='0.3.0', 37 | description="Survival analysis with PyTorch", 38 | long_description=long_description, 39 | long_description_content_type='text/markdown', 40 | author="Haavard Kvamme", 41 | author_email='haavard.kvamme@gmail.com', 42 | url='https://github.com/havakv/pycox', 43 | packages=find_packages(), 44 | include_package_data=True, 45 | install_requires=requirements, 46 | license="BSD license", 47 | zip_safe=False, 48 | keywords='pycox', 49 | classifiers=[ 50 | 'Development Status :: 2 - Pre-Alpha', 51 | 'Intended Audience :: Developers', 52 | 'License :: OSI Approved :: BSD License', 53 | 'Natural Language :: English', 54 | ], 55 | python_requires='>=3.8' 56 | ) 57 | -------------------------------------------------------------------------------- /tests/evaluation/test_admin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from pycox.evaluation import admin 4 | from pycox.evaluation import EvalSurv 5 | 6 | 7 | def test_brier_score_no_censor(): 8 | n = 4 9 | durations = np.ones(n) * 50 10 | durations_c = np.ones_like(durations) * 100 11 | events = durations <= durations_c 12 | m = 5 13 | index_surv = np.array([0, 25., 50., 75., 100.]) 14 | 15 | surv_ones = np.ones((m, n)) 16 | time_grid = np.array([5., 40., 60., 100.]) 17 | bs = admin.brier_score(time_grid, durations, durations_c, events, surv_ones, index_surv) 18 | assert (bs == np.array([0., 0., 1., 1.])).all() 19 | surv_zeros = surv_ones * 0 20 | bs = admin.brier_score(time_grid, durations, durations_c, events, surv_zeros, index_surv) 21 | assert (bs == np.array([1., 1., 0., 0.])).all() 22 | surv_05 = surv_ones * 0.5 23 | bs = admin.brier_score(time_grid, durations, durations_c, events, surv_05, index_surv) 24 | assert (bs == np.array([0.25, 0.25, 0.25, 0.25])).all() 25 | time_grid = np.array([110.]) 26 | bs = admin.brier_score(time_grid, durations, durations_c, events, surv_05, index_surv) 27 | assert np.isnan(bs).all() 28 | 29 | def test_brier_score_censor(): 30 | n = 4 31 | durations = np.ones(n) * 50 32 | durations_c = np.array([25, 50, 60, 100]) 33 | events = durations <= durations_c 34 | durations[~events] = durations_c[~events] 35 | m = 5 36 | index_surv = np.array([0, 25., 50., 75., 100.]) 37 | 38 | surv = np.ones((m, n)) 39 | surv[:, 0] = 0 40 | time_grid = np.array([5., 25., 40., 60., 100.]) 41 | bs = admin.brier_score(time_grid, durations, durations_c, events, surv, index_surv) 42 | assert (bs == np.array([0.25, 0.25, 0., 1., 1.])).all() 43 | 44 | def test_brier_score_evalsurv(): 45 | n = 4 46 | durations = np.ones(n) * 50 47 | durations_c = np.array([25, 50, 60, 100]) 48 | events = durations <= durations_c 49 | durations[~events] = durations_c[~events] 50 | m = 5 51 | index_surv = np.array([0, 25., 50., 75., 100.]) 52 | 53 | surv = np.ones((m, n)) 54 | surv[:, 0] = 0 55 | surv = pd.DataFrame(surv, index_surv) 56 | time_grid = np.array([5., 25., 40., 60., 100.]) 57 | ev = EvalSurv(surv, durations, events, censor_durations=durations_c) 58 | bs = ev.brier_score_admin(time_grid) 59 | assert (bs.values == np.array([0.25, 0.25, 0., 1., 1.])).all() 60 | 61 | 62 | def test_binoial_log_likelihood_no_censor(): 63 | n = 4 64 | durations = np.ones(n) * 50 65 | durations_c = np.ones_like(durations) * 100 66 | events = durations <= durations_c 67 | m = 5 68 | index_surv = np.array([0, 25., 50., 75., 100.]) 69 | 70 | surv_ones = np.ones((m, n)) 71 | time_grid = np.array([5., 40., 60., 100.]) 72 | bll = admin.binomial_log_likelihood(time_grid, durations, durations_c, events, surv_ones, index_surv) 73 | eps = 1e-7 74 | assert abs(bll - np.log([1-eps, 1-eps, eps, eps])).max() < 1e-7 75 | surv_zeros = surv_ones * 0 76 | bll = admin.binomial_log_likelihood(time_grid, durations, durations_c, events, surv_zeros, index_surv) 77 | assert abs(bll - np.log([eps, eps, 1-eps, 1-eps])).max() < 1e-7 78 | surv_05 = surv_ones * 0.5 79 | bll = admin.binomial_log_likelihood(time_grid, durations, durations_c, events, surv_05, index_surv) 80 | assert abs(bll - np.log([0.5, 0.5, 1-0.5, 1-0.5])).max() < 1e-7 81 | time_grid = np.array([110.]) 82 | bll = admin.binomial_log_likelihood(time_grid, durations, durations_c, events, surv_05, index_surv) 83 | assert np.isnan(bll).all() 84 | -------------------------------------------------------------------------------- /tests/models/test_bce_surv.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pycox.models import BCESurv 3 | import torchtuples as tt 4 | 5 | from utils_model_testing import make_dataset, fit_model, assert_survs 6 | 7 | 8 | @pytest.mark.parametrize('numpy', [True, False]) 9 | @pytest.mark.parametrize('num_durations', [2, 5]) 10 | def test_pmf_runs(numpy, num_durations): 11 | data = make_dataset(True) 12 | input, target = data 13 | labtrans = BCESurv.label_transform(num_durations) 14 | target = labtrans.fit_transform(*target) 15 | data = tt.tuplefy(input, target) 16 | if not numpy: 17 | data = data.to_tensor() 18 | net = tt.practical.MLPVanilla(input.shape[1], [4], labtrans.out_features) 19 | model = BCESurv(net) 20 | fit_model(data, model) 21 | assert_survs(input, model) 22 | model.duration_index = labtrans.cuts 23 | assert_survs(input, model) 24 | cdi = model.interpolate(3, 'const_pdf') 25 | assert_survs(input, cdi) 26 | -------------------------------------------------------------------------------- /tests/models/test_cox.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torchtuples as tt 3 | from pycox.models import CoxPH 4 | from pycox.models.cox_time import MLPVanillaCoxTime 5 | 6 | from utils_model_testing import make_dataset, fit_model, assert_survs 7 | 8 | 9 | @pytest.mark.parametrize('numpy', [True, False]) 10 | def test_cox_cc_runs(numpy): 11 | data = make_dataset(False).apply(lambda x: x.float()).to_numpy() 12 | if not numpy: 13 | data = data.to_tensor() 14 | net = tt.practical.MLPVanilla(data[0].shape[1], [4], 1, False, output_bias=False) 15 | model = CoxPH(net) 16 | fit_model(data, model) 17 | model.compute_baseline_hazards() 18 | assert_survs(data[0], model) 19 | -------------------------------------------------------------------------------- /tests/models/test_cox_cc.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torchtuples as tt 3 | from pycox.models import CoxCC 4 | from pycox.models.cox_time import MLPVanillaCoxTime 5 | 6 | from utils_model_testing import make_dataset, fit_model, assert_survs 7 | 8 | 9 | @pytest.mark.parametrize('numpy', [True, False]) 10 | def test_cox_cc_runs(numpy): 11 | data = make_dataset(False).apply(lambda x: x.float()).to_numpy() 12 | if not numpy: 13 | data = data.to_tensor() 14 | net = tt.practical.MLPVanilla(data[0].shape[1], [4], 1, False, output_bias=False) 15 | model = CoxCC(net) 16 | fit_model(data, model) 17 | model.compute_baseline_hazards() 18 | assert_survs(data[0], model) 19 | -------------------------------------------------------------------------------- /tests/models/test_cox_time.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torchtuples as tt 3 | from pycox.models import CoxTime 4 | from pycox.models.cox_time import MLPVanillaCoxTime 5 | 6 | from utils_model_testing import make_dataset, fit_model, assert_survs 7 | 8 | 9 | @pytest.mark.parametrize('numpy', [True, False]) 10 | def test_cox_time_runs(numpy): 11 | input, target = make_dataset(False).apply(lambda x: x.float()).to_numpy() 12 | labtrans = CoxTime.label_transform() 13 | target = labtrans.fit_transform(*target) 14 | data = tt.tuplefy(input, target) 15 | if not numpy: 16 | data = data.to_tensor() 17 | net = MLPVanillaCoxTime(data[0].shape[1], [4], False) 18 | model = CoxTime(net) 19 | fit_model(data, model) 20 | model.compute_baseline_hazards() 21 | assert_survs(data[0], model, with_dl=False) 22 | -------------------------------------------------------------------------------- /tests/models/test_deephit.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pycox.models import DeepHitSingle 3 | import torchtuples as tt 4 | 5 | from utils_model_testing import make_dataset, fit_model, assert_survs 6 | 7 | 8 | @pytest.mark.parametrize('numpy', [True, False]) 9 | @pytest.mark.parametrize('num_durations', [2, 5]) 10 | def test_deep_hit_single_runs(numpy, num_durations): 11 | data = make_dataset(True) 12 | input, target = data 13 | labtrans = DeepHitSingle.label_transform(num_durations) 14 | target = labtrans.fit_transform(*target) 15 | data = tt.tuplefy(input, target) 16 | if not numpy: 17 | data = data.to_tensor() 18 | net = tt.practical.MLPVanilla(input.shape[1], [4], labtrans.out_features) 19 | model = DeepHitSingle(net) 20 | fit_model(data, model) 21 | assert_survs(input, model) 22 | model.duration_index = labtrans.cuts 23 | assert_survs(input, model) 24 | cdi = model.interpolate(3, 'const_pdf') 25 | assert_survs(input, cdi) 26 | -------------------------------------------------------------------------------- /tests/models/test_interpolation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from pycox import models 4 | from pycox.models.interpolation import InterpolateDiscrete 5 | 6 | 7 | class MockPMF(models.PMF): 8 | def __init__(self, duration_index=None): 9 | self.duration_index = duration_index 10 | 11 | def predict(self, input, *args, **kwargs): 12 | return input 13 | 14 | class MockLogisticHazard(models.LogisticHazard): 15 | def __init__(self, duration_index=None): 16 | self.duration_index = duration_index 17 | 18 | def predict(self, input, *args, **kwargs): 19 | return input 20 | 21 | class MockMTLR(models.MTLR): 22 | def __init__(self, duration_index=None): 23 | self.duration_index = duration_index 24 | 25 | def predict(self, input, *args, **kwargs): 26 | return input 27 | 28 | 29 | @pytest.mark.parametrize('m', [2, 5, 10]) 30 | @pytest.mark.parametrize('sub', [2, 5]) 31 | def test_pmf_cdi_equals_base(m, sub): 32 | torch.manual_seed(12345) 33 | n = 20 34 | idx = torch.randn(m).abs().sort()[0].numpy() 35 | input = torch.randn(n, m) 36 | model = MockPMF(idx) 37 | surv_pmf = model.interpolate(sub).predict_surv_df(input) 38 | surv_base = InterpolateDiscrete(model, duration_index=idx, sub=sub).predict_surv_df(input) 39 | assert (surv_pmf.index == surv_base.index).all() 40 | assert (surv_pmf - surv_base).abs().max().max() < 1e-7 41 | 42 | 43 | @pytest.mark.parametrize('m', [2, 5, 10]) 44 | @pytest.mark.parametrize('sub', [2, 5]) 45 | def test_base_values_at_knots(m, sub): 46 | torch.manual_seed(12345) 47 | n = 20 48 | idx = torch.randn(m).abs().sort()[0].numpy() 49 | input = torch.randn(n, m) 50 | model = MockPMF(idx) 51 | surv_cdi = InterpolateDiscrete(model, duration_index=idx, sub=sub).predict_surv_df(input) 52 | surv = model.predict_surv_df(input) 53 | diff = (surv - surv_cdi).dropna() 54 | assert diff.shape == surv.shape 55 | assert (diff == 0).all().all() 56 | 57 | @pytest.mark.parametrize('m', [2, 5, 10]) 58 | @pytest.mark.parametrize('sub', [2, 5]) 59 | def test_pmf_values_at_knots(m, sub): 60 | torch.manual_seed(12345) 61 | n = 20 62 | idx = torch.randn(m).abs().sort()[0].numpy() 63 | input = torch.randn(n, m) 64 | model = MockPMF(idx) 65 | surv = model.predict_surv_df(input) 66 | surv_cdi = model.interpolate(sub, 'const_pdf').predict_surv_df(input) 67 | diff = (surv - surv_cdi).dropna() 68 | assert diff.shape == surv.shape 69 | assert diff.max().max() < 1e-7 70 | 71 | @pytest.mark.parametrize('m', [2, 5, 10]) 72 | @pytest.mark.parametrize('sub', [2, 5]) 73 | def test_logistic_hazard_values_at_knots(m, sub): 74 | torch.manual_seed(12345) 75 | n = 20 76 | idx = torch.randn(m).abs().sort()[0].numpy() 77 | input = torch.randn(n, m) 78 | model = MockLogisticHazard(idx) 79 | surv = model.predict_surv_df(input) 80 | surv_cdi = model.interpolate(sub, 'const_pdf').predict_surv_df(input) 81 | diff = (surv - surv_cdi).dropna() 82 | assert diff.shape == surv.shape 83 | assert (diff == 0).all().all() 84 | surv_chi = model.interpolate(sub, 'const_hazard').predict_surv_df(input) 85 | diff = (surv - surv_chi).dropna() 86 | assert diff.shape == surv.shape 87 | assert (diff.index == surv.index).all() 88 | assert diff.max().max() < 1e-6 89 | 90 | @pytest.mark.parametrize('m', [2, 5, 10]) 91 | @pytest.mark.parametrize('sub', [2, 5]) 92 | def test_mtlr_values_at_knots(m, sub): 93 | torch.manual_seed(12345) 94 | n = 20 95 | idx = torch.randn(m).abs().sort()[0].numpy() 96 | input = torch.randn(n, m) 97 | model = MockMTLR(idx) 98 | surv = model.predict_surv_df(input) 99 | surv_cdi = model.interpolate(sub, 'const_pdf').predict_surv_df(input) 100 | diff = (surv - surv_cdi).dropna() 101 | assert diff.shape == surv.shape 102 | assert diff.max().max() < 1e-7 103 | -------------------------------------------------------------------------------- /tests/models/test_logistic_hazard.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pycox.models import LogisticHazard 3 | import torchtuples as tt 4 | 5 | from utils_model_testing import make_dataset, fit_model, assert_survs 6 | 7 | 8 | @pytest.mark.parametrize('numpy', [True, False]) 9 | @pytest.mark.parametrize('num_durations', [2, 5]) 10 | def test_logistic_hazard_runs(numpy, num_durations): 11 | data = make_dataset(True) 12 | input, target = data 13 | labtrans = LogisticHazard.label_transform(num_durations) 14 | target = labtrans.fit_transform(*target) 15 | data = tt.tuplefy(input, target) 16 | if not numpy: 17 | data = data.to_tensor() 18 | net = tt.practical.MLPVanilla(input.shape[1], [4], labtrans.out_features) 19 | model = LogisticHazard(net) 20 | fit_model(data, model) 21 | assert_survs(input, model) 22 | model.duration_index = labtrans.cuts 23 | assert_survs(input, model) 24 | cdi = model.interpolate(3, 'const_pdf') 25 | assert_survs(input, cdi) 26 | -------------------------------------------------------------------------------- /tests/models/test_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | import torchtuples as tt 5 | 6 | from pycox.models.data import pair_rank_mat 7 | from pycox.models import loss 8 | 9 | @pytest.mark.parametrize('seed', [0, 1, 2]) 10 | @pytest.mark.parametrize('m', [1, 5, 8]) 11 | def test_nll_pmf_cr_equals_nll_pmf(seed, m): 12 | torch.manual_seed(seed) 13 | # m = 5 14 | n_risk = 1 15 | rep = 7 16 | batch = m * (n_risk + 1) * rep 17 | phi = torch.randn(batch, n_risk, m) 18 | idx_duration = torch.arange(m).repeat(rep * (n_risk + 1)) 19 | events = torch.arange(n_risk + 1).repeat(m * rep) 20 | r1 = loss.nll_pmf_cr(phi, idx_duration, events) 21 | r2 = loss.nll_pmf(phi.view(batch * n_risk, -1), idx_duration, events.float()) 22 | assert (r1 - r2).abs() < 1e-5 23 | 24 | @pytest.mark.parametrize('seed', [0, 1, 2]) 25 | @pytest.mark.parametrize('m', [1, 5, 8]) 26 | @pytest.mark.parametrize('sigma', [0.1, 0.2, 1.]) 27 | def test_rank_loss_deephit_cr_equals_single(seed, m, sigma): 28 | torch.manual_seed(seed) 29 | n_risk = 1 30 | rep = 7 31 | batch = m * (n_risk + 1) * rep 32 | phi = torch.randn(batch, n_risk, m) 33 | idx_duration = torch.arange(m).repeat(rep * (n_risk + 1)) 34 | events = torch.arange(n_risk + 1).repeat(m * rep) 35 | rank_mat = pair_rank_mat(idx_duration.numpy(), events.numpy()) 36 | rank_mat = torch.tensor(rank_mat) 37 | r1 = loss.rank_loss_deephit_cr(phi, idx_duration, events, rank_mat, sigma) 38 | r2 = loss.rank_loss_deephit_single(phi.view(batch, -1), idx_duration, events.float(), 39 | rank_mat, sigma) 40 | assert (r1 - r2).abs() < 1e-6 41 | 42 | 43 | @pytest.mark.parametrize('seed', [0, 1]) 44 | @pytest.mark.parametrize('m', [1, 8]) 45 | @pytest.mark.parametrize('sigma', [0.1, 1.]) 46 | @pytest.mark.parametrize('alpha', [1, 0.5, 0.]) 47 | def test_loss_deephit_cr_equals_single(seed, m, sigma, alpha): 48 | torch.manual_seed(seed) 49 | n_risk = 1 50 | rep = 7 51 | batch = m * (n_risk + 1) * rep 52 | phi = torch.randn(batch, n_risk, m) 53 | idx_duration = torch.arange(m).repeat(rep * (n_risk + 1)) 54 | events = torch.arange(n_risk + 1).repeat(m * rep) 55 | rank_mat = pair_rank_mat(idx_duration.numpy(), events.numpy()) 56 | rank_mat = torch.tensor(rank_mat) 57 | loss_cr = loss.DeepHitLoss(alpha, sigma) 58 | loss_single = loss.DeepHitSingleLoss(alpha, sigma) 59 | r1 = loss_cr(phi, idx_duration, events, rank_mat) 60 | r2 = loss_single(phi.view(batch, -1), idx_duration, events.float(), rank_mat) 61 | assert (r1 - r2).abs() < 1e-5 62 | 63 | @pytest.mark.parametrize('seed', [0, 1]) 64 | @pytest.mark.parametrize('shrink', [0, 0.01, 1.]) 65 | def test_cox_cc_loss_single_ctrl(seed, shrink): 66 | np.random.seed(seed) 67 | n = 100 68 | case = np.random.uniform(-1, 1, n) 69 | ctrl = np.random.uniform(-1, 1, n) 70 | case, ctrl = tt.tuplefy(case, ctrl).to_tensor() 71 | loss_1 = loss.cox_cc_loss(case, (ctrl,), shrink) 72 | loss_2 = loss.cox_cc_loss_single_ctrl(case, ctrl, shrink) 73 | assert (loss_1 - loss_2).abs() < 1e-6 74 | 75 | @pytest.mark.parametrize('shrink', [0, 0.01, 1.]) 76 | def test_cox_cc_loss_single_ctrl_zero(shrink): 77 | n = 10 78 | case = ctrl = torch.zeros(n) 79 | loss_1 = loss.cox_cc_loss_single_ctrl(case, ctrl, shrink) 80 | val = torch.tensor(2.).log() 81 | assert (loss_1 - val).abs() == 0 82 | 83 | @pytest.mark.parametrize('shrink', [0, 0.01, 1.]) 84 | def test_cox_cc_loss_zero(shrink): 85 | n = 10 86 | case = ctrl = torch.zeros(n) 87 | loss_1 = loss.cox_cc_loss(case, (ctrl,), shrink) 88 | val = torch.tensor(2.).log() 89 | assert (loss_1 - val).abs() == 0 90 | 91 | def test_nll_mtlr_zero(): 92 | n_frac = 4 93 | m = 5 94 | n = m * n_frac 95 | phi = torch.zeros(n, m) 96 | idx_durations = torch.arange(m).repeat(n_frac) 97 | events = torch.ones_like(idx_durations).float() 98 | loss_pmf = loss.nll_pmf(phi, idx_durations, events) 99 | loss_mtlr = loss.nll_mtlr(phi, idx_durations, events) 100 | assert (loss_pmf - loss_mtlr).abs() == 0 101 | events = torch.zeros_like(events) 102 | loss_pmf = loss.nll_pmf(phi, idx_durations, events) 103 | loss_mtlr = loss.nll_mtlr(phi, idx_durations, events) 104 | assert (loss_pmf - loss_mtlr).abs() == 0 105 | -------------------------------------------------------------------------------- /tests/models/test_models_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import torch 4 | from pycox.models.utils import pad_col, make_subgrid, cumsum_reverse 5 | 6 | @pytest.mark.parametrize('val', [0, 1, 5]) 7 | def test_pad_col_start(val): 8 | x = torch.ones((2, 3)) 9 | x_pad = pad_col(x, val, where='start') 10 | pad = torch.ones(2, 1) * val 11 | assert (x_pad == torch.cat([pad, x], dim=1)).all() 12 | 13 | @pytest.mark.parametrize('val', [0, 1, 5]) 14 | def test_pad_col_end(val): 15 | x = torch.ones((2, 3)) 16 | x_pad = pad_col(x, val) 17 | pad = torch.ones(2, 1) * val 18 | assert (x_pad == torch.cat([x, pad], dim=1)).all() 19 | 20 | @pytest.mark.parametrize('n', [2, 13, 40]) 21 | def test_make_subgrid_1(n): 22 | grid = np.random.uniform(0, 100, n) 23 | grid = np.sort(grid) 24 | new_grid = make_subgrid(grid, 1) 25 | assert len(new_grid) == len(grid) 26 | assert (new_grid == grid).all() 27 | 28 | @pytest.mark.parametrize('sub', [2, 10, 20]) 29 | @pytest.mark.parametrize('start', [0, 2]) 30 | @pytest.mark.parametrize('stop', [4, 100]) 31 | @pytest.mark.parametrize('n', [5, 10]) 32 | def test_make_subgrid(sub, start, stop, n): 33 | grid = np.linspace(start, stop, n) 34 | new_grid = make_subgrid(grid, sub) 35 | true_new = np.linspace(start, stop, n*sub - (sub-1)) 36 | assert len(new_grid) == len(true_new) 37 | assert np.abs(true_new - new_grid).max() < 1e-13 38 | 39 | def test_cumsum_reverse_error_dim(): 40 | x = torch.randn((5, 3)) 41 | with pytest.raises(NotImplementedError): 42 | cumsum_reverse(x, dim=0) 43 | with pytest.raises(NotImplementedError): 44 | cumsum_reverse(x, dim=2) 45 | 46 | def test_cumsum_reverse_dim_1(): 47 | torch.manual_seed(1234) 48 | x = torch.randn(5, 16) 49 | res_np = x.numpy()[:, ::-1].cumsum(1)[:, ::-1] 50 | res = cumsum_reverse(x, dim=1) 51 | assert np.abs(res.numpy() - res_np).max() < 1e-6 52 | -------------------------------------------------------------------------------- /tests/models/test_mtlr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pycox.models import MTLR 3 | import torchtuples as tt 4 | 5 | from utils_model_testing import make_dataset, fit_model, assert_survs 6 | 7 | 8 | @pytest.mark.parametrize('numpy', [True, False]) 9 | @pytest.mark.parametrize('num_durations', [2, 5]) 10 | def test_mtlr_runs(numpy, num_durations): 11 | data = make_dataset(True) 12 | input, target = data 13 | labtrans = MTLR.label_transform(num_durations) 14 | target = labtrans.fit_transform(*target) 15 | data = tt.tuplefy(input, target) 16 | if not numpy: 17 | data = data.to_tensor() 18 | net = tt.practical.MLPVanilla(input.shape[1], [4], labtrans.out_features) 19 | model = MTLR(net) 20 | fit_model(data, model) 21 | assert_survs(input, model) 22 | model.duration_index = labtrans.cuts 23 | assert_survs(input, model) 24 | cdi = model.interpolate(3, 'const_pdf') 25 | assert_survs(input, cdi) 26 | -------------------------------------------------------------------------------- /tests/models/test_pc_hazard.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import numpy as np 4 | import torchtuples as tt 5 | from pycox.models import PCHazard 6 | 7 | from utils_model_testing import make_dataset, fit_model, assert_survs 8 | 9 | 10 | def _make_dataset(n, m): 11 | np.random.seed(0) 12 | x = np.random.normal(0, 1, (n, 4)).astype('float32') 13 | duration_index = np.arange(m+1).astype('int64') 14 | durations = np.repeat(duration_index, np.ceil(n / m))[:n] 15 | events = np.random.uniform(0, 1, n).round().astype('float32') 16 | fracs = np.random.uniform(0, 1, n).astype('float32') 17 | return x, (durations, events, fracs), duration_index 18 | 19 | @pytest.mark.parametrize('m', [5, 10]) 20 | @pytest.mark.parametrize('n_mul', [2, 3]) 21 | @pytest.mark.parametrize('mp', [1, 2, -1]) 22 | def test_wrong_net_output(m, n_mul, mp): 23 | n = m * n_mul 24 | inp, tar, dur_index = _make_dataset(n, m) 25 | net = torch.nn.Linear(inp.shape[1], m+1) 26 | with pytest.raises(ValueError): 27 | model = PCHazard(net, duration_index=dur_index) 28 | 29 | model = PCHazard(net) 30 | with pytest.raises(ValueError): 31 | model.fit(inp, tar) 32 | 33 | model.duration_index = dur_index 34 | with pytest.raises(ValueError): 35 | model.predict_surv_df(inp) 36 | 37 | model.duration_index = dur_index 38 | dl = model.make_dataloader((inp, tar), 5, True) 39 | with pytest.raises(ValueError): 40 | model.fit_dataloader(dl) 41 | 42 | @pytest.mark.parametrize('m', [5, 10]) 43 | @pytest.mark.parametrize('n_mul', [2, 3]) 44 | def test_right_net_output(m, n_mul): 45 | n = m * n_mul 46 | inp, tar, dur_index = _make_dataset(n, m) 47 | net = torch.nn.Linear(inp.shape[1], m) 48 | model = PCHazard(net) 49 | model = PCHazard(net, duration_index=dur_index) 50 | model.fit(inp, tar, verbose=False) 51 | model.predict_surv_df(inp) 52 | dl = model.make_dataloader((inp, tar), 5, True) 53 | model.fit_dataloader(dl) 54 | assert True 55 | 56 | @pytest.mark.parametrize('numpy', [True, False]) 57 | @pytest.mark.parametrize('num_durations', [3, 8]) 58 | def test_pc_hazard_runs(numpy, num_durations): 59 | data = make_dataset(True) 60 | input, (durations, events) = data 61 | durations += 1 62 | target = (durations, events) 63 | labtrans = PCHazard.label_transform(num_durations) 64 | target = labtrans.fit_transform(*target) 65 | data = tt.tuplefy(input, target) 66 | if not numpy: 67 | data = data.to_tensor() 68 | net = tt.practical.MLPVanilla(input.shape[1], [4], num_durations) 69 | model = PCHazard(net) 70 | fit_model(data, model) 71 | assert_survs(input, model) 72 | model.duration_index = labtrans.cuts 73 | assert_survs(input, model) 74 | -------------------------------------------------------------------------------- /tests/models/test_pmf.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pycox.models import PMF 3 | import torchtuples as tt 4 | 5 | from utils_model_testing import make_dataset, fit_model, assert_survs 6 | 7 | 8 | @pytest.mark.parametrize('numpy', [True, False]) 9 | @pytest.mark.parametrize('num_durations', [2, 5]) 10 | def test_pmf_runs(numpy, num_durations): 11 | data = make_dataset(True) 12 | input, target = data 13 | labtrans = PMF.label_transform(num_durations) 14 | target = labtrans.fit_transform(*target) 15 | data = tt.tuplefy(input, target) 16 | if not numpy: 17 | data = data.to_tensor() 18 | net = tt.practical.MLPVanilla(input.shape[1], [4], labtrans.out_features) 19 | model = PMF(net) 20 | fit_model(data, model) 21 | assert_survs(input, model) 22 | model.duration_index = labtrans.cuts 23 | assert_survs(input, model) 24 | cdi = model.interpolate(3, 'const_pdf') 25 | assert_survs(input, cdi) 26 | -------------------------------------------------------------------------------- /tests/models/utils_model_testing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import torchtuples as tt 5 | 6 | def make_dataset(numpy): 7 | n_events = 2 8 | n_frac = 4 9 | m = 10 10 | n = m * n_frac * n_events 11 | p = 5 12 | input = torch.randn((n, p)) 13 | durations = torch.arange(m).repeat(int(n / m)) 14 | events = torch.arange(n_events).repeat(int(n / n_events)).float() 15 | target = (durations, events) 16 | data = tt.tuplefy(input, target) 17 | if numpy: 18 | data = data.to_numpy() 19 | return data 20 | 21 | def fit_model(data, model): 22 | model.fit(*data, epochs=1, verbose=False, val_data=data) 23 | return model 24 | 25 | def assert_survs(input, model, with_dl=True): 26 | preds = model.predict_surv(input) 27 | assert type(preds) is type(input) 28 | assert preds.shape[0] == input.shape[0] 29 | surv_df = model.predict_surv_df(input) 30 | assert type(surv_df) is pd.DataFrame 31 | assert type(surv_df.values) is np.ndarray 32 | assert preds.shape[0] == surv_df.shape[1] 33 | assert preds.shape[1] == surv_df.shape[0] 34 | np_input = tt.tuplefy(input).to_numpy()[0] 35 | torch_input = tt.tuplefy(input).to_tensor()[0] 36 | np_preds = model.predict_surv(np_input) 37 | torch_preds = model.predict_surv(torch_input) 38 | assert (np_preds == torch_preds.cpu().numpy()).all() 39 | if with_dl: 40 | dl_input = tt.tuplefy(input).make_dataloader(512, False) 41 | dl_preds = model.predict_surv(dl_input) 42 | assert type(np_preds) is type(dl_preds), f"got {type(np_preds)}, and, {type(dl_preds)}" 43 | assert (np_preds == dl_preds).all() 44 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from pycox import utils 4 | 5 | def test_kaplan_meier(): 6 | durations = np.array([1., 1., 2., 3.]) 7 | events = np.array([1, 1, 1, 0]) 8 | surv = utils.kaplan_meier(durations, events) 9 | assert (surv.index.values == np.arange(4, dtype=float)).all() 10 | assert (surv.values == np.array([1., 0.5, 0.25, 0.25])).all() 11 | 12 | @pytest.mark.parametrize('n', [10, 85, 259]) 13 | @pytest.mark.parametrize('p_cens', [0, 0.3, 0.8]) 14 | def test_kaplan_meier_vs_lifelines(n, p_cens): 15 | from lifelines import KaplanMeierFitter 16 | np.random.seed(0) 17 | durations = np.random.uniform(0, 100, n) 18 | events = np.random.binomial(1, 1 - p_cens, n).astype('float') 19 | km = utils.kaplan_meier(durations, events) 20 | kmf = KaplanMeierFitter().fit(durations, events).survival_function_['KM_estimate'] 21 | assert km.shape == kmf.shape 22 | assert (km - kmf).abs().max() < 1e-14 23 | assert (km.index == kmf.index).all() 24 | --------------------------------------------------------------------------------