├── .gitignore ├── LICENSE ├── README.md ├── data └── kkbox_v1 │ ├── pycox_paper_figure4.png │ └── pycox_paper_table7.png ├── datasets └── kkbox_churn │ ├── basic_sessionization_ibis_duckdb_vs_polars.ipynb │ ├── csv_to_parquet.py │ ├── data_engineering.py │ ├── prepare_duckdb.py │ └── sessionization.py ├── demo_mlflow.py ├── model_selection ├── cross_validation.py └── wrappers.py ├── models ├── gradient_boosted_cif.py ├── kaplan_meier.py ├── kaplan_neighbors.py ├── kaplan_tree.py ├── meta_grid_bc.py ├── survival_mixin.py ├── tree_transformer.py └── yasgbt.py ├── notebooks ├── BrierScore.svg ├── censoring.png ├── kkbox_cv_benchmark.ipynb ├── kkbox_cv_yasgbt.ipynb ├── msk_mettropism.ipynb ├── truck_dataset.ipynb ├── truck_dataset.py ├── tutorial_part_1.ipynb ├── tutorial_part_1.py ├── tutorial_part_2.ipynb ├── tutorial_part_2.py └── variables.png └── plot ├── brier_score.py └── individuals.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # duckdb files 132 | *.duckdb 133 | *.duckdb.wal 134 | *.duckdb.tmp 135 | *.db 136 | *.db.wal 137 | *.db.tmp 138 | 139 | # data 140 | *.7z 141 | *.csv 142 | *.parquet 143 | *.npz 144 | 145 | .DS_Store 146 | .vscode 147 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Soda team @ Inria 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Benchmarking predictive survival analysis models 2 | 3 | This repository is dedicated to the evaluation of predictive survival models on large-ish datasets. 4 | 5 | ## Software dependencies 6 | 7 | To run the jupytercon-2023 tutorial notebooks, you will need: 8 | 9 | ``` 10 | conda create -n jupytercon-survival -c conda-forge python jupyterlab scikit-learn lifelines scikit-survival matplotlib-base plotly seaborn pandas pyarrow ibis-duckdb polars 11 | 12 | conda activate jupytercon-survival 13 | jupyter lab 14 | ``` 15 | 16 | ## Notebooks 17 | 18 | The `notebooks` folder holds the two main notebooks for the jupytercon-2023, namely: 19 | 20 | - `tutorial_part_1.ipynb` 21 | - `tutorial_part_2.ipynb` 22 | 23 | and the ancillary notebook used to generate the dataset used in "part 1", namely: 24 | 25 | - `truck_dataset.ipynb` 26 | 27 | Note that running `truck_dataset.ipynb` consumes a significant share of RAM, so you might prefer [downloading the datasets from zip](https://github.com/soda-inria/survival-analysis-benchmark/releases/download/jupytercon-2023-tutorial-data/truck_failure.zip) (500 MB) instead of generating them. 28 | 29 | The notebooks display our benchmark results and show how to use our wrappers to cross validate various models. 30 | 31 | - `kkbox_cv_benchmark.ipynb` 32 | 33 | Benchmark of the KKBox challenge inspired from the [pycox paper](https://jmlr.org/papers/volume20/18-424/18-424.pdf). 34 | - `msk_mettropism.ipynb` 35 | 36 | Exploration of the MSK cancer dataset and survival probability predictions using our models. 37 | 38 | ## Datasets 39 | 40 | ### WSDM - KKBox's Churn Prediction Challenge (from Kaggle) 41 | 42 | The `datasets/kkbox_churn` folder contains Python code to efficiently 43 | preprocess the raw transaction logs of the KKBox's Churn Prediction Challenge 44 | using [ibis](https://ibis-project.org) and [duckdb](https://duckdb.org). 45 | 46 | - https://www.kaggle.com/competitions/kkbox-churn-prediction-challenge/data 47 | 48 | The objectives are to: 49 | 50 | - make everything reproducible from the event-based logs; 51 | - implement efficient, parallel and out-of-core "sessionization" of the past 52 | transactions for all members: here is a "session" is an uninterrupted 53 | sequence of transactions; 54 | - implement efficient, parallel and out-of-core tabularization (feature 55 | and churn target with censoring); 56 | - make it possible to compute the cumulative state of the subscription data 57 | and the censored churn events at any point in time. 58 | 59 | Alternatively, `kkbox_cv_benchmark.ipynb` directly uses the preprocessing steps from pycox to ensure reproducibility, at the cost of memory and speed performances. 60 | 61 | ## Models 62 | 63 | This repository introduces a novel survival estimator named Gradient Boosting CIF. This estimator is based on the HistGradientBoostingClassifier of scikit-learn under the hood. 64 | 65 | It is named CIF because it has the capability to estimate cause-specific Cumulative Incidence Functions in a competing risks setting by minimizing a cause specific Integrated Brier Score (IBS) objective function: 66 | 67 | ```python 68 | from models.gradient_boosted_cif import GradientBoostedCIF 69 | 70 | X_train = np.array([ 71 | [5.0, 0.1, 2.0], 72 | [3.0, 1.1, 2.2], 73 | [2.0, 0.3, 1.1], 74 | [4.0, 1.0, 0.9], 75 | ]) 76 | y_train_multi_event = np.array([ 77 | (2, 33.2), 78 | (0, 10.1), 79 | (0, 50.0), 80 | (1, 20.0), 81 | ], dtype=[("event", np.bool_), ("duration", np.float64)] 82 | ) 83 | time_grid = np.linspace(0.0, 30.0, 10) 84 | 85 | gb_cif = GradientBoostedCIF(event_of_interest=1) 86 | gb_cif.fit(X_train, y_train_multi_event, time_grid) 87 | 88 | X_test = np.array([[3.0, 1.0, 9.0]]) 89 | cif_curves = gb_cif.predict_cumulative_incidence(X_test, time_grid) 90 | ``` 91 | 92 | Alternatively, you can estimate the probability of an event to be experienced at a specific time horizon: 93 | 94 | ```python 95 | gb_cif.predict_proba(X_test, time_horizon=20) 96 | ``` 97 | 98 | Conversely, you can estimate the conditional quantile time to event e.g. answering the question "at which time horizon does the CIF reach 50%?": 99 | 100 | ```python 101 | gb_cif.predict_quantile(X_test, quantile=0.5) 102 | ``` 103 | 104 | You can also estimate the survival function in the single event setting (binary event). Warning: this metric only makes sense when `y` is binary or when setting `event_of_interest='any'`. 105 | 106 | ```python 107 | y_train_single_event = np.array([ 108 | (1, 12.0), 109 | (0, 5.1), 110 | (1, 1.1), 111 | (0, 29.0), 112 | ], dtype=[("event", np.bool_), ("duration", np.float64)] 113 | ) 114 | gb_cif.fit(X_train, y_train_single_event, time_grid) 115 | survival_curves = gb_cif.predict_survival_function(X_test, time_grid) 116 | ``` 117 | -------------------------------------------------------------------------------- /data/kkbox_v1/pycox_paper_figure4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soda-inria/survival-analysis-benchmark/2633aa5e4e6433c73bb87e99a9f3aad1033726e7/data/kkbox_v1/pycox_paper_figure4.png -------------------------------------------------------------------------------- /data/kkbox_v1/pycox_paper_table7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soda-inria/survival-analysis-benchmark/2633aa5e4e6433c73bb87e99a9f3aad1033726e7/data/kkbox_v1/pycox_paper_table7.png -------------------------------------------------------------------------------- /datasets/kkbox_churn/csv_to_parquet.py: -------------------------------------------------------------------------------- 1 | import duckdb 2 | from pathlib import Path 3 | 4 | 5 | DATA_FOLDER = Path("~/data/kkbox-churn-prediction-challenge").expanduser() 6 | 7 | members_dtypes = { 8 | "msno": "varchar", 9 | "city": "int32", 10 | "bd": "int32", 11 | "gender": "varchar", 12 | "registered_via": "int32", 13 | "registration_init_time": "date", 14 | } 15 | 16 | transactions_dtypes = { 17 | "msno": "varchar", 18 | "payment_method_id": "int32", 19 | "payment_plan_days": "int32", 20 | "plan_list_price": "int32", 21 | "actual_amount_paid": "int32", 22 | "is_auto_renew": "boolean", 23 | "transaction_date": "date", 24 | "membership_expire_date": "date", 25 | "is_cancel": "boolean", 26 | } 27 | 28 | user_logs = { 29 | "msno": "varchar", 30 | "date": "date", 31 | "num_25": "int32", 32 | "num_50": "int32", 33 | "num_75": "int32", 34 | "num_985": "int32", 35 | "num_100": "int32", 36 | "num_unq": "int32", 37 | "total_secs": "double", 38 | } 39 | 40 | tables = { 41 | "members": { 42 | "csv_filename": "members_v3.csv", 43 | "parquet_filename": "members.parquet", 44 | "dtype": members_dtypes, 45 | }, 46 | "transactions": { 47 | "csv_filename": "transactions.csv", 48 | "parquet_filename": "transactions.parquet", 49 | "dtype": transactions_dtypes, 50 | }, 51 | "user_logs": { 52 | "csv_filename": "user_logs.csv", 53 | "parquet_filename": "user_logs.parquet", 54 | "dtype": user_logs, 55 | }, 56 | } 57 | 58 | conn = duckdb.connect() 59 | 60 | for table_name, table in tables.items(): 61 | csv_file = str(DATA_FOLDER / table["csv_filename"]) 62 | if Path(table["parquet_filename"]).exists(): 63 | print(f"Skipping {csv_file}...") 64 | continue 65 | 66 | print(f"Processing {csv_file}...") 67 | df = conn.read_csv( 68 | csv_file, 69 | header=True, 70 | dtype=table["dtype"], 71 | date_format="%Y%m%d", 72 | ) 73 | df.to_parquet(table["parquet_filename"]) 74 | print(f"Done writing {table['parquet_filename']}") 75 | -------------------------------------------------------------------------------- /datasets/kkbox_churn/data_engineering.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import ibis 3 | from ibis import _ as c # to avoid name confict with Jupyter's _ 4 | 5 | 6 | def add_resubscription( 7 | transactions, 8 | expire_threshold=ibis.interval(days=30), 9 | transaction_threshold=ibis.interval(days=60), 10 | ): 11 | """Flag transactions that are resubscriptions. 12 | 13 | Compute the numbers of days elapsed between a given transaction and a 14 | deadline computed from the informations in the previous transaction by the 15 | same member. 16 | 17 | The new resubscription deadline is set to 30 days after the expiration date 18 | of the previous transaction or 60 days after the date of the previous 19 | transaction, which ever is the latest. We need this double criterion 20 | because some expiration dates seems invalid (too short) and would cause the 21 | detection of too many registration events, even for users which have more 22 | than one transaction per month for several contiguous months. 23 | 24 | Note: this strategy to detect resubscriptions is simplistic because it does 25 | not handle cancellation events. 26 | """ 27 | t = transactions 28 | return ( 29 | t.group_by(t.msno) 30 | .order_by([t.transaction_date, t.membership_expire_date]) 31 | # XXX: we have to assign the lag variable as expression fields 32 | # otherwise we can get a CatalogException from duckdb... There are 33 | # similar constraints for subsequent operations in the chained calls to 34 | # mutate. 35 | # 36 | # TODO: craft a minimal reproducible example and report it to the ibis 37 | # issue tracker. 38 | .mutate( 39 | transaction_lag=c.transaction_date.lag(), 40 | expire_lag=c.membership_expire_date.lag(), 41 | ) 42 | .mutate( 43 | transaction_deadline=ibis.coalesce(c.transaction_lag, c.transaction_date) 44 | + transaction_threshold, 45 | expire_deadline=ibis.coalesce(c.expire_lag, c.membership_expire_date) 46 | + expire_threshold, 47 | ) 48 | .mutate( 49 | previous_deadline_date=ibis.greatest( 50 | c.transaction_deadline, 51 | c.expire_deadline, 52 | ) 53 | ) 54 | .mutate( 55 | elapsed_days_since_previous_deadline=ibis.greatest( 56 | c.transaction_date - c.previous_deadline_date, 0 57 | ), 58 | resubscription=t.transaction_date > c.previous_deadline_date, 59 | ) 60 | .drop( 61 | "transaction_lag", 62 | "expire_lag", 63 | "transaction_deadline", 64 | "expire_deadline", 65 | "previous_deadline_date", 66 | ) 67 | ) 68 | 69 | 70 | def add_subscription_id(expr): 71 | """Generate a distinct id for each subscription. 72 | 73 | The subscription id is based on the cumulative sum of past resubscription 74 | events. 75 | """ 76 | return ( 77 | expr.group_by(c.msno) 78 | .order_by([c.transaction_date, c.membership_expire_date]) 79 | .mutate( 80 | subscription_id=c.resubscription.cast("int").cumsum().cast("string"), 81 | ) 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | from pathlib import Path 87 | 88 | database = Path(__file__).parent / "kkbox.db" 89 | duckdb_conn = ibis.duckdb.connect(database=database, read_only=True) 90 | transactions = duckdb_conn.table("transactions") 91 | example_msno = "AHDfgFvwL4roCSwVdCbzjUfgUuibJHeMMl2Nx0UDdjI=" 92 | df = ( 93 | transactions.filter(c.msno == example_msno) 94 | .pipe(add_resubscription) 95 | .pipe(add_subscription_id) 96 | .order_by([c.transaction_date, c.membership_expire_date]) 97 | ).execute() 98 | print(df) -------------------------------------------------------------------------------- /datasets/kkbox_churn/prepare_duckdb.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | from pathlib import Path 3 | import ibis 4 | 5 | # Download the original data with the kaggle cli. First generate an API 6 | # token as described here: 7 | # 8 | # https://github.com/Kaggle/kaggle-api#api-credentials 9 | # 10 | # Then: 11 | # 12 | # pip install kaggle 13 | # kaggle competitions download kkbox-churn-prediction-challenge 14 | # 15 | # and decompress the .7z files using: 16 | # 17 | # mamba install p7zip 18 | # 7z x members_v3.csv.7z 19 | # 7z x transactions.csv.7z 20 | # 7z x user_logs.csv.7z 21 | 22 | KKBOX_DATA_FOLDER = Path("~/data/kkbox-churn-prediction-challenge").expanduser() 23 | 24 | table_name_to_csv_filename = { 25 | "members": "members_v3.csv", 26 | } 27 | 28 | table_schemas = { 29 | "members": ibis.Schema.from_tuples( 30 | [ 31 | ("msno", "string"), 32 | ("city", "int32"), 33 | ("bd", "int32"), 34 | ("gender", "string"), 35 | ("registered_via", "int32"), 36 | ("registration_init_time", "date"), 37 | ] 38 | ), 39 | "transactions": ibis.Schema.from_tuples( 40 | [ 41 | ("msno", "string"), 42 | ("payment_method_id", "int32"), 43 | ("payment_plan_days", "int32"), 44 | ("plan_list_price", "int32"), 45 | ("actual_amount_paid", "int32"), 46 | ("is_auto_renew", "boolean"), 47 | ("transaction_date", "date"), 48 | ("membership_expire_date", "date"), 49 | ("is_cancel", "boolean"), 50 | ] 51 | ), 52 | "user_logs": ibis.Schema.from_tuples( 53 | [ 54 | ("msno", "string"), 55 | ("date", "date"), 56 | ("num_25", "int32"), 57 | ("num_50", "int32"), 58 | ("num_75", "int32"), 59 | ("num_985", "int32"), 60 | ("num_100", "int32"), 61 | ("num_unq", "int32"), 62 | ("total_secs", "float64"), 63 | ] 64 | ), 65 | } 66 | 67 | connection = ibis.duckdb.connect(database="kkbox.db") 68 | existing_tables = connection.list_tables() 69 | for table_name, schema in table_schemas.items(): 70 | if table_name not in existing_tables: 71 | csv_filename = table_name_to_csv_filename.get(table_name, table_name + ".csv") 72 | print(f"Loading {csv_filename} to table {table_name}...") 73 | tic = perf_counter() 74 | # transaction table 75 | connection.create_table(table_name, schema=schema) 76 | # XXX: use ibis.duckdb.from_csv() instead? 77 | connection.raw_sql( 78 | f"COPY {table_name} FROM '{KKBOX_DATA_FOLDER}/{csv_filename}'" 79 | f" WITH (FORMAT CSV, HEADER TRUE, DATEFORMAT '%Y%m%d');" 80 | ) 81 | toc = perf_counter() 82 | table = connection.table(table_name) 83 | print( 84 | f"Imported {table.count().execute()} rows into " 85 | f"table '{table_name}' in {toc - tic:0.3f} seconds" 86 | ) 87 | else: 88 | print(f"Table '{table_name}' already exists") 89 | table = connection.table(table_name) 90 | parquet_filename = table_name + ".parquet" 91 | if not Path(parquet_filename).exists(): 92 | print(f"Writing {table_name} to {parquet_filename}...") 93 | tic = perf_counter() 94 | table.to_parquet(parquet_filename) 95 | toc = perf_counter() 96 | print(f"Wrote {table_name} to {parquet_filename} in {toc - tic:0.3f} seconds") 97 | -------------------------------------------------------------------------------- /datasets/kkbox_churn/sessionization.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import numpy as np 3 | import ibis 4 | from ibis import deferred as c # to avoid name confict with Jupyter's _ 5 | 6 | 7 | duckdb_conn = ibis.duckdb.connect(database="kkbox.db", read_only=True) 8 | transactions = duckdb_conn.table("transactions") 9 | 10 | 11 | # %% 12 | def add_resubscription_window( 13 | transactions, 14 | expire_threshold=ibis.interval(days=30), 15 | transaction_threshold=ibis.interval(days=60), 16 | ): 17 | """Flag transactions that are resubscriptions. 18 | 19 | Compute the numbers of days elapsed between a given transaction and a 20 | deadline computed from the informations in the previous transaction by the 21 | same member. 22 | 23 | The new resubscription deadline is set to 30 days after the expiration date 24 | of the previous transaction or 60 days after the date of the previous 25 | transaction, which ever is the latest. We need this double criterion 26 | because some expiration dates seems invalid (too short) and would cause the 27 | detection of too many registration events, even for users which have more 28 | than one transaction per month for several contiguous months. 29 | 30 | Note: this strategy to detect resubscriptions is simplistic because it does 31 | not handle cancellation events. 32 | """ 33 | w = ibis.trailing_window( 34 | group_by=c.msno, 35 | order_by=[c.transaction_date, c.membership_expire_date], 36 | preceding=1, 37 | ) 38 | previous_deadline_date = ibis.greatest( 39 | c.membership_expire_date.first().over(w) + expire_threshold, 40 | c.transaction_date.first().over(w) + transaction_threshold, 41 | ) 42 | current_transaction_date = c.transaction_date.last().over(w) 43 | resubscription = current_transaction_date > previous_deadline_date 44 | return transactions.mutate( 45 | elapsed_days_since_previous_deadline=ibis.greatest( 46 | current_transaction_date - previous_deadline_date, 0 47 | ), 48 | resubscription=resubscription, 49 | ) 50 | 51 | 52 | # %% 53 | 54 | 55 | def add_resubscription_groupby_lag( 56 | transactions, 57 | expire_threshold=ibis.interval(days=30), 58 | transaction_threshold=ibis.interval(days=60), 59 | ): 60 | """Flag transactions that are resubscriptions. 61 | 62 | Compute the numbers of days elapsed between a given transaction and a 63 | deadline computed from the informations in the previous transaction by the 64 | same member. 65 | 66 | The new resubscription deadline is set to 30 days after the expiration date 67 | of the previous transaction or 60 days after the date of the previous 68 | transaction, which ever is the latest. We need this double criterion 69 | because some expiration dates seems invalid (too short) and would cause the 70 | detection of too many registration events, even for users which have more 71 | than one transaction per month for several contiguous months. 72 | 73 | Note: this strategy to detect resubscriptions is simplistic because it does 74 | not handle cancellation events. 75 | """ 76 | t = transactions 77 | return ( 78 | t.group_by(t.msno) 79 | .order_by([t.transaction_date, t.membership_expire_date]) 80 | # XXX: we have to assign the lag variable as expression fields 81 | # otherwise we can get a CatalogException from duckdb... There are 82 | # similar constraints for subsequent operations in the chained calls to 83 | # mutate. 84 | # 85 | # TODO: craft a minimal reproducible example and report it to the ibis 86 | # issue tracker. 87 | .mutate( 88 | transaction_lag=c.transaction_date.lag(), 89 | expire_lag=c.membership_expire_date.lag(), 90 | ) 91 | .mutate( 92 | transaction_deadline=ibis.coalesce(c.transaction_lag, c.transaction_date) 93 | + transaction_threshold, 94 | expire_deadline=ibis.coalesce(c.expire_lag, c.membership_expire_date) 95 | + expire_threshold, 96 | ) 97 | .mutate( 98 | previous_deadline_date=ibis.greatest( 99 | c.transaction_deadline, 100 | c.expire_deadline, 101 | ) 102 | ) 103 | .mutate( 104 | elapsed_days_since_previous_deadline=ibis.greatest( 105 | c.transaction_date - c.previous_deadline_date, 0 106 | ), 107 | resubscription=t.transaction_date > c.previous_deadline_date, 108 | ) 109 | .drop( 110 | "transaction_lag", 111 | "expire_lag", 112 | "transaction_deadline", 113 | "expire_deadline", 114 | "previous_deadline_date", 115 | ) 116 | ) 117 | 118 | 119 | # %% 120 | example_msno = "+8ZA0rcIhautWvUbAM58/4jZUvhNA4tWMZKhPFdfquQ=" 121 | 122 | 123 | # %% 124 | # ( 125 | # transactions.filter(c.msno == example_msno) 126 | # .pipe(add_resubscription_groupby_lag) 127 | # .order_by([c.msno, c.transaction_date, c.membership_expire_date]) 128 | # ).execute() 129 | 130 | # %% 131 | ( 132 | transactions.filter(c.msno == example_msno) 133 | .pipe(add_resubscription_window) 134 | .order_by([c.msno, c.transaction_date, c.membership_expire_date]) 135 | ).execute() 136 | 137 | 138 | # %% 139 | def count_resubscriptions(expr): 140 | return expr.group_by(expr.msno).aggregate( 141 | n_resubscriptions=expr.resubscription.sum(), 142 | ) 143 | 144 | 145 | # %% 146 | # counts_groupby_lag = ( 147 | # transactions.pipe(add_resubscription_groupby_lag) 148 | # .pipe(count_resubscriptions) 149 | # .group_by(c.n_resubscriptions) 150 | # .aggregate( 151 | # n_members=c.msno.count(), 152 | # ) 153 | # ).execute() 154 | # counts_groupby_lag 155 | 156 | # %% 157 | counts_window = ( 158 | transactions.pipe(add_resubscription_window) 159 | .pipe(count_resubscriptions) 160 | .group_by(c.n_resubscriptions) 161 | .aggregate( 162 | n_members=c.msno.count(), 163 | ) 164 | ).execute() 165 | counts_window 166 | 167 | # %% 168 | # assert (counts_window == counts_groupby_lag).all().all() 169 | 170 | # %% 171 | # Both methods return the same results on duckdb and take the same 172 | # time. From now one use the group_by + lag variant since not 173 | # all the backends support the generic Window Function API 174 | add_resubscription = add_resubscription_window 175 | # add_resubscription = add_resubscription_groupby_lag 176 | 177 | 178 | # %% 179 | ( 180 | transactions.pipe(add_resubscription) 181 | .pipe(count_resubscriptions) 182 | .order_by([c.n_resubscriptions.desc(), c.msno]) 183 | .limit(10) 184 | ).execute() 185 | 186 | 187 | # %% 188 | def add_n_resubscriptions(expr): 189 | return expr.group_by(expr.msno).mutate( 190 | n_resubscriptions=c.resubscription.sum(), 191 | ) 192 | 193 | 194 | # %% 195 | example_msno = "AHDfgFvwL4roCSwVdCbzjUfgUuibJHeMMl2Nx0UDdjI=" 196 | ( 197 | transactions.filter(c.msno == example_msno) 198 | .pipe(add_resubscription) 199 | .pipe(add_n_resubscriptions) 200 | .order_by([c.transaction_date, c.membership_expire_date]) 201 | ).execute() 202 | 203 | 204 | # %% 205 | def add_subscription_id_window( 206 | expr, relative_to_epoch=False, epoch=ibis.date(2000, 1, 1) 207 | ): 208 | """Generate a distinct id for each subscription. 209 | 210 | The subscription id is based on the cumulative sum of past resubscription 211 | events. 212 | """ 213 | w = ibis.window( 214 | group_by=c.msno, 215 | order_by=[c.transaction_date, c.membership_expire_date], 216 | preceding=None, 217 | following=0, 218 | ) 219 | if relative_to_epoch: 220 | # use oldest transaction date as reference to make it possible 221 | # to generate a session id that can be computed in parallel 222 | # on partitions of the transaction logs. 223 | base = (c.transaction_date.first().over(w) - epoch).cast("string") 224 | counter = c.resubscription.sum().over(w).cast("string") 225 | subscription_id = base + "_" + counter 226 | else: 227 | subscription_id = c.resubscription.sum().over(w).cast("string") 228 | return expr.mutate( 229 | subscription_id=subscription_id, 230 | ) 231 | 232 | 233 | # %% 234 | def add_subscription_groupby_cumsum(expr): 235 | """Generate a distinct id for each subscription. 236 | 237 | The subscription id is based on the cumulative sum of past resubscription 238 | events. 239 | """ 240 | return ( 241 | expr.group_by(c.msno) 242 | .order_by([c.transaction_date, c.membership_expire_date]) 243 | .mutate( 244 | subscription_id=c.resubscription.cast("int").cumsum().cast("string"), 245 | ) 246 | ) 247 | 248 | 249 | # %% 250 | add_subscription_id = add_subscription_groupby_cumsum 251 | 252 | # %% 253 | example_msno = "AHDfgFvwL4roCSwVdCbzjUfgUuibJHeMMl2Nx0UDdjI=" 254 | ( 255 | transactions.filter(c.msno == example_msno) 256 | .pipe(add_resubscription) 257 | .pipe(add_n_resubscriptions) 258 | .pipe(add_subscription_id) 259 | .order_by([c.transaction_date, c.membership_expire_date]) 260 | ).execute() 261 | 262 | 263 | # %% 264 | def subsample_by_unique(expr, col_name="msno", size=1, seed=None): 265 | unique_col = expr[[col_name]].distinct() 266 | num_unique = unique_col.count().execute() 267 | assert size <= num_unique 268 | positional_values = unique_col.order_by(col_name)[ 269 | ibis.row_number().name("position"), col_name 270 | ] 271 | selected_indices = np.random.RandomState(seed).choice( 272 | num_unique, size=size, replace=False 273 | ) 274 | selected_rows = positional_values.filter( 275 | positional_values.position.isin(selected_indices) 276 | )[[col_name]] 277 | return expr.inner_join(selected_rows, col_name, suffixes=["", "_"]).select(expr) 278 | 279 | 280 | # %% 281 | ( 282 | transactions.pipe(subsample_by_unique, "msno", size=3, seed=0) 283 | .pipe(add_resubscription) 284 | .pipe(add_n_resubscriptions) 285 | .pipe(add_subscription_id) 286 | .order_by([c.msno, c.transaction_date, c.membership_expire_date]) 287 | ).execute() 288 | 289 | # %% 290 | ( 291 | transactions.pipe(subsample_by_unique, "msno", size=3, seed=0) 292 | .pipe(add_resubscription) 293 | .pipe(add_n_resubscriptions) 294 | .pipe(add_subscription_id) 295 | .order_by([c.msno, c.transaction_date, c.membership_expire_date]) 296 | ).execute() 297 | 298 | 299 | # %% 300 | from time import perf_counter 301 | 302 | 303 | def bench_sessionization(conn): 304 | tic = perf_counter() 305 | sessionized = ( 306 | conn.table("transactions") 307 | .pipe(add_resubscription) 308 | .pipe(add_n_resubscriptions) 309 | .pipe(add_subscription_id) 310 | .select(c.msno, c.subscription_id, c.n_resubscriptions, c.transaction_date) 311 | ) 312 | most_resubscribed = sessionized.msno.topk(5, by=sessionized.n_resubscriptions.max()) 313 | results = ( 314 | sessionized.semi_join(most_resubscribed, sessionized.msno == most_resubscribed.msno) 315 | .order_by( 316 | [ 317 | ibis.desc(c.transaction_date), 318 | c.msno, 319 | c.subscription_id, 320 | ] 321 | ) 322 | .execute() 323 | ) 324 | toc = perf_counter() 325 | print(f"Sessionization took {toc - tic:.1f} seconds") 326 | return results 327 | 328 | 329 | # %% 330 | bench_sessionization(duckdb_conn) 331 | 332 | # %% 333 | parquet_files = {"transactions": "transactions.parquet"} 334 | 335 | # %% 336 | duckdb_parquet_conn = ibis.duckdb.connect() 337 | for table_name, path in parquet_files.items(): 338 | duckdb_parquet_conn.register(path, table_name) 339 | 340 | bench_sessionization(duckdb_parquet_conn) 341 | 342 | # %% 343 | # XXX: pandas is quite slow: I never waited until the end. 344 | 345 | # import pandas as pd 346 | 347 | # pandas_conn = ibis.pandas.connect( 348 | # {k: pd.read_parquet(v) for k, v in parquet_files.items()} 349 | # ) 350 | # bench_sessionization(pandas_conn) 351 | 352 | # %% 353 | # XXX: dask does not support window functions 354 | # NotImplementedError: Window operations are unsupported in the dask backend 355 | # XXX: even the lag / cumsum variant raise NotImplementedError with dask 356 | 357 | # import dask.dataframe as dd 358 | # dask_conn = ibis.dask.connect( 359 | # {k: dd.read_parquet(v) for k, v in parquet_files.items()} 360 | # ) 361 | # bench_sessionization(dask_conn) 362 | 363 | # %% XXX: ibis' Arrow DataFusion backend does not translate Window ops, neight 364 | # with the generic Window API nor with the lag / cumsum variants: 365 | # OperationNotDefinedError: No translation rule for 367 | # datafusion_conn = ibis.datafusion.connect(parquet_files) 368 | # bench_sessionization(datafusion_conn) 369 | 370 | # %% 371 | # XXX: polars does not support window all window functions: 372 | # OperationNotDefinedError: No translation rule for 373 | # 374 | # polars_conn = ibis.polars.connect() 375 | # for table_name, path in parquet_files.items(): 376 | # polars_conn.register(path, table_name=table_name) 377 | 378 | # bench_sessionization(polars_conn) 379 | 380 | # %% 381 | # Note: to use clickhouse, one needs to first start the server with `clickhouse server`. 382 | 383 | # XXX: not possible to register parquet files, too bad. 384 | 385 | # clickouse_conn = ibis.clickhouse.connect() 386 | # for table_name, path in parquet_files.items(): 387 | # clickouse_conn.register(path, table_name=table_name) 388 | 389 | # %% 390 | # clickhouse_conn.raw_sql("DROP TABLE transactions") 391 | # CREATE_QUERY = """\ 392 | # CREATE TABLE transactions 393 | # ( 394 | # `msno` String, 395 | # `payment_method_id` Int8, 396 | # `payment_plan_days` Int16, 397 | # `plan_list_price` Int16, 398 | # `actual_amount_paid` Int16, 399 | # `is_auto_renew` Int8, 400 | # `transaction_date` Date, 401 | # `membership_expire_date` Date, 402 | # `is_cancel` Int8 403 | # ) 404 | # ENGINE = MergeTree 405 | # PARTITION BY toYYYYMM(transaction_date) 406 | # ORDER BY (msno, transaction_date, membership_expire_date) 407 | # """ 408 | # clickhouse_conn.raw_sql(CREATE_QUERY) 409 | # !cat "transactions.parquet" | clickhouse client --query="INSERT INTO transactions FORMAT Parquet" 410 | # %% 411 | clickhouse_conn = ibis.clickhouse.connect() 412 | bench_sessionization(clickhouse_conn) 413 | 414 | # %% 415 | -------------------------------------------------------------------------------- /demo_mlflow.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is a quick demo to leverage model versioning with MLFlow for production purposes. 3 | """ 4 | import os 5 | import numpy as np 6 | import pandas as pd 7 | import mlflow 8 | from mlflow.models.signature import infer_signature 9 | 10 | from sksurv.datasets import get_x_y 11 | from pycox.datasets import kkbox_v1 12 | 13 | from models.yasgbt import YASGBTClassifier 14 | from model_selection.cross_validation import ( 15 | run_cv, get_all_results, get_time_grid 16 | ) 17 | from plot.brier_score import plot_brier_scores 18 | from plot.individuals import plot_individuals_survival_curve 19 | 20 | 21 | assert os.getenv("MLFLOW_S3_ENDPOINT_URL"), "env variable MLFLOW_S3_ENDPOINT_URL must be set" 22 | assert os.getenv("PYCOX_DATA_DIR"), "env variable PYCOX_DATA_DIR must be set" 23 | 24 | def main(): 25 | 26 | # Download data if necessary and run preprocessing 27 | download_data() 28 | X, y = preprocess() 29 | 30 | # Define our model 31 | est_params = dict( 32 | sampling_strategy="uniform", 33 | n_iter=10, 34 | ) 35 | est = YASGBTClassifier(**est_params) 36 | est.name = "YASGBT-prod" 37 | 38 | # Run cross val 39 | run_cv(X, y, est, single_fold=True) 40 | 41 | # Fetch all scores 42 | df_tables, df_lines = get_all_results(match_filter=est.name) 43 | 44 | # Get the brier scores figure 45 | fig_bs = plot_brier_scores(df_lines) 46 | bs_filename = f"{est.name}_brier_score.png" 47 | 48 | # Get the individuals survival proba curve figure 49 | fig_indiv = plot_individuals_survival_curve(df_tables, df_lines, y) 50 | surv_curv_filename = f"{est.name}_surv_curv.png" 51 | 52 | scores = df_tables.to_dict("index")[0] 53 | metrics = dict( 54 | mean_ibs=float(scores["IBS"].split("±")[0]), 55 | mean_c_index=float(scores["C_td"].split("±")[0]), 56 | ) 57 | 58 | # Get our model signature for mlflow UI 59 | times = get_time_grid(y, y, n=100) 60 | X_sample = X.head() 61 | surv_probs = est.predict_survival_function(X_sample.values, times) 62 | signature = infer_signature(X_sample, surv_probs) 63 | 64 | # Register metrics and params to mlflow 65 | print(f"mlflow URI: {mlflow.get_tracking_uri()}") 66 | with mlflow.start_run(run_name="survival_demo"): 67 | mlflow.log_metrics(metrics) 68 | mlflow.log_params(est_params) 69 | mlflow.log_figure(fig_bs, artifact_file=bs_filename) 70 | mlflow.log_figure(fig_indiv, artifact_file=surv_curv_filename) 71 | mlflow.sklearn.log_model( 72 | est, 73 | artifact_path=est.name, 74 | signature=signature, 75 | ) 76 | 77 | 78 | def download_data(): 79 | 80 | kkbox_v1._path_dir.mkdir(exist_ok=True) 81 | 82 | train_file = kkbox_v1._path_dir / "train.csv" 83 | members_file = kkbox_v1._path_dir / "members_v3.csv" 84 | transactions_file = kkbox_v1._path_dir / "transactions.csv" 85 | 86 | any_prior_file_missing = ( 87 | not train_file.exists() 88 | or not members_file.exists() 89 | or not transactions_file.exists() 90 | ) 91 | 92 | covariate_file = kkbox_v1._path_dir / "covariates.feather" 93 | is_covariate_file_missing = not covariate_file.exists() 94 | 95 | if is_covariate_file_missing: 96 | print("Covariate file missing!") 97 | # We need to download any missing prior file 98 | # before producing the final covariate file. 99 | if any_prior_file_missing: 100 | print("Prior files missing!") 101 | kkbox_v1._setup_download_dir() 102 | kkbox_v1._7z_from_kaggle() 103 | 104 | kkbox_v1._csv_to_feather_with_types() 105 | kkbox_v1._make_survival_data() 106 | kkbox_v1._make_survival_covariates() 107 | kkbox_v1._make_train_test_split() 108 | 109 | 110 | def preprocess(): 111 | covariates = pd.read_feather(kkbox_v1._path_dir / "covariates.feather") 112 | covariates = extra_cleaning(covariates) 113 | X, y = get_x_y(covariates, ("event", "duration"), pos_label=1) 114 | return X, y 115 | 116 | 117 | def extra_cleaning(df): 118 | # remove id 119 | df.pop("msno") 120 | 121 | # ordinal encode gender 122 | df["gender"] = df["gender"].astype(str) 123 | gender_map = dict( 124 | zip(df["gender"].unique(), range(df["gender"].nunique())) 125 | ) 126 | df["gender"] = df["gender"].map(gender_map) 127 | 128 | # remove tricky np.nan in city, encoded as int 129 | df["city"] = df["city"].astype(str).replace("nan", -1).astype(int) 130 | 131 | # same for registered via 132 | df["registered_via"] = ( 133 | df["registered_via"] 134 | .astype(str) 135 | .replace("nan", -1) 136 | .astype(int) 137 | ) 138 | 139 | return df 140 | 141 | 142 | if __name__ == "__main__": 143 | main() -------------------------------------------------------------------------------- /model_selection/cross_validation.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | import pandas as pd 4 | import os 5 | from pathlib import Path 6 | import pickle 7 | from time import perf_counter 8 | 9 | from sklearn.model_selection import KFold, train_test_split 10 | 11 | from sksurv.metrics import ( 12 | brier_score, 13 | integrated_brier_score, 14 | cumulative_dynamic_auc, 15 | concordance_index_censored, 16 | concordance_index_ipcw, 17 | ) 18 | 19 | 20 | def run_cv( 21 | X, 22 | y, 23 | estimator, 24 | subsample_train=1.0, 25 | subsample_val=1.0, 26 | cv=None, 27 | single_fold=False, 28 | random_state=42, 29 | ): 30 | """Run cross validation and save score. 31 | 32 | Parameters 33 | ---------- 34 | estimator : 35 | Instance of an estimator to evaluate. 36 | 37 | kfold_tuples : tuple of tuple of ndarray, 38 | Data for training and evaluating. 39 | 40 | times : ndarray 41 | Timesteps used for evaluation. 42 | 43 | score_func : callable, default=None 44 | Score function that generates a dictionnary of metrics. 45 | If set to None, `run_cv` will call `get_score`. 46 | 47 | subset: dict, default=None 48 | The number of first rows to select for keys 'train' and 'val'. 49 | If `subset` is not None, one value must be set for each key. 50 | """ 51 | if isinstance(X, pd.DataFrame): 52 | X = X.values 53 | 54 | cv_scores = [] 55 | cv = cv or KFold(shuffle=True, random_state=random_state) 56 | 57 | for train_idxs, val_idxs in cv.split(X): 58 | 59 | size_train = int(subsample_train * len(train_idxs)) 60 | size_val = int(subsample_val * len(val_idxs)) 61 | 62 | train_idxs = train_idxs[:size_train] 63 | val_idxs = val_idxs[:size_val] 64 | 65 | X_train, y_train = X[train_idxs], y[train_idxs] 66 | X_val, y_val = X[val_idxs], y[val_idxs] 67 | 68 | X_val, y_val = truncate_val_to_train(y_train, X_val, y_val) 69 | 70 | print(f"train set: {X_train.shape[0]}, val set: {X_val.shape[0]}") 71 | 72 | # Generate evaluation time steps as a subset of the train and val durations. 73 | times = get_time_grid(y_train, y_val) 74 | 75 | t0 = perf_counter() 76 | estimator.fit(X_train, y_train, times) 77 | t1 = perf_counter() 78 | 79 | scores = get_scores(estimator, y_train, X_val, y_val, times) 80 | scores["training_duration"] = t1 - t0 81 | cv_scores.append(scores) 82 | 83 | if single_fold: 84 | break 85 | 86 | print("-" * 8) 87 | results = {} 88 | 89 | # sufficient statistics 90 | for k in [ 91 | "ibs", "c_index", "training_duration", "prediction_duration" 92 | ]: 93 | score_mean = np.mean([score[k] for score in cv_scores]) 94 | score_std = np.std([score[k] for score in cv_scores]) 95 | results[f"mean_{k}"] = score_mean 96 | results[f"std_{k}"] = score_std 97 | print(f"{k}: {score_mean:.4f} ± {score_std:.4f}") 98 | 99 | # vectors 100 | for k in ["times", "survival_probs", "brier_scores"]: 101 | results[k] = [score[k] for score in cv_scores] 102 | 103 | results["n_sample_train"] = len(train_idxs) 104 | results["n_sample_val"] = len(val_idxs) 105 | 106 | save_scores(estimator.name, results) 107 | 108 | 109 | def survival_to_risk_estimate(survival_probs): 110 | return -np.log(survival_probs + 1e-8).sum(axis=1) 111 | 112 | 113 | def get_scores( 114 | estimator, 115 | y_train, 116 | X_val, 117 | y_val, 118 | times, 119 | use_cindex_ipcw=False, 120 | ): 121 | 122 | t0 = perf_counter() 123 | survival_probs = estimator.predict_survival_function(X_val, times) 124 | t1 = perf_counter() 125 | 126 | risk_estimate = survival_to_risk_estimate(survival_probs) 127 | 128 | _, brier_scores = brier_score(y_train, y_val, survival_probs, times) 129 | ibs = integrated_brier_score(y_train, y_val, survival_probs, times) 130 | 131 | # As the C-index is expensive to compute, we only consider a subsample of our data. 132 | N_sample_c_index = 50_000 133 | c_index = concordance_index_censored( 134 | y_val["event"][:N_sample_c_index], 135 | y_val["duration"][:N_sample_c_index], 136 | risk_estimate[:N_sample_c_index], 137 | )[0] 138 | 139 | results = dict( 140 | brier_scores=brier_scores, 141 | ibs=ibs, 142 | times=times, 143 | survival_probs=survival_probs, 144 | c_index=c_index, 145 | prediction_duration=t1 - t0, 146 | ) 147 | 148 | if use_cindex_ipcw: 149 | c_index_ipcw = concordance_index_ipcw( 150 | y_train[:N_sample_c_index], 151 | y_val[:N_sample_c_index], 152 | risk_estimate[:N_sample_c_index], 153 | )[0] 154 | results["c_index_ipcw"] = c_index_ipcw 155 | 156 | return results 157 | 158 | 159 | def truncate_val_to_train(y_train, X_val, y_val): 160 | """Enforce y_val to stay below y_train upper bound""" 161 | out_of_bound_mask = y_train["duration"].max() <= y_val["duration"] 162 | return X_val[~out_of_bound_mask, :], y_val[~out_of_bound_mask] 163 | 164 | 165 | def get_time_grid(y_train, y_val, n=100): 166 | y_time = np.hstack([y_train["duration"], y_val["duration"]]) 167 | lower, upper = np.percentile(y_time, [2.5, 97.5]) 168 | return np.linspace(lower, upper, n) 169 | 170 | 171 | def save_scores(name, scores, create_dir=True): 172 | path_results = get_path_results() 173 | if create_dir: 174 | path_results.mkdir(exist_ok=True, parents=True) 175 | path = path_results / f"{name}.pkl" 176 | 177 | pickle.dump(scores, open(path, "wb+")) 178 | 179 | 180 | def load_scores(name): 181 | path_results = get_path_results() 182 | path = path_results / f"{name}.pkl" 183 | return pickle.load(open(path, "rb")) 184 | 185 | 186 | def get_path_results(): 187 | return Path(os.getenv("PYCOX_DATA_DIR")) / "kkbox_v1" / "results" 188 | 189 | 190 | def get_all_results(match_filter: str = None): 191 | """Load all results matching `match_filter`, concatenate sufficient 192 | statistics in `df_tables` and times, survival probs and 193 | brier scores vectors into `df_lines`. 194 | """ 195 | path_results = get_path_results() 196 | lines, tables = [], [] 197 | match_filter = match_filter or "" 198 | 199 | for path in path_results.iterdir(): 200 | if ( 201 | path.is_file() 202 | and path.suffix == ".pkl" 203 | and match_filter in str(path) 204 | ): 205 | result = pickle.load(open(path, "rb")) 206 | model_name = path.name.split(".")[0].split("_")[-1] 207 | 208 | line = make_row_line(result, model_name) 209 | table = make_row_table(result, model_name) 210 | 211 | lines.append(line) 212 | tables.append(table) 213 | 214 | df_tables = pd.DataFrame(tables) 215 | df_lines = pd.DataFrame(lines) 216 | 217 | # sort by ibs 218 | df_tables["ibs_tmp"] = ( 219 | df_tables["IBS"].str.split("±") 220 | .str[0] 221 | .astype(np.float64) 222 | ) 223 | df_tables = ( 224 | df_tables.sort_values("ibs_tmp") 225 | .reset_index(drop=True) 226 | .drop(["ibs_tmp"], axis=1) 227 | ) 228 | 229 | return df_tables, df_lines 230 | 231 | 232 | def make_row_line(result, model_name): 233 | """Format the results of a single model into a row with 234 | times, survival probs and brier scores vectors for visualization 235 | purposes. 236 | """ 237 | # times are the same across all folds, output shape: (times) 238 | times = result["times"][0] 239 | 240 | # take the mean for each folds, output shape: (times) 241 | brier_scores = np.asarray(result["brier_scores"]).mean(axis=0) 242 | 243 | # arbitrarily take the first cross val to vizualize surv probs 244 | # output shape: (n_val, times) 245 | survival_probs = np.asarray(result["survival_probs"][0]) 246 | 247 | return dict( 248 | model=model_name, 249 | times=times, 250 | brier_scores=brier_scores, 251 | survival_probs=survival_probs, 252 | ) 253 | 254 | 255 | def make_row_table(result, model_name): 256 | """Format the results of a single model into a row with 257 | sufficient statistics. 258 | """ 259 | row = {"Method": model_name} 260 | 261 | col_displayed = {"c_index": "C_td", "ibs": "IBS"} 262 | for col in ["c_index", "ibs"]: 263 | mean_col, std_col = f"mean_{col}", f"std_{col}" 264 | row[col_displayed[col]] = f"{result[mean_col]:.4f} ± {result[std_col]:.4f}" 265 | 266 | for col in ["training_duration", "prediction_duration"]: 267 | mean_col = f"mean_{col}" 268 | row[col] = f"{result[mean_col]:.4f}s" 269 | 270 | for col in ["n_sample_train", "n_sample_val"]: 271 | row[col] = result[col] 272 | 273 | return row -------------------------------------------------------------------------------- /model_selection/wrappers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import numpy as np 3 | 4 | from models.survival_mixin import SurvivalMixin 5 | 6 | 7 | class BaseWrapper(ABC, SurvivalMixin): 8 | 9 | def __init__(self, estimator, name=None, fit_kwargs=None): 10 | self.estimator = estimator 11 | self.name = name or estimator.__class__.__name__ 12 | self.fit_kwargs = fit_kwargs or dict() 13 | 14 | @abstractmethod 15 | def fit(self, X_train, y_train, times=None): 16 | pass 17 | 18 | @abstractmethod 19 | def predict_survival_function(self, X_test, times): 20 | pass 21 | 22 | 23 | class PipelineWrapper(BaseWrapper): 24 | 25 | def fit(self, X_train, y_train, times=None): 26 | last_est_name = self.estimator.steps[-1][0] 27 | times_kwargs = {f"{last_est_name}__times": times} 28 | self.estimator.fit(X_train, y_train, **times_kwargs) 29 | 30 | def predict_survival_function(self, X_test, times=None): 31 | return self.estimator.predict_survival_function(X_test, times=times) 32 | 33 | def predict_cumulative_incidence(self, X_test, times=None): 34 | transformers = self.estimator[:-1] 35 | X_test = transformers.transform(X_test) 36 | estimator = self.estimator[-1] 37 | return estimator.predict_cumulative_incidence(X_test, times=times) 38 | 39 | def predict_quantile(self, X_test, quantile=0.5, times=None): 40 | transformers = self.estimator[:-1] 41 | X_test = transformers.transform(X_test) 42 | estimator = self.estimator[-1] 43 | return estimator.predict_quantile(X_test, quantile=quantile, times=times) 44 | 45 | def predict_proba(self, X_test, time_horizon=None): 46 | transformers = self.estimator[:-1] 47 | X_test = transformers.transform(X_test) 48 | estimator = self.estimator[-1] 49 | return estimator.predict_proba(X_test, time_horizon=time_horizon) 50 | 51 | 52 | class SkurvWrapper(BaseWrapper): 53 | 54 | def fit(self, X_train, y_train, times=None): 55 | self.estimator.fit(X_train, y_train) 56 | 57 | def predict_survival_function(self, X_test, times): 58 | step_funcs = self.estimator.predict_survival_function(X_test, return_array=False) 59 | survival_probs = np.vstack([step_func(times) for step_func in step_funcs]) 60 | self.survival_probs_ = survival_probs 61 | return survival_probs 62 | 63 | 64 | class XGBSEWrapper(BaseWrapper): 65 | 66 | def fit(self, X_train, y_train, times=None): 67 | self.estimator.fit(X_train, y_train, time_bins=times, **self.fit_kwargs) 68 | 69 | def predict_survival_function(self, X_test, times=None): 70 | survival_probs = self.estimator.predict(X_test, return_interval_probs=False) 71 | self.survival_probs_ = survival_probs 72 | return survival_probs 73 | 74 | 75 | class DeepHitWrapper(BaseWrapper): 76 | 77 | def fit(self, X_train, y_train, times=None): 78 | y_train_ = self.adapt_y(y_train) 79 | self.estimator.fit(X_train, y_train_, **self.fit_kwargs) 80 | 81 | def predict_survival_function(self, X_test, times=None): 82 | # StandardScaler 83 | X_test_trans = self.estimator[0].transform(X_test) 84 | 85 | # DeepHitestimator 86 | survival_probs = self.estimator[1].predict_surv_df(X_test_trans) 87 | time_index = np.asarray(survival_probs.index) 88 | survival_probs = survival_probs.reset_index(drop=True).T 89 | survival_probs.columns = time_index 90 | self.survival_probs_ = survival_probs 91 | 92 | return survival_probs 93 | 94 | def adapt_y(self, y): 95 | if not isinstance(y, tuple): 96 | y = (y["duration"], y["event"]) 97 | y = ( 98 | np.ascontiguousarray(y[0], dtype=int), 99 | np.ascontiguousarray(y[1], dtype=np.float32), 100 | ) 101 | return y 102 | -------------------------------------------------------------------------------- /models/gradient_boosted_cif.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | from sklearn.base import BaseEstimator, ClassifierMixin 6 | from sklearn.utils.validation import check_random_state 7 | from sklearn.ensemble import HistGradientBoostingRegressor 8 | from sklearn.ensemble import HistGradientBoostingClassifier 9 | 10 | from sksurv.functions import StepFunction 11 | from sksurv.nonparametric import kaplan_meier_estimator 12 | from sksurv.nonparametric import CensoringDistributionEstimator 13 | 14 | from models.survival_mixin import SurvivalMixin 15 | from model_selection.cross_validation import get_time_grid 16 | 17 | 18 | def cif_brier_score( 19 | y_train, 20 | y_test, 21 | cif_pred, 22 | times, 23 | event_of_interest="any", 24 | ): 25 | # XXX: make times an optional kwarg to be compatible with 26 | # sksurv.metrics.brier_score? 27 | ibsts = IBSTrainingSampler( 28 | y_train, 29 | event_of_interest=event_of_interest, 30 | ) 31 | return times, ibsts.brier_score(y_test, cif_pred, times) 32 | 33 | 34 | def cif_integrated_brier_score( 35 | y_train, 36 | y_test, 37 | cif_pred, 38 | times, 39 | event_of_interest="any", 40 | ): 41 | times, brier_scores = cif_brier_score( 42 | y_train, 43 | y_test, 44 | cif_pred, 45 | times, 46 | event_of_interest=event_of_interest, 47 | ) 48 | return np.trapz(brier_scores, times) / (times[-1] - times[0]) 49 | 50 | 51 | class IBSTrainingSampler: 52 | # XXX: this class can be used both to sample and compute (I)BS terms 53 | # we need a better name / separation of concerns. 54 | 55 | def __init__( 56 | self, 57 | y_train, 58 | event_of_interest="any", 59 | random_state=None, 60 | min_censoring_prob=1e-30, # XXX: study the effect and set a better default 61 | ): 62 | if event_of_interest != "any" and event_of_interest < 1: 63 | raise ValueError( 64 | f"event_of_interest must be a strictly positive integer or 'any', " 65 | f"got: event_of_interest={self.event_of_interest:!r}" 66 | ) 67 | self.y_train = y_train 68 | self.y_train_any_event = self._any_event(y_train) 69 | self.event_of_interest = event_of_interest 70 | self.min_censoring_prob = min_censoring_prob 71 | self.rng = check_random_state(random_state) 72 | 73 | # Estimate the censoring distribution from on the training set using Kaplan-Meier. 74 | self.censoring_dist = CensoringDistributionEstimator().fit(self.y_train_any_event) 75 | 76 | # Precompute the censoring probabilities at the time of the events on the 77 | # training set: 78 | censoring_prob_y_train = self.censoring_dist.predict_proba(y_train["duration"]) 79 | censoring_prob_y_train = np.clip(censoring_prob_y_train, self.min_censoring_prob, 1) 80 | self.censoring_prob_y_train = censoring_prob_y_train 81 | 82 | def _any_event(self, y): 83 | y_any_event = np.empty( 84 | y.shape[0], 85 | dtype=[("event", bool), ("duration", float)], 86 | ) 87 | y_any_event["event"] = (y["event"] > 0) 88 | y_any_event["duration"] = y["duration"] 89 | return y_any_event 90 | 91 | def _ibs_components(self, y, times, censoring_prob_y=None): 92 | if self.event_of_interest == "any": 93 | # y should already be provided as binary indicator 94 | k = 1 95 | else: 96 | k = self.event_of_interest 97 | 98 | # Specify the binary classification target for each record in y and 99 | # a reference time horizon: 100 | # 101 | # - 1 when event of interest was observed before the reference time 102 | # horizon, 103 | # 104 | # - 0 otherwise: any other event happening at any time, censored record 105 | # or event of interest happening after the reference time horizon. 106 | # 107 | # Note: censored events only contribute (as negative target) when 108 | # their duration is larger than the reference target horizon. 109 | # Otherwise, they are discarded by setting their weight to 0 in the 110 | # following. 111 | 112 | y_binary = np.zeros(y.shape[0], dtype=np.int32) 113 | y_binary[(y["event"] == k) & (y["duration"] <= times)] = 1 114 | 115 | # Compute the weights for each term contributing to the Brier score 116 | # at the specified time horizons. 117 | # 118 | # - error of a prediction for a time horizon before the occurence of an 119 | # event (either censored or uncensored) is weighted by the inverse 120 | # probability of censoring at that time horizon. 121 | # 122 | # - error of a prediction for a time horizon after the any observed event 123 | # is weighted by inverse censoring probability at the actual time 124 | # of the observed event. 125 | # 126 | # - "error" of a prediction for a time horizon after a censored event has 127 | # 0 weight and do not contribute to the Brier score computation. 128 | 129 | # Estimate the probability of censoring at current time point t. 130 | censoring_prob_t = self.censoring_dist.predict_proba(times) 131 | censoring_prob_t = np.clip(censoring_prob_t, self.min_censoring_prob, 1) 132 | before = times < y["duration"] 133 | weights = np.where(before, 1 / censoring_prob_t, 0) 134 | 135 | after_any_observed_event = (y["event"] > 0) & (times >= y["duration"]) 136 | if censoring_prob_y is None: 137 | censoring_prob_y = self.censoring_dist.predict_proba(y["duration"]) 138 | censoring_prob_y = np.clip(censoring_prob_y, self.min_censoring_prob, 1) 139 | weights = np.where(after_any_observed_event, 1 / censoring_prob_y, weights) 140 | 141 | return y_binary, weights 142 | 143 | def brier_score(self, y_true, y_pred, times): 144 | 145 | if self.event_of_interest == "any": 146 | if y_true is self.y_train: 147 | y_true = self.y_train_any_event 148 | else: 149 | y_true = self._any_event(y_true) 150 | 151 | n_samples = y_true.shape[0] 152 | n_time_steps = times.shape[0] 153 | brier_scores = np.empty( 154 | shape=(n_samples, n_time_steps), 155 | dtype=np.float64, 156 | ) 157 | for t_idx, t in enumerate(times): 158 | y_true_binary, weights = self._ibs_components( 159 | y_true, np.full(shape=n_samples, fill_value=t) 160 | ) 161 | squared_error = (y_true_binary - y_pred[:, t_idx]) ** 2 162 | brier_scores[:, t_idx] = weights * squared_error 163 | 164 | return brier_scores.mean(axis=0) 165 | 166 | def draw(self): 167 | # Sample time horizons uniformly on the observed time range: 168 | min_times = self.y_train["duration"].min() 169 | max_times = self.y_train["duration"].max() 170 | times = self.rng.uniform(min_times, max_times, self.y_train.shape[0]) 171 | 172 | if self.event_of_interest == "any": 173 | # Collapse all event types together. 174 | y = self.y_train_any_event 175 | else: 176 | y = self.y_train 177 | 178 | y_binary, sample_weights = self._ibs_components( 179 | y, 180 | times, 181 | censoring_prob_y=self.censoring_prob_y_train, 182 | ) 183 | return times.reshape(-1, 1), y_binary, sample_weights 184 | 185 | 186 | class GradientBoostedCIF(BaseEstimator, SurvivalMixin, ClassifierMixin): 187 | """GBDT estimator for cause-specific Cumulative Incidence Function (CIF). 188 | 189 | This internally relies on the histogram-based gradient boosting classifier 190 | or regressor implementation of scikit-learn. 191 | 192 | Estimate a cause-specific CIF by minimizing the Brier Score for the kth 193 | cause of failure from _[1] for randomly sampled reference time horizons 194 | concatenated as extra inputs to the underlying HGB binary classification 195 | model. 196 | 197 | One can obtain the survival probabilities for any event by summing all 198 | cause-specific CIF curves and computing 1 - "sum of CIF curves". 199 | 200 | Parameters 201 | ---------- 202 | event_of_interest : int or "any" default="any" 203 | The event to compute the CIF for. When passed as an integer, it should 204 | match one of the values observed in `y_train["event"]`. Note: 0 always 205 | represents censoring and cannot be used as a valid event of interest. 206 | 207 | "any" means that all events are collapsed together and the resulting 208 | model can be used for any event survival analysis: the any 209 | event survival function can be estimated as the complement of the 210 | any event cumulative incidence function. 211 | 212 | objective : {'ibs', 'inll'}, default='ibs' 213 | The objective of the model. In practise, both objective yields 214 | comparable results. 215 | 216 | - 'ibs' : integrated brier score. Use a `HistGradientBoostedRegressor` 217 | with the 'squared_error' loss. As we have no guarantee that the regression 218 | yields a survival function belonging to [0, 1], we clip the probabilities 219 | to this range. 220 | - 'inll' : integrated negative log likelihood. Use a 221 | `HistGradientBoostedClassifier` with 'log_loss' loss. 222 | 223 | time_horizon : float or int, default=None 224 | A specific time horizon `t_horizon` to treat the model as a 225 | probabilistic classifier to estimate `E[T_k < t_horizon|X]` where `T_k` 226 | is a random variable representing the (uncensored) event for the type 227 | of interest. 228 | 229 | When specified, the `predict_proba` method returns an estimate of 230 | `E[T_k < t_horizon|X]` for each provided realisation of `X`. 231 | 232 | TODO: complete the docstring. 233 | 234 | References 235 | ---------- 236 | 237 | [1] M. Kretowska, "Tree-based models for survival data with competing risks", 238 | Computer Methods and Programs in Biomedicine 159 (2018) 185-198. 239 | """ 240 | 241 | name = "GradientBoostedCIF" 242 | 243 | def __init__( 244 | self, 245 | event_of_interest="any", 246 | objective="ibs", 247 | n_iter=10, 248 | n_repetitions_per_iter=5, 249 | learning_rate=0.1, 250 | max_depth=None, 251 | max_leaf_nodes=31, 252 | min_samples_leaf=50, 253 | show_progressbar=True, 254 | n_time_grid_steps=1000, 255 | time_horizon=None, 256 | random_state=None, 257 | ): 258 | self.event_of_interest = event_of_interest 259 | self.objective = objective 260 | self.n_iter = n_iter 261 | self.n_repetitions_per_iter = n_repetitions_per_iter # TODO? data augmenting for early iterations 262 | self.learning_rate = learning_rate 263 | self.max_depth = max_depth 264 | self.max_leaf_nodes = max_leaf_nodes 265 | self.min_samples_leaf = min_samples_leaf 266 | self.show_progressbar = show_progressbar 267 | self.n_time_grid_steps = n_time_grid_steps 268 | self.time_horizon = time_horizon 269 | self.random_state = random_state 270 | 271 | def _build_base_estimator(self, monotonic_cst): 272 | 273 | if self.objective == "ibs": 274 | return HistGradientBoostingRegressor( 275 | loss="squared_error", 276 | max_iter=1, 277 | warm_start=True, 278 | monotonic_cst=monotonic_cst, 279 | learning_rate=self.learning_rate, 280 | max_depth=self.max_depth, 281 | max_leaf_nodes=self.max_leaf_nodes, 282 | min_samples_leaf=self.min_samples_leaf, 283 | ) 284 | elif self.objective == "inll": 285 | return HistGradientBoostingClassifier( 286 | loss="log_loss", 287 | max_iter=1, 288 | warm_start=True, 289 | monotonic_cst=monotonic_cst, 290 | learning_rate=self.learning_rate, 291 | max_leaf_nodes=self.max_leaf_nodes, 292 | max_depth=self.max_depth, 293 | min_samples_leaf=self.min_samples_leaf, 294 | ) 295 | else: 296 | raise ValueError( 297 | "Parameter 'objective' must be either 'ibs' or 'inll', " 298 | f"got {self.objective}." 299 | ) 300 | 301 | def fit(self, X, y, times=None, validation_data=None): 302 | 303 | # TODO: add check_X_y from sksurv 304 | self.event_ids_ = np.unique(y["event"]) 305 | 306 | 307 | # The time horizon is concatenated as an additional input feature 308 | # before the features of X and we constrain the prediction function 309 | # (that estimates the CIF) to monotically increase with the time 310 | # horizon feature. 311 | monotonic_cst = np.zeros(X.shape[1] + 1) 312 | monotonic_cst[0] = 1 313 | 314 | self.estimator_ = self._build_base_estimator(monotonic_cst) 315 | 316 | # Compute the time grid used at prediction time. 317 | any_event_mask = y["event"] > 0 318 | observed_times = y["duration"][any_event_mask] 319 | 320 | if times is None: 321 | if observed_times.shape[0] > self.n_time_grid_steps: 322 | self.time_grid_ = np.quantile( 323 | observed_times, 324 | np.linspace(0, 1, num=self.n_time_grid_steps) 325 | ) 326 | else: 327 | self.time_grid_ = observed_times.copy() 328 | self.time_grid_.sort() 329 | else: 330 | self.time_grid_ = times 331 | 332 | ibs_training_sampler = IBSTrainingSampler( 333 | y, 334 | event_of_interest=self.event_of_interest, 335 | random_state=self.random_state, 336 | ) 337 | 338 | iterator = range(self.n_iter) 339 | if self.show_progressbar: 340 | iterator = tqdm(iterator) 341 | 342 | for idx_iter in iterator: 343 | ( 344 | sampled_times, 345 | y_binary, 346 | sample_weight, 347 | ) = ibs_training_sampler.draw() 348 | Xt = np.hstack([sampled_times, X]) 349 | self.estimator_.max_iter += 1 350 | self.estimator_.fit(Xt, y_binary, sample_weight=sample_weight) 351 | 352 | # XXX: implement verbose logging with a version of IBS that 353 | # can handle competing risks. 354 | 355 | # To be use at a fixed horizon classifier when setting time_horizon. 356 | if self.event_of_interest == "any": 357 | self.classes_ = np.array(["no_event", "any_event"]) 358 | else: 359 | self.classes_ = np.array( 360 | ["other_or_no_event", f"event_{self.event_of_interest}"] 361 | ) 362 | return self 363 | 364 | def predict_proba(self, X, time_horizon=None): 365 | """Estimate the probability of incidence for a specific time horizon. 366 | 367 | See the docstring for the `time_horizon` parameter for more details. 368 | 369 | Returns a 2d array with shape (X.shape[0], 2). The second column holds 370 | the cumulative incidence probability and the first column its 371 | complement. 372 | 373 | When `event_of_interest == "any"` the second column therefore holds the 374 | sum all individual events cumulative incidece and the first column 375 | holds the probability of remaining event free at `time_horizon`, that 376 | is, the survival probability. 377 | 378 | When `event_of_interest != "any"`, the values in the first column do 379 | not have an intuitive meaning. 380 | """ 381 | if time_horizon is None: 382 | if self.time_horizon is None: 383 | raise ValueError( 384 | "The time_horizon parameter is required to use " 385 | f"{self.__class__.__name__} as a classifier. " 386 | "This parameter can either be passed as constructor " 387 | "or method parameter." 388 | ) 389 | else: 390 | time_horizon = self.time_horizon 391 | 392 | times = np.asarray([time_horizon]) 393 | cif = self.predict_cumulative_incidence(X, times=times) 394 | 395 | # Reshape to be consistent with the expected shape returned by 396 | # the predict_proba method of scikit-learn binary classifiers. 397 | cif = cif.reshape(-1, 1) 398 | return np.hstack([1 - cif, cif]) 399 | 400 | def predict_cumulative_incidence(self, X, times=None): 401 | all_y_cif = [] 402 | 403 | if times is None: 404 | times = self.time_grid_ 405 | 406 | if self.show_progressbar: 407 | times = tqdm(times) 408 | 409 | for t in times: 410 | t = np.full((X.shape[0], 1), t) 411 | X_with_t = np.hstack([t, X]) 412 | if self.objective == "ibs": 413 | y_cif = self.estimator_.predict(X_with_t) 414 | else: 415 | y_cif = self.estimator_.predict_proba(X_with_t)[:, 1] 416 | all_y_cif.append(y_cif) 417 | 418 | cif = np.vstack(all_y_cif).T 419 | 420 | if self.objective == "ibs": 421 | cif = np.clip(cif, 0, 1) 422 | 423 | return cif 424 | 425 | def predict_survival_function(self, X, times=None): 426 | """Compute the event specific survival function. 427 | 428 | Warning: this metric only makes sense when y_train["event"] is binary 429 | (single event) or when setting event_of_interest='any'. 430 | """ 431 | if ( 432 | (self.event_ids_ > 0).sum() > 1 433 | and self.event_of_interest != "any" 434 | ): 435 | warnings.warn( 436 | f"Values returned by predict_survival_function only make " 437 | f"sense when the model is trained with a binary event " 438 | f"indicator or when setting event_of_interest='any'. " 439 | f"Instead this model was fit on data with event ids " 440 | f"{self.event_ids_.tolist()} and with " 441 | f"event_of_interest={self.event_of_interest}." 442 | ) 443 | return 1 - self.predict_cumulative_incidence(X, times=times) 444 | 445 | def predict_quantile(self, X, quantile=0.5, times=None): 446 | """Estimate the conditional median (or other quantile) time to event 447 | 448 | Note: this can return np.inf values when the estimated CIF does not 449 | reach the `quantile` value at the maximum time horizon observed on 450 | the training set. 451 | """ 452 | if times is None: 453 | times = self.time_grid_ 454 | cif_curves = self.predict_cumulative_incidence(X, times=times) 455 | quantile_idx = np.apply_along_axis( 456 | lambda a: a.searchsorted(quantile, side='right'), 1, cif_curves 457 | ) 458 | inf_mask = quantile_idx == cif_curves.shape[1] 459 | # Change quantile_idx to avoid out-of-bound index in the subsequent 460 | # line. 461 | quantile_idx[inf_mask] = cif_curves.shape[1] - 1 462 | results = times[quantile_idx] 463 | # Mark out-of-index results as np.inf 464 | results[inf_mask] = np.inf 465 | return results 466 | -------------------------------------------------------------------------------- /models/kaplan_meier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sksurv.nonparametric import kaplan_meier_estimator 4 | from sksurv.functions import StepFunction 5 | 6 | from models.survival_mixin import SurvivalMixin 7 | 8 | 9 | class KaplanMeier(SurvivalMixin): 10 | 11 | name = "KaplanMeier" 12 | 13 | def fit(self, X, y, times=None): 14 | self.km_x_, self.km_y_ = kaplan_meier_estimator(y["event"], y["duration"]) 15 | return self 16 | 17 | def predict_survival_function(self, X, times): 18 | surv_probs = StepFunction(self.km_x_, self.km_y_)(times) 19 | return np.vstack([surv_probs] * X.shape[0]) 20 | -------------------------------------------------------------------------------- /models/kaplan_neighbors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.base import BaseEstimator 4 | from sklearn.neighbors import NearestNeighbors 5 | from sklearn.utils.validation import check_is_fitted 6 | 7 | from sksurv.tree.tree import _array_to_step_function 8 | 9 | from xgbse.non_parametric import calculate_kaplan_vectorized 10 | 11 | from .survival_mixin import SurvivalMixin 12 | 13 | 14 | class KaplanNeighbors(BaseEstimator, SurvivalMixin): 15 | 16 | def __init__(self, neighbors_params=None): 17 | self.neighbors_params = neighbors_params 18 | 19 | def fit(self, X, y=None, times=None): 20 | self.nearest_neighbors_ = NearestNeighbors(**self.neighbors_params).fit(X) 21 | self.y_train_ = y 22 | self.times_ = times 23 | return self 24 | 25 | def predict_survival_function(self, X, times=None, return_array=True): 26 | check_is_fitted(self, "nearest_neighbors_") 27 | X_idx = self.nearest_neighbors_.kneighbors(X, return_distance=False) 28 | y_preds = self.y_train_[X_idx] 29 | survival_probs, _, _ = calculate_kaplan_vectorized( 30 | E=y_preds["event"], 31 | T=y_preds["duration"], 32 | time_bins=self.times_, 33 | ) 34 | survival_probs = survival_probs.values 35 | if return_array: 36 | return survival_probs 37 | return _array_to_step_function(self.times_, survival_probs) 38 | -------------------------------------------------------------------------------- /models/kaplan_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from sklearn.base import BaseEstimator 4 | from sklearn.utils.multiclass import type_of_target 5 | from sklearn.utils.validation import check_is_fitted 6 | 7 | from sksurv.tree.tree import _array_to_step_function 8 | from sksurv.nonparametric import kaplan_meier_estimator 9 | from sksurv.functions import StepFunction 10 | 11 | from .survival_mixin import SurvivalMixin 12 | 13 | 14 | def _map_leaves_to_km(leaves, y_train, time_bins): 15 | unique_leaves = np.unique(leaves) 16 | leaf_to_km = dict() 17 | 18 | for leaf in unique_leaves: 19 | 20 | mask_leaf = np.where(leaves == leaf) 21 | _y_train = y_train[mask_leaf] 22 | times, km = kaplan_meier_estimator(_y_train["event"], _y_train["duration"]) 23 | 24 | # Build the grid by using the step function, remove user required time_bins 25 | # that are out of train times interval. 26 | min_time, max_time = times[0], times[-1] 27 | mask_min_time = time_bins < min_time 28 | mask_max_time = time_bins > max_time 29 | mask_time = ~(mask_min_time | mask_max_time) 30 | km = StepFunction(times, km)(time_bins[mask_time]) 31 | 32 | # Fill the km array with its own min and max 33 | # for the time_bins that we previously excluded. 34 | min_km, max_km = km[0], km[-1] 35 | n_min_val = sum(mask_min_time) 36 | n_max_val = sum(mask_max_time) 37 | left_km = np.full(shape=n_min_val, fill_value=min_km) 38 | right_km = np.full(shape=n_max_val, fill_value=max_km) 39 | full_km = np.hstack([left_km, km, right_km]) 40 | leaf_to_km[leaf] = full_km 41 | 42 | return leaf_to_km 43 | 44 | 45 | class KaplanTree(BaseEstimator, SurvivalMixin): 46 | 47 | def fit(self, X, y=None, times=None): 48 | leaves = self._get_leaves(X) 49 | self.leaf_to_km_ = _map_leaves_to_km(leaves, y, times) 50 | self.times_ = times 51 | return self 52 | 53 | def predict_survival_function(self, X, times=None, return_array=True): 54 | check_is_fitted(self, "leaf_to_km_") 55 | x_leaves = self._get_leaves(X) 56 | survival_probs = np.vstack([self.leaf_to_km_[leaf] for leaf in x_leaves]) 57 | if return_array: 58 | return survival_probs 59 | return _array_to_step_function(self.times_, survival_probs) 60 | 61 | def _get_leaves(self, X): 62 | target_type = type_of_target(X) 63 | if target_type == "multilabel-indicator": 64 | X_leaves = np.argmax(X, axis=1) 65 | elif target_type == "multiclass": 66 | X_leaves = X 67 | else: 68 | raise ValueError(f"X must be a categorical label, got: {target_type}") 69 | x_leaves = self._ensure_leaves_1d(X_leaves) 70 | return x_leaves 71 | 72 | def _ensure_leaves_1d(self, X_leaves): 73 | if X_leaves.ndim == 2: 74 | x_leaves = np.asarray(X_leaves).ravel() # scipy sparse yields a matrix, we need a ndarray 75 | else: 76 | x_leaves = X_leaves 77 | return x_leaves 78 | -------------------------------------------------------------------------------- /models/meta_grid_bc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from joblib import Parallel, delayed 3 | 4 | from sklearn.base import BaseEstimator, clone 5 | from sklearn.utils.validation import check_is_fitted 6 | from sklearn.ensemble import HistGradientBoostingClassifier 7 | from sklearn.linear_model import LogisticRegression 8 | 9 | from sksurv.tree.tree import _array_to_step_function 10 | 11 | from xgbse._debiased_bce import _build_multi_task_targets 12 | from xgbse._base import DummyLogisticRegression 13 | 14 | from .survival_mixin import SurvivalMixin 15 | 16 | 17 | class MetaGridBC(BaseEstimator, SurvivalMixin): 18 | def __init__(self, classifier=None, n_jobs=None, verbose=0, name="MetaGridBC"): 19 | self.classifier = classifier or LogisticRegression() 20 | self.n_jobs = n_jobs 21 | self.verbose = verbose 22 | self.name = name 23 | 24 | def fit(self, X, y, times=None): 25 | e_name, t_name = y.dtype.names 26 | targets_train, _ = _build_multi_task_targets( 27 | E=y[e_name], 28 | T=y[t_name], 29 | time_bins=times, 30 | ) 31 | with Parallel(n_jobs=self.n_jobs, verbose=self.verbose) as parallel: 32 | estimators = parallel( 33 | delayed(self._fit_one_lr)(X, targets_train[:, i]) 34 | for i in range(targets_train.shape[1]) 35 | ) 36 | self.times_ = times 37 | self.estimators_ = estimators 38 | return self 39 | 40 | def _fit_one_lr(self, X, target): 41 | mask = target != -1 42 | 43 | if len(target[mask]) == 0: 44 | # If there's no observation in a time bucket we raise an error 45 | raise ValueError("Error: No observations in a time bucket") 46 | elif len(np.unique(target[mask])) == 1: 47 | # If there's only one class in a time bucket 48 | # we create a dummy classifier that predicts that class and send a warning 49 | classifier = DummyLogisticRegression() 50 | else: 51 | classifier = clone(self.classifier) 52 | classifier.fit(X[mask, :], target[mask]) 53 | return classifier 54 | 55 | def predict_survival_function(self, X, times=None, return_array=True): 56 | check_is_fitted(self, "estimators_") 57 | with Parallel(n_jobs=self.n_jobs) as parallel: 58 | y_preds = parallel( 59 | delayed(self._predict_one_lr)(X, estimator) 60 | for estimator in self.estimators_ 61 | ) 62 | y_preds = np.asarray(y_preds) 63 | survival_probs = np.cumprod(1 - y_preds, axis=0).T 64 | if return_array: 65 | return survival_probs 66 | return _array_to_step_function(self.times_, survival_probs) 67 | 68 | def _predict_one_lr(self, X, estimator): 69 | y_pred = estimator.predict_proba(X) 70 | # return probability of "positive" event 71 | return y_pred[:, 1] 72 | -------------------------------------------------------------------------------- /models/survival_mixin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class SurvivalMixin: 5 | 6 | def predict_cumulative_hazard_function(self, X_test, times): 7 | survival_probs = self.predict_survival_function(X_test, times) 8 | cumulative_hazards = -np.log(survival_probs + 1e-8) 9 | return cumulative_hazards 10 | 11 | def predict_risk_estimate(self, X_test, times): 12 | cumulative_hazards = self.predict_cumulative_hazard_function(X_test, times) 13 | return cumulative_hazards.sum(axis=1) 14 | -------------------------------------------------------------------------------- /models/tree_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.testing import assert_array_equal 3 | from pandas.core.frame import DataFrame 4 | 5 | from sklearn.base import TransformerMixin, BaseEstimator, clone 6 | from sklearn.ensemble import RandomForestRegressor 7 | from sklearn.preprocessing import OneHotEncoder 8 | from sklearn.utils.validation import check_array, check_is_fitted, check_random_state 9 | 10 | 11 | # inspired from sklearn _gradient_boosting.pyx 12 | def _random_sample_mask( 13 | n_sample, 14 | n_in_bag, 15 | random_state, 16 | ): 17 | sample_mask = np.hstack([ 18 | np.zeros(n_sample - n_in_bag, dtype=bool), 19 | np.ones(n_in_bag, dtype=bool), 20 | ]) 21 | random_state.shuffle(sample_mask) 22 | return sample_mask 23 | 24 | 25 | class TreeTransformer(BaseEstimator, TransformerMixin): 26 | def __init__( 27 | self, 28 | base_estimator=None, 29 | handle_survival_target=False, 30 | subsample=1.0, 31 | random_state=None, 32 | ): 33 | self.base_estimator = base_estimator 34 | self.handle_survival_target = handle_survival_target 35 | self.subsample = subsample 36 | self.random_state = random_state 37 | 38 | def fit_transform(self, X, y=None): 39 | 40 | if isinstance(X, DataFrame): 41 | raise TypeError("DataFrame not supported, convert it to a numpy ndarray") 42 | 43 | self._rng = check_random_state(self.random_state) 44 | 45 | if self.subsample < 1.0: 46 | n_sample = X.shape[0] 47 | n_in_bag = max(1, int(self.subsample * n_sample)) 48 | sample_mask = _random_sample_mask(n_sample, n_in_bag, self._rng) 49 | else: 50 | sample_mask = np.ones(X.shape[0], dtype=bool) 51 | 52 | _y = y[sample_mask] 53 | 54 | y_event = _y[_y.dtype.names[0]] 55 | y_duration = _y[_y.dtype.names[1]] 56 | assert_array_equal(np.unique(y_event), [0, 1]) 57 | y_duration = check_array(y_duration, dtype="numeric", ensure_2d=False) 58 | 59 | if self.base_estimator is None: 60 | if self.handle_survival_target: 61 | raise TypeError( 62 | "Specify a base estimator that accepts a survival " 63 | f"structure for y: {y.dtype}" 64 | ) 65 | else: 66 | base_estimator = RandomForestRegressor() 67 | else: 68 | base_estimator = clone(self.base_estimator) 69 | 70 | if self.handle_survival_target: 71 | X_leaves = base_estimator.fit(X[sample_mask, :], _y).apply(X) 72 | else: 73 | # The base estimator will be fitted with a censoring bias 74 | X_leaves = base_estimator.fit(X[sample_mask, :], y_duration).apply(X) 75 | 76 | X_leaves = self._ensure_leaves_2d(X_leaves) 77 | self.encoder_ = OneHotEncoder(sparse=True, handle_unknown="ignore") 78 | X_ohe = self.encoder_.fit_transform(X_leaves) 79 | self.base_estimator_ = base_estimator 80 | return X_ohe 81 | 82 | def fit(self, X, y=None): 83 | self.fit_transform(X, y) 84 | return self 85 | 86 | def transform(self, X): 87 | check_is_fitted(self, "base_estimator_") 88 | X_leaves = self.base_estimator_.apply(X) 89 | X_leaves = self._ensure_leaves_2d(X_leaves) 90 | return self.encoder_.transform(X_leaves) 91 | 92 | def _ensure_leaves_2d(self, X_leaves): 93 | if X_leaves.ndim == 1: 94 | X_leaves = X_leaves.reshape(-1, 1) 95 | return X_leaves -------------------------------------------------------------------------------- /models/yasgbt.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is an archive of the previous GradientBoostedCIF model, saved for the sake 3 | of reproducibility of the benchmarks `kkbox_cv_benchmark` and `kkbox_cv_yasgbt`. 4 | """ 5 | from abc import ABC 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from sklearn.base import BaseEstimator 10 | from sklearn.utils.validation import check_random_state 11 | from sklearn.ensemble import HistGradientBoostingRegressor, HistGradientBoostingClassifier 12 | 13 | from sksurv.functions import StepFunction 14 | from sksurv.metrics import integrated_brier_score 15 | from sksurv.nonparametric import kaplan_meier_estimator, CensoringDistributionEstimator 16 | 17 | from models.survival_mixin import SurvivalMixin 18 | from model_selection.cross_validation import get_time_grid 19 | 20 | 21 | class IBSTrainingSampler: 22 | 23 | def __init__(self, y, sampling_strategy, random_state): 24 | self.y = y 25 | self.sampling_strategy = sampling_strategy 26 | self.random_state = random_state 27 | times, km = kaplan_meier_estimator(y["event"], y["duration"]) 28 | # inverse_km has x: [0.01, ..., .99] and y: [0., ..., 820.] 29 | self.inverse_km = StepFunction(x=km[::-1], y=times[::-1]) 30 | self.max_km, self.min_km = km[0], km[-1] 31 | 32 | def make_sample(self): 33 | y = self.y 34 | rng = check_random_state(self.random_state) 35 | 36 | if self.sampling_strategy == "inverse_km": 37 | surv_probs = rng.uniform(self.min_km, self.max_km, y.shape[0]) 38 | times = self.inverse_km(surv_probs) 39 | elif self.sampling_strategy == "uniform": 40 | min_times, max_times = y["duration"].min(), y["duration"].max() 41 | times = rng.uniform(min_times, max_times, y.shape[0]) 42 | else: 43 | raise ValueError(f"Sampling strategy must be 'inverse_km' or 'uniform', got {self.sampling_strategy}") 44 | 45 | mask_y_0 = y["event"] & (y["duration"] <= times) 46 | mask_y_1 = y["duration"] > times 47 | 48 | yc = np.zeros(y.shape[0]) 49 | yc[mask_y_1] = 1 50 | 51 | cens = CensoringDistributionEstimator().fit(y) 52 | # calculate inverse probability of censoring weight at current time point t. 53 | prob_cens_t = cens.predict_proba(times) 54 | prob_cens_t = np.clip(prob_cens_t, 1e-8, 1) 55 | # calculate inverse probability of censoring weights at observed time point 56 | prob_cens_y = cens.predict_proba(y["duration"]) 57 | prob_cens_y = np.clip(prob_cens_y, 1e-8, 1) 58 | 59 | sample_weights = np.where(mask_y_0, 1/prob_cens_y, 0) 60 | sample_weights = np.where(mask_y_1, 1/prob_cens_t, sample_weights) 61 | 62 | return times.reshape(-1, 1), yc, sample_weights 63 | 64 | 65 | class BaseYASGBT(BaseEstimator, SurvivalMixin, ABC): 66 | 67 | def __init__( 68 | self, 69 | sampling_strategy="uniform", 70 | n_iter=10, 71 | n_repetitions_per_iter=5, 72 | learning_rate=0.1, 73 | max_depth=7, 74 | min_samples_leaf=50, 75 | verbose=False, 76 | show_progressbar=True, 77 | random_state=None, 78 | ): 79 | self.sampling_strategy = sampling_strategy 80 | self.n_iter = n_iter 81 | self.n_repetitions_per_iter = n_repetitions_per_iter # TODO? data augmenting for early iterations 82 | self.learning_rate = learning_rate 83 | self.max_depth = max_depth 84 | self.min_samples_leaf = min_samples_leaf 85 | self.verbose = verbose 86 | self.show_progressbar = show_progressbar 87 | self.random_state = random_state 88 | 89 | def fit(self, X, y, times=None, validation_data=None): 90 | 91 | # TODO: add check_X_y from sksurv 92 | monotonic_cst = np.zeros(X.shape[1]+1) 93 | monotonic_cst[0] = -1 94 | 95 | self.hgbt_ = self._get_model(monotonic_cst) 96 | 97 | data_sampler = IBSTrainingSampler( 98 | y, 99 | sampling_strategy=self.sampling_strategy, 100 | random_state=self.random_state, 101 | ) 102 | iterator = range(self.n_iter) 103 | if self.show_progressbar: 104 | iterator = tqdm(iterator) 105 | 106 | for idx_iter in iterator: 107 | times, yc, sample_weight = data_sampler.make_sample() 108 | Xt = np.hstack([times, X]) 109 | self.hgbt_.max_iter += 1 110 | self.hgbt_.fit(Xt, yc, sample_weight=sample_weight) 111 | 112 | if self.verbose: 113 | train_ibs = self.compute_ibs(y, X_val=X) 114 | msg_ibs = f"round {idx_iter+1:03d} -- train ibs: {train_ibs:.6f}" 115 | 116 | if validation_data is not None: 117 | X_val, y_val = validation_data 118 | val_ibs = self.compute_ibs(y, X_val, y_val) 119 | msg_ibs += f" -- val ibs: {val_ibs:.6f}" 120 | 121 | print(msg_ibs) 122 | 123 | return self 124 | 125 | def compute_ibs(self, y_train, X_val, y_val=None): 126 | if y_val is None: 127 | y_val = y_train 128 | times_val = get_time_grid(y_train, y_val) 129 | survival_probs = self.predict_survival_function(X_val, times_val) 130 | 131 | return integrated_brier_score(y_train, y_val, survival_probs, times_val) 132 | 133 | 134 | class YASGBTClassifier(BaseYASGBT): 135 | 136 | name = "YASGBTClassifier" 137 | 138 | def _get_model(self, monotonic_cst): 139 | 140 | return HistGradientBoostingClassifier( 141 | loss="log_loss", 142 | max_iter=1, 143 | warm_start=True, 144 | monotonic_cst=monotonic_cst, 145 | learning_rate=self.learning_rate, 146 | max_depth=self.max_depth, 147 | min_samples_leaf=self.min_samples_leaf, 148 | ) 149 | 150 | def predict_survival_function(self, X, times): 151 | 152 | all_y_probs = [] 153 | 154 | iterator = times 155 | if self.show_progressbar: 156 | iterator = tqdm(iterator) 157 | 158 | for t in iterator: 159 | t = np.full((X.shape[0], 1), t) 160 | Xt = np.hstack([t, X]) 161 | y_probs = self.hgbt_.predict_proba(Xt)[:, 1] 162 | all_y_probs.append(y_probs) 163 | 164 | surv_probs = np.vstack(all_y_probs).T 165 | 166 | return surv_probs 167 | 168 | 169 | class YASGBTRegressor(BaseYASGBT): 170 | 171 | name = "YASGBTRegressor" 172 | 173 | def _get_model(self, monotonic_cst): 174 | 175 | return HistGradientBoostingRegressor( 176 | loss="squared_error", 177 | max_iter=1, 178 | warm_start=True, 179 | monotonic_cst=monotonic_cst, 180 | learning_rate=self.learning_rate, 181 | max_depth=self.max_depth, 182 | min_samples_leaf=self.min_samples_leaf, 183 | ) 184 | 185 | def predict_survival_function(self, X, times): 186 | 187 | all_y_probs = [] 188 | 189 | iterator = times 190 | if self.show_progressbar: 191 | iterator = tqdm(iterator) 192 | 193 | for t in iterator: 194 | t = np.full((X.shape[0], 1), t) 195 | Xt = np.hstack([t, X]) 196 | y_probs = self.hgbt_.predict(Xt) 197 | all_y_probs.append(y_probs) 198 | 199 | surv_probs = np.vstack(all_y_probs).T 200 | 201 | return np.clip(surv_probs, 0, 1) 202 | -------------------------------------------------------------------------------- /notebooks/censoring.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soda-inria/survival-analysis-benchmark/2633aa5e4e6433c73bb87e99a9f3aad1033726e7/notebooks/censoring.png -------------------------------------------------------------------------------- /notebooks/truck_dataset.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py 5 | # text_representation: 6 | # extension: .py 7 | # format_name: light 8 | # format_version: '1.5' 9 | # jupytext_version: 1.14.5 10 | # kernelspec: 11 | # display_name: Python 3 (ipykernel) 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # # The Truck Dataset 17 | 18 | # The purpose of this notebook is to generate a synthetic dataset of failure times with and without uncensoring to study survival analysis and competing risk analysis methods with access to the ground truth (e.g. uncensored failure times and true hazard functions, conditioned on some observable features). 19 | # 20 | # We chose to simulate a predictive maintenance problem, namely the failures during the operation of a fleet of trucks. 21 | 22 | # ### Type of failures 23 | # 24 | # Survival analysis can be used for predictive maintenance in industrial settings. In this work, we will create a synthetic dataset of trucks and drivers with their associated simulated failures, in a competitive events setting. Our truck failures can be of three types: 25 | # 26 | # **1. Initial assembly failures $e_1$** 27 | # 28 | # This failure might occure during the first weeks after the operation of a newly commissioned truck. As these hazards stem from **manufacturing defects** such as incorrect wiring or components assembly, they are dependent on the quality of assembly of each truck, along with its usage rate. 29 | # 30 | # **2. Operation failure $e_2$** 31 | # 32 | # Operation failures can occur on a day to day basis because of some critical mistakes made by the driver —e.g. car accident, wrong gas fill-up. The probability of making mistakes is linked to the ease of use (UX) of the truck, the expertise of the driver and the usage rate of the truck. 33 | # 34 | # **3. Fatigue failure $e_3$** 35 | # 36 | # Fatigue failure relate the wear of the material and components of each truck through time. This type of hazard is linked to the quality of the material of the truck and also its usage rate. I could also be linked to the ability of the driver to operate it with minimal wear and tear (e.g. reduced or anticipated use of breaks, use of gears and smooth accelerations and decelarations). 37 | 38 | # ### Observed and hidden variables 39 | # We make the simplistic assumptions that the variables of interest are constant through time. To create non-linearities and make the dataset more challenging, we consider that the observer don't have access to the three truck characteristics: assembly quality, UX and material quality. 40 | # 41 | # Instead, the observer has only access to the **brand** of the truck and its **model**. They also know the **usage rate** because it is linked to the driver planning, and they have access to the **training level** of each drivers. 42 | 43 | # 44 | # 45 | # 46 | # **TODO** update this diagram to add driver skill and usage rate. 47 | 48 | # So, in summary: 49 | 50 | # |failure id |failure name |associated features | 51 | # |-----------|-------------|----------------------------| 52 | # |$e_1$ |assembly |assembly quality, usage rate| 53 | # |$e_2$ |operation |UX, operator training, driver skill, usage rate| 54 | # |$e_3$ |fatigue |material quality, usage rate| 55 | 56 | # ## Drivers and truck properties 57 | 58 | import pandas as pd 59 | import numpy as np 60 | from sklearn.utils import check_random_state 61 | from matplotlib import pyplot as plt 62 | import seaborn as sns 63 | sns.set_style("darkgrid") 64 | 65 | 66 | # We consider 10,000 pairs (driver, truck) with constant features. The period span on 10 years. 67 | 68 | n_datapoints = 10_000 69 | total_years = 10 70 | total_days = total_years * 365 71 | 72 | # ### Sampling driver / truck pairs 73 | # 74 | # Let's assume that drivers have different experience and training. We summarize this information in a "skill" level with values in the `[0.2-1.0]` range. We make the simplifying assumpting that the skill of the drivers do not evolve during the duration of the experiment 75 | # 76 | # Furthermore each driver has a given truck model and we assume that drivers do not change truck model over that period. 77 | # 78 | # We further assume that each (driver, truck) pair has a specific usage rate that stays constant over time (for the sake of simplicity). Let's assume that usage rates are distributed as a mixture of Gaussians. 79 | 80 | # + 81 | from scipy.stats import norm 82 | 83 | def sample_usage_weights(n_datapoints, rng): 84 | rates_1 = norm.rvs(.5, .08, size=n_datapoints, random_state=rng) 85 | rates_2 = norm.rvs(.8, .05, size=n_datapoints, random_state=rng) 86 | usage_mixture_idxs = rng.choice(2, size=n_datapoints, p=[1/3, 2/3]) 87 | return np.where(usage_mixture_idxs, rates_1, rates_2).clip(0, 1) 88 | 89 | 90 | # + 91 | truck_model_names = ["RA", "C1", "C2", "RB", "C3"] 92 | 93 | def sample_driver_truck_pairs(n_datapoints, random_seed=None): 94 | rng = np.random.RandomState(random_seed) 95 | df = pd.DataFrame( 96 | { 97 | "driver_skill": rng.uniform(low=0.2, high=1.0, size=n_datapoints).round(decimals=1), 98 | "truck_model": rng.choice(truck_model_names, size=n_datapoints), 99 | "usage_rate": sample_usage_weights(n_datapoints, rng).round(decimals=2), 100 | } 101 | ) 102 | return df 103 | 104 | 105 | # - 106 | 107 | fig, axes = plt.subplots(ncols=3, figsize=(12, 3)) 108 | df = sample_driver_truck_pairs(n_datapoints, random_seed=0) 109 | df["usage_rate"].plot.hist(bins=30, xlabel="usage_rate", ax=axes[0]) 110 | df["driver_skill"].plot.hist(bins=30, xlabel="driver_skill", ax=axes[1]) 111 | df["truck_model"].value_counts().plot.bar(ax=axes[2]) 112 | df 113 | 114 | # ### Truck models, Brands, UX, Material and Assembly quality 115 | # 116 | 117 | # Let's imagine that the assembly quality only depends on the supplier brand. There are two brands on the market, Robusta (R) and Cheapz (C). 118 | 119 | brand_quality = pd.DataFrame({ 120 | "brand": ["Robusta", "Cheapz"], 121 | "assembly_quality": [0.95, 0.30], 122 | }) 123 | brand_quality 124 | 125 | # The models have user controls with different UX and driving assistance that has improved over the years. On the other hands the industry has progressively evolved to use lower quality materials over the years. 126 | # 127 | # Each truck model come from a specific brand: 128 | 129 | # + 130 | trucks = pd.DataFrame( 131 | { 132 | "truck_model": truck_model_names, 133 | "brand": [ 134 | "Robusta" if m.startswith("R") else "Cheapz" 135 | for m in truck_model_names 136 | ], 137 | "ux": [.2, .5, .7, .9, 1.0], 138 | "material_quality": [.95, .92, .85, .7, .65], 139 | } 140 | ).merge(brand_quality) 141 | 142 | trucks 143 | 144 | 145 | # - 146 | 147 | # We can easily augment our truck driver pairs with those extra metadata by using a join: 148 | 149 | # + 150 | def sample_driver_truck_pairs_with_metadata(n_datapoints, random_seed): 151 | return ( 152 | sample_driver_truck_pairs( 153 | n_datapoints, random_seed=random_seed 154 | ) 155 | .reset_index() 156 | .merge(trucks, on="truck_model") 157 | # Sort by original index to avoid introducing an ordering 158 | # of the dataset based on the truck_model column. 159 | .sort_values("index") 160 | .drop("index", axis="columns") 161 | ) 162 | 163 | sample_driver_truck_pairs_with_metadata(10, random_seed=0) 164 | 165 | 166 | # - 167 | 168 | # ## Types of Failures 169 | # 170 | # We assume all types of failures follow a [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution) with varying shape parameters $k$: 171 | # 172 | # - k < 1: is good to model manufacturing defects, "infant mortality" and similar, monotonically decreasing hazards; 173 | # - k = 1: constant hazards (exponential distribution): random events not related to time (e.g. driving accidents); 174 | # - k > 1: "aging" process, wear and tear... monotonically increasing hazards. 175 | # 176 | # The hazard function can be implemented as: 177 | 178 | def weibull_hazard(t, k=1., s=1., t_shift=100, base_rate=1e2): 179 | # See: https://en.wikipedia.org/wiki/Weibull_distribution 180 | # t_shift is a trick to avoid avoid negative powers at t=0 when k < 1. 181 | # t_shift could be interpreted at the operation time at the factory for 182 | # quality assurance checks for instance. 183 | t = t + t_shift 184 | return base_rate * (k / s) * (t / s) ** (k - 1.) 185 | 186 | 187 | # + 188 | fig, ax = plt.subplots() 189 | 190 | t = np.linspace(0, total_days, total_days) # in days 191 | for k, s in [(0.003, 1.0), (1, 1e5), (7., 5e3)]: 192 | y = weibull_hazard(t, k=k, s=s) # rate of failures / day 193 | ax.plot(t, y, alpha=0.6, label=f"$k={k}, s={s}$"); 194 | ax.set( 195 | title="Weibull Hazard (failure rates)", 196 | ); 197 | plt.legend(); 198 | # - 199 | 200 | # ## Assembly failure $e_1$ 201 | # 202 | # Let $\lambda_1$ be the hazard related to the event $e_1$. We model the $\lambda_1$ with Weibull hazards with k << 1. 203 | 204 | # Therefore for the assembly failure $e_1$, 205 | # 206 | # $$\lambda_1 \propto \mathrm{usage\; rate} \times (1 - \mathrm{assembly\; quality})$$ 207 | # 208 | 209 | # + 210 | t = np.linspace(0, total_days, total_days) 211 | 212 | def assembly_hazards(df, t): 213 | baseline = weibull_hazard(t, k=0.003) 214 | s = (df["usage_rate"] * (1 - df["assembly_quality"])).to_numpy() 215 | return s.reshape(-1, 1) * baseline.reshape(1, -1) 216 | 217 | fig, ax = plt.subplots() 218 | subsampled_df = sample_driver_truck_pairs_with_metadata(5, random_seed=0) 219 | hazards_1 = assembly_hazards(subsampled_df, t) 220 | for idx, h1 in enumerate(hazards_1): 221 | ax.plot(t, h1, label=f"Pair #{idx}") 222 | ax.set( 223 | title="$\lambda_1$ hazard for some (driver, truck) pairs", 224 | xlabel="time (days)", 225 | ylabel="$\lambda_1$", 226 | ) 227 | plt.legend() 228 | subsampled_df 229 | # - 230 | 231 | # This seems to indicate that drivers of Cheapz trucks have a significantly larger risk to run into a manufacturing defect during the first 2 years. Let's confirm this by computing the mean hazards per brands on a a much larger sample: 232 | 233 | # + 234 | df = sample_driver_truck_pairs_with_metadata(n_datapoints, random_seed=0) 235 | hazards_1 = assembly_hazards(df, t) 236 | 237 | fig, ax = plt.subplots() 238 | for brand in df["brand"].unique(): 239 | mask_brand = df["brand"] == brand 240 | mean_hazards = hazards_1[mask_brand].mean(axis=0) 241 | ax.plot(t, mean_hazards, label=brand) 242 | ax.set( 243 | title="Average $\lambda_1(t)$ hazard by brand", 244 | xlabel="time (days)", 245 | ylabel="$\lambda_1$", 246 | ) 247 | plt.legend(); 248 | 249 | 250 | # - 251 | 252 | # ## Operation failure $e_2$ 253 | # 254 | # We consider the operation hazard to be a constant modulated by driver skills, UX and usage rate. 255 | 256 | # + 257 | def operational_hazards(df, t): 258 | # Weibull hazards with k = 1 is just a constant over time: 259 | baseline = weibull_hazard(t, k=1, s=8e3) 260 | s = ( 261 | ((1 - df["driver_skill"]) * (1 - df["ux"]) + .001) * df["usage_rate"] 262 | ).to_numpy() 263 | return s.reshape(-1, 1) * baseline.reshape(1, -1) 264 | 265 | 266 | hazards_2 = operational_hazards(df, t) 267 | 268 | # + 269 | models = sorted(df["truck_model"].unique()) 270 | 271 | fig, ax = plt.subplots() 272 | for model in models: 273 | mask_model = df["truck_model"] == model 274 | mean_hazards = hazards_2[mask_model].mean(axis=0) 275 | ax.plot(t, mean_hazards, label=model) 276 | ax.set( 277 | title="Average $\lambda_2(t)$ by model", 278 | xlabel="time (days)", 279 | ylabel="$\lambda_2$", 280 | ) 281 | plt.legend(); 282 | 283 | 284 | # - 285 | 286 | # ## Fatigue failure $e_3$ 287 | # 288 | # We now model fatigue related features with Weibull hazards with k > 1. 289 | 290 | # + 291 | def fatigue_hazards(df, t): 292 | return np.vstack([ 293 | 0.5 * weibull_hazard(t, k=6 * material_quality, s=4e3) * usage_rate 294 | for material_quality, usage_rate in zip(df["material_quality"], df["usage_rate"]) 295 | ]) 296 | 297 | 298 | hazards_3 = fatigue_hazards(df, t) 299 | # - 300 | 301 | fig, ax = plt.subplots() 302 | for h_3_ in hazards_3[:5]: 303 | ax.plot(t, h_3_) 304 | ax.set(title="$\lambda_3$ hazard", xlabel="time (days)"); 305 | 306 | fig, ax = plt.subplots() 307 | for model in models: 308 | mask_model = df["truck_model"] == model 309 | hazards_mean = hazards_3[mask_model].mean(axis=0) 310 | ax.plot(t, hazards_mean, label=model) 311 | ax.set( 312 | title="Average $\lambda_3(t)$", 313 | xlabel="time (days)", 314 | ) 315 | plt.legend(); 316 | 317 | # ## Additive hazard curve (any event curve) 318 | # 319 | # Let's enhance our understanding of these hazards by plotting the additive (any event) hazards for some couple (operator, machine). 320 | 321 | hazards_1.shape, hazards_2.shape, hazards_3.shape 322 | 323 | hazards_1.nbytes / 1e6 324 | 325 | total_hazards = (hazards_1[:5] + hazards_2[:5] + hazards_3[:5]) 326 | fig, ax = plt.subplots() 327 | for idx, total_hazards_ in enumerate(total_hazards): 328 | ax.plot(t, total_hazards_, label=idx) 329 | ax.set( 330 | title="$\lambda_{\mathrm{total}}$ hazard", 331 | xlabel="time (years)", 332 | ylabel="$\lambda(t)$", 333 | 334 | xlim=[None, 2000], 335 | ylim=[-0.00001, 0.005], 336 | ) 337 | plt.legend(); 338 | 339 | # ## Sampling from all hazards 340 | # 341 | # Now that we have the event hazards for the entire period of observation, we can sample the failure for all (driver, truck) pairs and define our target. 342 | # 343 | # Our target `y` is comprised of two columns: 344 | # - `event`: 1, 2, 3 or 0 if no event occured during the period or if the observation was censored 345 | # - `duration`: the day when the event or censor was observed 346 | 347 | # + 348 | from scipy.stats import bernoulli 349 | 350 | 351 | def sample_events_by_type(hazards, random_state=None): 352 | rng = check_random_state(random_state) 353 | outcomes = bernoulli.rvs(hazards, random_state=rng) 354 | any_event_mask = np.any(outcomes, axis=1) 355 | duration = np.full(outcomes.shape[0], fill_value=total_days) 356 | occurrence_rows, occurrence_cols = np.where(outcomes) 357 | # Some individuals might have more than one event occurrence, 358 | # we only keep the first one. 359 | # ex: trials = [[0, 0, 1, 0, 1]] -> duration = 2 360 | _, first_occurrence_idxs = np.unique(occurrence_rows, return_index=True) 361 | duration[any_event_mask] = occurrence_cols[first_occurrence_idxs] 362 | jitter = rng.rand(duration.shape[0]) 363 | return pd.DataFrame(dict(event=any_event_mask, duration=duration + jitter)) 364 | 365 | 366 | # - 367 | 368 | # Let's count the number of events of each type that would occur if event types were non-competing: 369 | 370 | rng = check_random_state(0) 371 | occurrences_1 = sample_events_by_type(hazards_1, random_state=rng) 372 | print( 373 | f"total events: {occurrences_1['event'].sum()}, " 374 | f"mean duration: {occurrences_1.query('event')['duration'].mean():.2f} days" 375 | ) 376 | 377 | occurrences_2 = sample_events_by_type(hazards_2, random_state=rng) 378 | print( 379 | f"total events: {occurrences_2['event'].sum()}, " 380 | f"mean duration: {occurrences_2.query('event')['duration'].mean():.2f} days" 381 | ) 382 | 383 | occurrences_3 = sample_events_by_type(hazards_3, random_state=rng) 384 | print( 385 | f"total events: {occurrences_3['event'].sum()}, " 386 | f"mean duration: {occurrences_3.query('event')['duration'].mean():.2f} days" 387 | ) 388 | 389 | 390 | # Let's compute the result of the competing events buy only considering the first event for each driver / truck pair. 391 | 392 | # + 393 | def first_event(event_frames, event_ids, random_seed=None): 394 | rng = check_random_state(random_seed) 395 | event = np.zeros(event_frames[0].shape[0], dtype=np.int32) 396 | max_duration = np.max([ef["duration"].max() for ef in event_frames]) 397 | duration = np.full_like(event_frames[0]["duration"], fill_value=max_duration) 398 | 399 | out = pd.DataFrame( 400 | { 401 | "event": event, 402 | "duration": duration, 403 | } 404 | ) 405 | for event_id, ef in zip(event_ids, event_frames): 406 | mask = ef["event"] & (ef["duration"] < out["duration"]) 407 | out.loc[mask, "event"] = event_id 408 | out.loc[mask, "duration"] = ef.loc[mask, "duration"] 409 | return out 410 | 411 | 412 | competing_events = first_event( 413 | [occurrences_1, occurrences_2, occurrences_3], event_ids=[1, 2, 3] 414 | ) 415 | 416 | 417 | # + 418 | def plot_stacked_occurrences(occurrences): 419 | hists = [ 420 | occurrences.query("event == @idx")["duration"] 421 | for idx in range(4) 422 | ] 423 | labels = [f"$e_{idx}$" for idx in range(4)] 424 | fig, axes = plt.subplots(ncols=2, figsize=(12, 4)) 425 | axes[0].hist(hists, bins=50, stacked=True, label=labels); 426 | axes[0].set( 427 | xlabel="duration (days)", 428 | ylabel="occurrence count", 429 | title="Stacked combined duration distributions", 430 | ) 431 | axes[0].legend() 432 | occurrences["event"].value_counts().sort_index().plot.bar(rot=0, ax=axes[1]) 433 | axes[1].set(title="Event counts by type (0 == censored)") 434 | 435 | 436 | plot_stacked_occurrences(competing_events) 437 | 438 | 439 | # - 440 | 441 | # Let's now write a function to add non-informative (independent uniform censoring): 442 | 443 | # + 444 | def uniform_censoring(occurrences, censoring_weight=0.5, offset=0, random_state=None): 445 | n_datapoints = occurrences.shape[0] 446 | rng = check_random_state(random_state) 447 | max_duration = occurrences["duration"].max() 448 | censoring_durations = rng.randint( 449 | low=offset, high=max_duration, size=n_datapoints 450 | ) 451 | # reduce censoring randomly by setting durations back to the max, 452 | # effectively ensuring that a fraction of the datapoints will not 453 | # be censured. 454 | disabled_censoring_mask = rng.rand(n_datapoints) > censoring_weight 455 | censoring_durations[disabled_censoring_mask] = max_duration 456 | out = occurrences.copy() 457 | censor_mask = occurrences["duration"] > censoring_durations 458 | out.loc[censor_mask, "event"] = 0 459 | out.loc[censor_mask, "duration"] = censoring_durations[censor_mask] 460 | return out 461 | 462 | 463 | censored_events = uniform_censoring(competing_events, random_state=0) 464 | plot_stacked_occurrences(censored_events) 465 | # - 466 | 467 | # It is often the case that there is deterministic component to the censoring distribution that stems from a maximum observation duration: 468 | 469 | max_observation_duration = 2000 470 | max_duration_mask = censored_events["duration"] > max_observation_duration 471 | censored_events.loc[max_duration_mask, "duration"] = max_observation_duration 472 | censored_events.loc[max_duration_mask, "event"] = 0 473 | plot_stacked_occurrences(censored_events) 474 | 475 | 476 | # Let's put it all data generation steps together. 477 | 478 | # + 479 | def sample_competing_events( 480 | data, 481 | uniform_censoring_weight=1.0, 482 | max_observation_duration=2000, 483 | random_seed=None, 484 | ): 485 | rng = check_random_state(random_seed) 486 | t = np.linspace(0, total_days, total_days) 487 | hazard_funcs = [ 488 | assembly_hazards, 489 | operational_hazards, 490 | fatigue_hazards, 491 | ] 492 | event_ids = np.arange(len(hazard_funcs)) + 1 493 | all_hazards = np.asarray([ 494 | hazard_func(data, t) for hazard_func in hazard_funcs 495 | ]) 496 | occurrences_by_type = [ 497 | sample_events_by_type(all_hazards[i], random_state=rng) 498 | for i in range(all_hazards.shape[0]) 499 | ] 500 | occurrences = first_event(occurrences_by_type, event_ids) 501 | censored_occurrences = uniform_censoring( 502 | occurrences, censoring_weight=uniform_censoring_weight, random_state=rng 503 | ) 504 | if max_observation_duration is not None: 505 | # censor all events after max_observation_duration 506 | max_duration_mask = censored_occurrences["duration"] > max_observation_duration 507 | censored_occurrences.loc[max_duration_mask, "duration"] = max_observation_duration 508 | censored_occurrences.loc[max_duration_mask, "event"] = 0 509 | return ( 510 | censored_occurrences, 511 | occurrences, 512 | all_hazards # shape = (n_event_types, n_observations, n_timesteps) 513 | ) 514 | 515 | 516 | truck_failure_10k = sample_driver_truck_pairs_with_metadata(10_000, random_seed=0) 517 | ( 518 | truck_failure_10k_events, 519 | truck_failure_10k_events_uncensored, 520 | truck_failure_10k_all_hazards, 521 | ) = sample_competing_events(truck_failure_10k, random_seed=0) 522 | plot_stacked_occurrences(truck_failure_10k_events) 523 | # - 524 | 525 | truck_failure_10k 526 | 527 | # Let's check that the Kaplan-Meier estimator can estimate the mean "any event" survival function. We compare this estimate to the theoretical mean survival function computed from the conditional hazard functions from which the event data has been sampled: 528 | 529 | # + 530 | from sksurv.nonparametric import kaplan_meier_estimator 531 | from scipy.interpolate import interp1d 532 | 533 | 534 | def plot_survival_function(event_frame, all_hazards): 535 | assert all_hazards.shape[0] == event_frame.query("event != 0")["event"].nunique() 536 | assert all_hazards.shape[1] == event_frame.shape[0] # observations 537 | assert all_hazards.shape[2] >= event_frame["duration"].max() # days 538 | 539 | any_event = event_frame["event"] > 0 540 | km_times, km_surv_probs = kaplan_meier_estimator(any_event, event_frame["duration"]) 541 | 542 | # Make it possible to evaluate the survival probabilities at any time step with 543 | # with constant extrapolation if necessary. 544 | times = np.arange(total_days) 545 | surv_func = interp1d( 546 | km_times, 547 | km_surv_probs, 548 | kind="previous", 549 | bounds_error=False, 550 | fill_value="extrapolate" 551 | ) 552 | surv_probs = surv_func(times) 553 | 554 | any_event_hazards = all_hazards.sum(axis=0) 555 | true_surv = np.exp(-any_event_hazards.cumsum(axis=-1)) 556 | 557 | plt.step(times, surv_probs, label="KM estimator $\hat{S}(t)$") 558 | plt.step(times, true_surv.mean(axis=0), label="True $E_{x_i \in X} [S(t; x_i)]$") 559 | plt.legend() 560 | plt.title("Survival functions") 561 | 562 | 563 | plot_survival_function(truck_failure_10k_events, truck_failure_10k_all_hazards) 564 | # - 565 | 566 | # The Aalan-Johansen estimator allows us to compute the cumulative incidence function $P(T < t)$ for competitive events. 567 | # We compare its estimation to the ground truth by converting our fixed hazards to CIF. 568 | 569 | # $$CIF_k(t) = \int^t_0 f(u) du = \int^t_0 \lambda_k(u).S(u) du $$ 570 | # 571 | # Where $f(t)$ is the probability density, $CIF_k(t)$ is the cumulative incidence function, $\lambda_k(t)$ is the hazard rate of event $k$ and $S(t)$ is the survival probability. 572 | 573 | # + 574 | from lifelines import AalenJohansenFitter 575 | 576 | 577 | def plot_cumulative_incidence_functions(event_frame, all_hazards): 578 | fig, axes = plt.subplots(figsize=(12, 4), ncols=all_hazards.shape[0]) 579 | 580 | any_event_hazards = all_hazards.sum(axis=0) 581 | true_surv = np.exp(-any_event_hazards.cumsum(axis=-1)) 582 | 583 | plt.suptitle("Cause-specific cumulative incidence functions") 584 | for event_id, (ax, hazards_i) in enumerate(zip(axes, all_hazards), 1): 585 | ajf = AalenJohansenFitter(calculate_variance=True) 586 | ajf.fit(event_frame["duration"], event_frame["event"], event_of_interest=event_id) 587 | ajf.plot(label=f"Aalen Johansen estimate of $CIF_{event_id}$", ax=ax) 588 | 589 | cif = (hazards_i * true_surv).cumsum(axis=-1).mean(axis=0) 590 | ax.plot(cif, label="True $E_{x_i \in X}" + f"[CIF_{event_id}(t; x_i)]$"), 591 | ax.set(ylim=[-.01, 1.01]), 592 | ax.legend() 593 | 594 | 595 | plot_cumulative_incidence_functions(truck_failure_10k_events, truck_failure_10k_all_hazards) 596 | # - 597 | 598 | # Sanity check for our Brier score computation: it should approximately return the same values irrespective of the censoring for each event type: 599 | 600 | import sys; sys.path.append("..") 601 | 602 | # + 603 | from models.gradient_boosted_cif import cif_integrated_brier_score 604 | 605 | 606 | any_event_hazards = truck_failure_10k_all_hazards.sum(axis=0) 607 | true_survival = np.exp(-any_event_hazards.cumsum(axis=-1)) 608 | max_time_ibs = int( 609 | truck_failure_10k_events.query("event > 0")["duration"].max() 610 | ) - 1 611 | 612 | for i, hazards_i in enumerate(truck_failure_10k_all_hazards): 613 | k = i + 1 614 | true_cif_k = (hazards_i * true_survival).cumsum(axis=-1) 615 | ibs_censored = cif_integrated_brier_score( 616 | truck_failure_10k_events, 617 | truck_failure_10k_events, 618 | true_cif_k[:, :max_time_ibs], 619 | np.arange(total_days)[:max_time_ibs], 620 | event_of_interest=k, 621 | ) 622 | ibs_uncensored = cif_integrated_brier_score( 623 | truck_failure_10k_events_uncensored, 624 | truck_failure_10k_events_uncensored, 625 | true_cif_k[:, :max_time_ibs], 626 | np.arange(total_days)[:max_time_ibs], 627 | event_of_interest=k, 628 | ) 629 | print(f"IBS for event {k}: censored {ibs_censored:.4f}, uncensored {ibs_uncensored:.4f}") 630 | # - 631 | 632 | # If all is well, let's save this dataset to disk: 633 | 634 | # + 635 | observed_variables = [ 636 | "driver_skill", 637 | "brand", 638 | "truck_model", 639 | "usage_rate", 640 | ] 641 | truck_failure_10k[observed_variables].to_parquet("truck_failure_10k_features.parquet", index=False) 642 | truck_failure_10k_events.to_parquet("truck_failure_10k_competing_risks.parquet", index=False) 643 | truck_failure_10k_events_uncensored.to_parquet("truck_failure_10k_competing_risks_uncensored.parquet", index=False) 644 | 645 | truck_failure_10k_any_event = truck_failure_10k_events.copy() 646 | truck_failure_10k_any_event["event"] = truck_failure_10k_any_event["event"] > 0 647 | truck_failure_10k_any_event.to_parquet("truck_failure_10k_any_event.parquet", index=False) 648 | 649 | truck_failure_10k_any_event_uncensored = truck_failure_10k_events_uncensored.copy() 650 | truck_failure_10k_any_event_uncensored["event"] = truck_failure_10k_any_event["event"] > 0 651 | truck_failure_10k_any_event_uncensored.to_parquet("truck_failure_10k_any_event_uncensored.parquet", index=False) 652 | # - 653 | 654 | # Let's also save the underlying hazard functions used to sample each event of the dataset. 655 | 656 | np.savez_compressed( 657 | "truck_failure_10k_hazards.npz", 658 | truck_failure_10k_hazards=truck_failure_10k_all_hazards, 659 | ) 660 | 661 | # + 662 | with np.load("truck_failure_10k_hazards.npz") as hazards_file: 663 | array_names = list(hazards_file.keys()) 664 | 665 | array_names 666 | # - 667 | 668 | # ## Sampling a larger dataset (without ground truth) 669 | # 670 | # Let's sample a larger event dataset to be able to assess the sample and computational complexities of various predictive methods: 671 | 672 | # + 673 | from joblib import Parallel, delayed 674 | 675 | chunk_size = 10_000 676 | n_chunks = 10 677 | 678 | 679 | def sample_chunk(chunk_idx, chunk_size): 680 | features_chunk = sample_driver_truck_pairs_with_metadata(chunk_size, random_seed=chunk_idx) 681 | events_chunk, _, _ = sample_competing_events(features_chunk, random_seed=chunk_idx) 682 | return features_chunk, events_chunk 683 | 684 | 685 | results = Parallel(n_jobs=-1, verbose=10)( 686 | delayed(sample_chunk)(i, chunk_size) for i in range(n_chunks) 687 | ) 688 | truck_failure_100k = pd.concat([features_chunk for features_chunk, _ in results], axis="rows") 689 | truck_failure_100k_events = pd.concat([events_chunk for _, events_chunk in results], axis="rows") 690 | plot_stacked_occurrences(truck_failure_100k_events) 691 | # - 692 | 693 | # Note: we ensured that the chunk size is 10_000 and the random seed is based on the chunk index to ensure that the `truck_failure_100k` dataset is a super set of the `truck_failure_10k` dataset. 694 | 695 | # + 696 | from pandas.testing import assert_frame_equal 697 | 698 | assert_frame_equal( 699 | truck_failure_100k[observed_variables].iloc[:10_000].reset_index(drop=True), 700 | truck_failure_10k[observed_variables].reset_index(drop=True), 701 | ) 702 | assert_frame_equal( 703 | truck_failure_100k[observed_variables].iloc[:10_000].reset_index(drop=True), 704 | truck_failure_10k[observed_variables].reset_index(drop=True), 705 | ) 706 | # - 707 | 708 | truck_failure_100k[observed_variables].to_parquet("truck_failure_100k_features.parquet", index=False) 709 | truck_failure_100k_events.to_parquet("truck_failure_100k_competing_risks.parquet", index=False) 710 | truck_failure_100k_any_event = truck_failure_100k_events.copy() 711 | truck_failure_100k_any_event["event"] = truck_failure_100k_any_event["event"] > 0 712 | truck_failure_100k_any_event.to_parquet("truck_failure_100k_any_event.parquet", index=False) 713 | 714 | # ## Sampling targets at fixed conditional X 715 | # 716 | # We now fix our covariates X to the first truck-driver pair, and create a fixed dataset by sampling $N$ times our first user multi-event hazards. The goal is to check that an unconditional estimator designed for competing events, called Aalen-Johanson, gives hazards estimations close to the ground truth. 717 | 718 | truck_failure_fc = sample_driver_truck_pairs_with_metadata(1, random_seed=3) 719 | truck_failure_fc = pd.concat([truck_failure_fc] * 300, axis="rows").reset_index(drop=True) 720 | truck_failure_fc 721 | 722 | ( 723 | truck_failure_fc_events, 724 | truck_failure_fc_events_uncensored, 725 | all_hazards_fc, 726 | ) = sample_competing_events(truck_failure_fc, random_seed=42) 727 | plot_stacked_occurrences(truck_failure_fc_events) 728 | 729 | plot_survival_function(truck_failure_fc_events, all_hazards_fc) 730 | plot_cumulative_incidence_functions(truck_failure_fc_events, all_hazards_fc) 731 | 732 | # We should see that the Aalen-Johansen method provides an accurate estimators for the unconditional competing hazards, even with few samples! 733 | -------------------------------------------------------------------------------- /notebooks/tutorial_part_1.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py:percent 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.5 10 | # kernelspec: 11 | # display_name: Python 3 (ipykernel) 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% [markdown] 17 | # # Survival Analysis Tutorial Part 1 18 | # 19 | # In this tutorial we will introduce: 20 | # 21 | # - what is **right-censored time-to-event data** and why naive regression models fail on such data, 22 | # - **unconditional survival analysis** with the **Kaplan-Meier** estimator, 23 | # - **predictive survival analysis** with Cox Proportional Hazards, Survival Forests and Gradient Boosted CIF, 24 | # - how to assess the quality of survival estimators using the Integrated Brier Score and C-index metrics, 25 | # - what is **right-censored competing risks data**, 26 | # - **unconditional competing risks analysis** with the **Aalen-Johansen** estimator, 27 | # - **predictive competing risks analysis** with gradient boosted CIF. 28 | 29 | # %% [markdown] 30 | # ## Tutorial data 31 | # 32 | # To run this tutorial you can **either generate or download** the truck failure data. 33 | # 34 | # **To generate the data**, open the `truck_data.ipynb` notebook in jupyter lab and click the "run all" (fast forward arrow) button and wait a few minutes. By the end, you should see a bunch of `.parquet` and `.npz` files showing up in the `notebooks` folder. 35 | # 36 | # Alternatively, feel free **to download** a zip archive from: 37 | # 38 | # - https://github.com/soda-inria/survival-analysis-benchmark/releases/download/jupytercon-2023-tutorial-data/truck_failure.zip (a bit more than 500 MB). 39 | # 40 | # and unzip it in the `notebooks` folder. 41 | 42 | # %% [markdown] 43 | # ## What is right-censored time-to-event data? 44 | # 45 | # ### Censoring 46 | # 47 | # Survival analysis is a time-to-event regression problem, with censored data. We call censored all individuals that didn't experience the event during the range of the observation window. 48 | # 49 | # In our setting, we're mostly interested in right-censored data, meaning we that the event of interest did not occur before the end of the observation period (typically the time of collection of the dataset): 50 | # 51 | #
52 | # 53 | #
image credit: scikit-survival
54 | #
55 | # 56 | # Individuals can join the study at the same or different times, and the study may or may not be ended by the time of observation. 57 | # 58 | # Survival analysis techniques have wide applications: 59 | # 60 | # - In the **medical** landscape, events can consist in patients dying of cancer, or on the contrary recovering from some disease. 61 | # - In **predictive maintenance**, events can consist in machine failure. 62 | # - In **insurance**, we are interesting in modeling the time to next claim for a portfolio of insurance contracts. 63 | # - In **marketing**, we can consider user churning as events, or we could focus on users becoming premium (members that choose to pay a subscription after having used the free version of service for a while). 64 | # - **Economists** my be interesting in modeling the time for unemployed people to find a new job in different context or different kinds of jobs. 65 | # 66 | # 67 | # As we will see, for all those applications, it is not possible to directly train a machine learning-based regression model on such a **right-censored** time-to-event target since we only have a lower bound on the true time to event for some data points. **Naively removing such points from the dataset would cause the model predictions to be biased**. 68 | 69 | # %% [markdown] 70 | # ### Our target `y` 71 | # 72 | # For each individual $i\in[1, N]$, our survival analysis target $y_i$ is comprised of two elements: 73 | # 74 | # - The event indicator $\delta_i\in\{0, 1\}$, where $0$ marks censoring and $1$ is indicative that the event of interest has actually happened before the end of the observation window. 75 | # - The censored time-to-event $d_i=min(t_{i}, c_i) > 0$, that is the minimum between the date of the experienced event $t_i$ and the censoring date $c_i$. In a real-world setting, we don't have direct access to $t_i$ when $\delta_i=0$. We can only record $d_i$. 76 | # 77 | # Here is how we represent our target: 78 | 79 | # %% 80 | import pandas as pd 81 | import numpy as np 82 | 83 | truck_failure_events = pd.read_parquet("truck_failure_10k_any_event.parquet") 84 | truck_failure_events 85 | 86 | # %% [markdown] 87 | # In this example, we study the accident of truck-driver pairs. Censored pairs (when event is 0 or False) haven't had a mechanical failure or an accident during the study. 88 | 89 | # %% [markdown] 90 | # ### Why is it a problem to train time-to-event regression models? 91 | # 92 | # Without survival analysis, we have two naive options to deal with right-censored time to event data: 93 | # - We ignore censorted data points from the dataset, only keep events that happened and perform naive regression on them. 94 | # - We consider that all censored events happen at the end of our observation window. 95 | # 96 | # **Both approaches are wrong and lead to biased results.** 97 | # 98 | # Let's compute the average and median time to event using either of those naive approaches on our truck failure dataset. We will compare them to the mean of the ground-truth event time $T$, that we would obtained with an infinite observation window. 99 | # 100 | # Note that we have access to the random variable $T$ because we generated this synthetic dataset. With real-world data, you only have access to $Y = \min(T, C)$, where $C$ is a random variable representing the censoring time. 101 | 102 | # %% 103 | naive_stats_1 = ( 104 | truck_failure_events.query("event == True")["duration"] 105 | .apply(["mean", "median"]) 106 | ) 107 | print( 108 | f"Biased method 1 (removing censored points):\n" 109 | f"mean: {naive_stats_1['mean']:.1f} days, " 110 | f"median: {naive_stats_1['median']:.1f} days" 111 | ) 112 | 113 | # %% 114 | max_duration = truck_failure_events["duration"].max() 115 | naive_stats_2 = ( 116 | pd.Series( 117 | np.where( 118 | truck_failure_events["event"], 119 | truck_failure_events["duration"], 120 | max_duration, 121 | ) 122 | ) 123 | .apply(["mean", "median"]) 124 | ) 125 | print( 126 | f"Biased method 2 (censored events moved to the end of the window):\n" 127 | f"mean: {naive_stats_2['mean']:.1f} days, " 128 | f"median: {naive_stats_2['median']:.1f} days" 129 | ) 130 | 131 | # %% [markdown] 132 | # In our case, the **data comes from a simple truck fleet simulator** and we have **access to the uncensored times** (we can wait as long as we want to extend the observation period as needed to have all trucks fail). 133 | # 134 | # Let's have a look at the **true mean and median time-to-failure**: 135 | 136 | # %% 137 | truck_failure_events_uncensored = pd.read_parquet("truck_failure_10k_any_event_uncensored.parquet") 138 | 139 | # %% 140 | true_stats = truck_failure_events_uncensored["duration"].apply(["mean", "median"]) 141 | print( 142 | f"Ground truth (from the simulator):\n" 143 | f"mean: {true_stats['mean']:.2f} days, " 144 | f"median: {true_stats['median']:.2f} days" 145 | ) 146 | 147 | # %% [markdown] 148 | # We see that **neither of the naive ways to handle censoring gives a good estimate of the true mean or median time to event**. 149 | # 150 | # If we have access to covariates $X$ (also known as input features in machine learning), a regression method would try to estimate $\mathbb{E}[T|X]$, where $X$ are our covariates, but we only have access to $Y = \min(T, C)$ where $T$ is the true time to failure and $C$ is the censoring duration. Fitting a **conditional regression model on right-censored data** would also require a special treatment because either of the **naive preprocessing** presented above would introduce a **significant bias in the predictions**. 151 | # 152 | # 153 | # Here is structured outline of the estimators we will introduce in this tutorial: 154 | # 155 | # 156 | # | | Descriptive / unconditional: only `y`, no `X` | Predictive / conditional: `y` given `X` | 157 | # |------------------------------------------|------------------------------------------------------|-------------------------------------------------| 158 | # | Suvival Analysis (1 event type) | Kaplan-Meier | Cox PH, Survival Forests, Gradient Boosting CIF | 159 | # | Competing Risks Analysis (k event types) | Aalen-Johansen | Gradient Boosting CIF | 160 | # 161 | 162 | # %% [markdown] 163 | # Let's start with unconditional estimation of the any event survival curve. 164 | # 165 | # 166 | # ## Unconditional survival analysis with Kaplan-Meier 167 | # 168 | # We now introduce the survival analysis approach to the problem of estimating the time-to-event from censored data. For now, we ignore any information from $X$ and focus on $y$ only. 169 | # 170 | # Here our quantity of interest is the survival probability: 171 | # 172 | # $$S(t)=P(T > t)$$ 173 | # 174 | # This represents the probability that an event doesn't occur at or before some given time $t$, i.e. that it happens at some time $T > t$. 175 | # 176 | # The most commonly used method to estimate this function is the **Kaplan-Meier** estimator. It gives us an **unbiased estimate of the survival probability**. It can be computed as follows: 177 | # 178 | # $$\hat{S}(t)=\prod_{i: t_i\leq t} (1 - \frac{d_i}{n_i})$$ 179 | # 180 | # Where: 181 | # 182 | # - $t_i$ is the time of event for individual $i$ that experienced the event, 183 | # - $d_i$ is the number of individuals having experienced the event at $t_i$, 184 | # - $n_i$ are the remaining individuals at risk at $t_i$. 185 | # 186 | # Note that **individuals that were censored before $t_i$ are no longer considered at risk at $t_i$**. 187 | # 188 | # Contrary to machine learning regressors, this estimator is **unconditional**: it only extracts information from $y$ only, and cannot model information about each individual typically provided in a feature matrix $X$. 189 | # 190 | # In a real-world application, we aim at estimating $\mathbb{E}[T]$ or $Q_{50\%}[T]$. The latter quantity represents the median survival duration i.e. the duration before 50% of our population at risk experiment the event. 191 | # 192 | # We can also be interested in estimating the survival probability after some reference time $P(T > t_{ref})$, e.g. a random clinical trial estimating the capacity of a drug to improve the survival probability after 6 months. 193 | 194 | # %% 195 | import plotly.express as px 196 | from sksurv.nonparametric import kaplan_meier_estimator 197 | 198 | 199 | times, km_survival_probabilities = kaplan_meier_estimator( 200 | truck_failure_events["event"], truck_failure_events["duration"] 201 | ) 202 | 203 | # %% 204 | times 205 | 206 | # %% 207 | km_survival_probabilities 208 | 209 | # %% 210 | km_proba = pd.DataFrame( 211 | dict( 212 | time=times, 213 | survival_curve=km_survival_probabilities 214 | ) 215 | ) 216 | fig = px.line( 217 | km_proba, 218 | x="time", 219 | y="survival_curve", 220 | title="Kaplan-Meier survival probability", 221 | ) 222 | fig.add_hline( 223 | y=0.50, 224 | annotation_text="Median", 225 | line_dash="dash", 226 | line_color="red", 227 | annotation_font_color="red", 228 | ) 229 | 230 | fig.update_layout( 231 | height=500, 232 | width=800, 233 | xaxis_title="time (days)", 234 | yaxis_title="$\hat{S}(t)$", 235 | yaxis_range=[0, 1], 236 | ) 237 | 238 | 239 | # %% [markdown] 240 | # We can read the median time to event directly from this curve: it is the time at the intersection of the estimate of the survival curve with the horizontal line for a 50% failure probility. 241 | # 242 | # Since we have censored data, $\hat{S}(t)$ doesn't reach 0 within our observation window. We would need to extend the observation window to estimate the survival function beyond this limit. **Kaplan-Meier does not attempt the extrapolate beyond the last observed event**. 243 | 244 | # %% [markdown] 245 | # ***Exercice***
246 | # Based on `times` and `km_survival_probabilities`, estimate the median survival time. 247 | # 248 | # *Hint: You can use `np.searchsorted` on sorted probabilities in increasing order (reverse the natural order of the survival probabilities*. 249 | # 250 | # *Hint: Alternatively you can "inverse" the estimate of the survival curve using `scipy.interpolate.interp1d` and take the value at probability 0.5.* 251 | 252 | # %% 253 | def compute_median_survival_time(times, survival_probabilities): 254 | """Get the closest time to a survival probability of 50%.""" 255 | ### Your code here 256 | median_survival_time = 0 257 | ### 258 | return median_survival_time 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | compute_median_survival_time(times, km_survival_probabilities) 272 | 273 | 274 | # %% [markdown] 275 | # ***Solution*** 276 | 277 | # %% 278 | def compute_median_survival_time_with_searchsorted(times, survival_probabilities): 279 | """Get the closest time to a survival probability of 50%.""" 280 | # Search sorted needs an array of ascending values: 281 | increasing_survival_probabilities = survival_probabilities[::-1] 282 | median_idx = np.searchsorted(increasing_survival_probabilities, 0.50) 283 | median_survival_time = times[-median_idx] 284 | return median_survival_time.round(decimals=1) 285 | 286 | 287 | compute_median_survival_time_with_searchsorted(times, km_survival_probabilities) 288 | 289 | # %% 290 | from scipy.interpolate import interp1d 291 | 292 | 293 | def compute_median_survival_time_with_interp1d(times, survival_probabilities): 294 | """Get the time to a survival proba of 50% via linear interpolation.""" 295 | reverse_survival_func = interp1d(survival_probabilities, times, kind="linear") 296 | return reverse_survival_func([0.5])[0].round(decimals=1) 297 | 298 | 299 | compute_median_survival_time_with_interp1d(times, km_survival_probabilities) 300 | 301 | # %% [markdown] 302 | # Here is the **true median survival time from the same data without any censoring** (generally not avaible in a real life setting). 303 | 304 | # %% 305 | truck_failure_events_uncensored["duration"].median().round(decimals=1) 306 | 307 | # %% [markdown] 308 | # This empirically confirms that the median survival time estimated by post-processing the KM estimate of the survival curve is a much better way to handle censored data than any the two naive approaches we considered in the beginning of this notebook. 309 | 310 | # %% [markdown] 311 | # ### Mathematical break 312 | # 313 | # We now introduce some quantities which are going to be at the core of many survival analysis models and Kaplan-Meier in particular. 314 | # 315 | # The most important concept is the hazard rate $\lambda(t)$. This quantity represents the "speed of failure" or **the probability that an event occurs in the next $dt$, given that it hasn't occured yet**. This can be written as: 316 | # 317 | # $$\begin{align} 318 | # \lambda(t) &=\lim_{dt\rightarrow 0}\frac{P(t \leq T < t + dt | P(T \geq t))}{dt} \\ 319 | # &= \lim_{dt\rightarrow 0}\frac{P(t \leq T < t + dt)}{dtS(t)} \\ 320 | # &= \frac{f(t)}{S(t)} 321 | # \end{align} 322 | # $$ 323 | # 324 | # where $f(t)$ represents the event density function, independently of wheter the event has happened before or not. 325 | 326 | # %% [markdown] 327 | # If we integrate $f(t)$, we get the cumulative incidence function (CIF) $F(t)=P(T < t)$, which is the complement of the survival function $S(t)$: 328 | # 329 | # $$F(t) = 1 - S(t) = \int^t_0 f(u) du$$ 330 | 331 | # %% [markdown] 332 | # Most of the time we do not attempt to evaluate $f(t)$. Instead we usually define the cumulative hazard function by integrating the hazard function: 333 | # 334 | # $$\Lambda(t) = \int^t_0 \lambda(u) du$$ 335 | # 336 | # It can be shown that the survival function (and therefore the cumulative incidence function) can be computed as: 337 | # 338 | # $$S(t) = e^{-\Lambda(t)}$$ 339 | # 340 | # $$F(t) = 1 - e^{-\Lambda(t)}$$ 341 | # 342 | # and if we have an estimate of $S(t)$ we can derive estimates of the cumulative hazard and instantenous hazard functions as: 343 | # 344 | # $$\Lambda(t) = - log(S(t))$$ 345 | # 346 | # $$\lambda(t) = - \frac{S'(t)}{S(t)}$$ 347 | # 348 | # In practice, estimating the hazard function from a finite sample estimate of the survival curve can be quite challenging (from a numerical point of view). But the converse often works well. 349 | # 350 | # Since our dataset was sampled from known hazard functions (one per truck), we can compute the theoretical survival curve by integrating over the time and taking the exponential of the negative. Let's give this a try: 351 | 352 | # %% 353 | with np.load("truck_failure_10k_hazards.npz") as f: 354 | theoretical_hazards = f["truck_failure_10k_hazards"].sum(axis=0) # will be explained later 355 | 356 | theoretical_hazards.shape 357 | 358 | # %% 359 | import matplotlib.pyplot as plt 360 | import seaborn as sns; sns.set_style("darkgrid") 361 | 362 | 363 | theoretical_cumulated_hazards = theoretical_hazards.cumsum(axis=-1) 364 | mean_theoretical_survival_functions = np.exp(-theoretical_cumulated_hazards).mean(axis=0) 365 | n_time_steps = mean_theoretical_survival_functions.shape[0] 366 | 367 | fig, ax = plt.subplots(figsize=(12, 5)) 368 | ax.plot(times, km_survival_probabilities, label="Kaplan-Meier") 369 | ax.plot( 370 | np.arange(n_time_steps), mean_theoretical_survival_functions, 371 | linestyle="--", label="Theoretical" 372 | ) 373 | ax.set(title="Mean survival curve", xlabel="Time (days)") 374 | ax.legend(); 375 | 376 | # %% [markdown] 377 | # We observe that **Kaplan-Meier is an unbiased estimator of the survival curve** defined by the true hazard functions. 378 | # 379 | # However we observe that then **KM estimate is no longer defined after the time of the last observed failure** (day 2000 in our case). In this dataset, all events are censored past that date: as a result **the KM survival curve does not reach zero** even when the true curve does. Therefore, it is **not possible to compute the mean survival time from the KM-estimate** alone. One would need to make some further assumptions to extrapolate it if necessary. 380 | # 381 | # Furthermore, not all data generating process necessarily need to reach the 0.0 probability. For instance, survival analysis could be used to model a "time to next snow event" in different regions of the world. We can anticipate that it will never snow in some regions of the world in the foreseeable future. 382 | 383 | # %% [markdown] 384 | # ### Kaplan-Meier on subgroups: stratification on columns of `X` 385 | 386 | # %% [markdown] 387 | # We can enrich our analysis by introducing covariates, that are statistically associated to the events and durations. 388 | 389 | # %% 390 | truck_failure_features = pd.read_parquet("truck_failure_10k_features.parquet") 391 | truck_failure_features 392 | 393 | # %% 394 | truck_failure_features_and_events = pd.concat( 395 | [truck_failure_features, truck_failure_events], axis="columns" 396 | ) 397 | truck_failure_features_and_events 398 | 399 | # %% [markdown] 400 | # For exemple, let's use Kaplan Meier to get a sense of the impact of the **brand**, by stratifying on this variable. 401 | # 402 | # ***Exercice*** 403 | # 404 | # Plot the stratified Kaplan Meier of the brand, i.e. for each different brand: 405 | # 1. Filter the dataset on this brand using pandas, for instance by using boolean masking or using the `.query` method of the dataframe; 406 | # 2. Estimate the survival curve with Kaplan-Meier on each subset; 407 | # 3. Plot the survival curve for each subset. 408 | # 409 | # What are the limits of this method? 410 | 411 | # %% 412 | import matplotlib.pyplot as plt 413 | 414 | 415 | def plot_km_curve_by_brand(df): 416 | brands = df["brand"].unique() 417 | fig_data = [] 418 | for brand in brands: 419 | # TODO: replace the following by your code here: 420 | pass 421 | 422 | plt.title("Survival curves by brand") 423 | 424 | 425 | plot_km_curve_by_brand(truck_failure_features_and_events) 426 | 427 | # %% [markdown] 428 | # **Solution**: click below to expand the cell: 429 | 430 | # %% 431 | import matplotlib.pyplot as plt 432 | 433 | 434 | def plot_km_curve_by_brand(df): 435 | brands = df["brand"].unique() 436 | fig_data = [] 437 | for brand in brands: 438 | df_brand = df.query("brand == @brand") 439 | x, y = kaplan_meier_estimator(df_brand["event"], df_brand["duration"]) 440 | plt.plot(x, y, label=brand) 441 | 442 | plt.legend() 443 | plt.ylim(-0.01, 1.01) 444 | plt.title("Survival curves by brand") 445 | 446 | plot_km_curve_by_brand(truck_failure_features_and_events) 447 | 448 | # %% [markdown] 449 | # We can observe that drivers of "Cheapz" trucks seem to experiment a higher number of failures in the early days but then the cumulative number of failures for each group seem to become comparable. 450 | # 451 | # The stratified KM method is nice to compare two groups but quickly becomes impracticable as the number of covariate groups grow. We need estimator that can handle covariates. 452 | 453 | # %% [markdown] 454 | # Let's now attempt to quantify how a survival curve estimated on a training set performs on a test set. 455 | # 456 | # ## Survival model evaluation using the Integrated Brier Score (IBS) and the Concordance Index (C-index) 457 | 458 | # %% [markdown] 459 | # The Brier score and the C-index are measures that **assess the quality of a predicted survival curve** on a finite data sample. 460 | # 461 | # - **The Brier score is a proper scoring rule**, meaning that an estimate of the survival curve has minimal Brier score if and only if it matches the true survival probabilities induced by the underlying data generating process. In that respect the **Brier score** assesses both the **calibration** and the **ranking power** of a survival probability estimator. 462 | # 463 | # - On the other hand, the **C-index** only assesses the **ranking power**: it is invariant to a monotonic transform of the survival probabilities. It only focus on the ability of a predictive survival model to identify which individual is likely to fail first out of any pair of two individuals. 464 | # 465 | # 466 | # 467 | # It is comprised between 0 and 1 (lower is better). 468 | # It answers the question "how close to the real probabilities are our estimates?". 469 | 470 | # %% [markdown] 471 | #
Mathematical formulation 472 | # 473 | # $$\mathrm{BS}^c(t) = \frac{1}{n} \sum_{i=1}^n I(d_i \leq t \land \delta_i = 1) 474 | # \frac{(0 - \hat{S}(t | \mathbf{x}_i))^2}{\hat{G}(d_i)} + I(d_i > t) 475 | # \frac{(1 - \hat{S}(t | \mathbf{x}_i))^2}{\hat{G}(t)}$$ 476 | # 477 | # In the survival analysis context, the Brier Score can be seen as the Mean Squared Error (MSE) between our probability $\hat{S}(t)$ and our target label $\delta_i \in {0, 1}$, weighted by the inverse probability of censoring $\frac{1}{\hat{G}(t)}$. In practice we estimate $\hat{G}(t)$ using a variant of the Kaplan-Estimator with swapped event indicator. 478 | # 479 | # - When no event or censoring has happened at $t$ yet, i.e. $I(d_i > t)$, we penalize a low probability of survival with $(1 - \hat{S}(t|\mathbf{x}_i))^2$. 480 | # - Conversely, when an individual has experienced an event before $t$, i.e. $I(d_i \leq t \land \delta_i = 1)$, we penalize a high probability of survival with $(0 - \hat{S}(t|\mathbf{x}_i))^2$. 481 | # 482 | #
483 | # 484 | #
485 | # 486 | #
487 | 488 | # %% [markdown] 489 | # Let's put this in practice. We first perform a train test split so as to fit the estimator on a traing sample and compute the performance metrics on a held-out test sample. Due to restructions of some estimators in scikit-survival, we ensure that all the test data points lie well within the time range observed in the training set. 490 | 491 | # %% 492 | from sklearn.model_selection import train_test_split 493 | 494 | 495 | def train_test_split_within(X, y, idx, **kwargs): 496 | """Ensure that test data durations are within train data durations.""" 497 | X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split(X, y, idx, **kwargs) 498 | mask_duration_inliers = y_test["duration"] < y_train["duration"].max() 499 | X_test = X_test[mask_duration_inliers] 500 | y_test = y_test[mask_duration_inliers] 501 | idx_test = idx_test[mask_duration_inliers] 502 | return X_train, X_test, y_train, y_test, idx_train, idx_test 503 | 504 | 505 | X = truck_failure_features 506 | y = truck_failure_events 507 | 508 | X_train, X_test, y_train, y_test, idx_train, idx_test = train_test_split_within( 509 | X, y, np.arange(X.shape[0]), test_size=0.75, random_state=0 510 | ) 511 | 512 | # %% [markdown] 513 | # Let's estimate the survival curve on the training set: 514 | 515 | # %% 516 | km_times_train, km_survival_curve_train = kaplan_meier_estimator( 517 | y_train["event"], y_train["duration"] 518 | ) 519 | 520 | # %% [markdown] 521 | # The `km_times_train` are ordered event (or censoring) times actually observed on the training set. To be able to compare that curve with curves computed on another time grid, we can use step-wise constant interpolation: 522 | 523 | # %% 524 | from scipy.interpolate import interp1d 525 | 526 | 527 | km_predict = interp1d( 528 | km_times_train, 529 | km_survival_curve_train, 530 | kind="previous", 531 | bounds_error=False, 532 | fill_value="extrapolate", 533 | ) 534 | 535 | 536 | def make_test_time_grid(y_train, n_steps=300): 537 | """Bound times to the range of duration.""" 538 | # Some survival models can fail to predict near the boundary of the 539 | # range of durations observed on the training set. 540 | observed_duration = y_test.query("event > 0")["duration"] 541 | 542 | # trim 1% of the span, 0.5% on each end: 543 | span = observed_duration.max() - observed_duration.min() 544 | start = observed_duration.min() + 0.005 * span 545 | stop = observed_duration.max() - 0.005 * span 546 | return np.linspace(start, stop, num=n_steps) 547 | 548 | 549 | time_grid = make_test_time_grid(y_train) 550 | 551 | # %% [markdown] 552 | # Kaplan-Meier is a constant predictor: it always estimates the mean survival curve for all individual in the (training) dataset: the estimated survival curve does not depend on features values of the `X_train` or `X_test` matrices. 553 | # 554 | # To be able to compare the Kaplan-Meier estimator with conditional estimators who estimate indivudual survival curves for each row in `X_train` or `X_test` we treat KM as a constant predictor that always output the same survival curve as many times as there are rows in `X_test`: 555 | 556 | # %% 557 | km_curve = km_predict(time_grid) 558 | y_pred_km_test = np.vstack([km_curve] * X_test.shape[0]) 559 | 560 | # %% [markdown] 561 | # We can now compute on value of the Brier score for each time horizon in the test time grid using the values in `y_test` as ground truth targets using `sksurv.metrics.brier_score`. At this time, scikit-survival expects the `y` arguments to be passed as numpy record arrays instead of pandas dataframes: 562 | 563 | # %% 564 | from sksurv.metrics import brier_score 565 | 566 | 567 | def as_sksurv_recarray(y_frame): 568 | """Return scikit-survival's specific target format.""" 569 | y_recarray = np.empty( 570 | shape=y_frame.shape[0], 571 | dtype=[("event", np.bool_), ("duration", np.float64)], 572 | ) 573 | y_recarray["event"] = y_frame["event"] 574 | y_recarray["duration"] = y_frame["duration"] 575 | return y_recarray 576 | 577 | 578 | _, km_brier_scores = brier_score( 579 | survival_train=as_sksurv_recarray(y_train), 580 | survival_test=as_sksurv_recarray(y_test), 581 | estimate=y_pred_km_test, 582 | times=time_grid, 583 | ) 584 | 585 | # %% 586 | fig, ax = plt.subplots(figsize=(12, 5)) 587 | ax.plot(time_grid, km_brier_scores, label="KM on test data"); 588 | ax.set( 589 | title="Time-varying Brier score of Kaplan Meier estimation (lower is better)", 590 | xlabel = "time (days)", 591 | ) 592 | ax.legend(); 593 | 594 | # %% [markdown] 595 | # We observed that the "prediction error" is largest for time horizons between 200 and 1500 days after the beginning of the observation period. 596 | # 597 | # 598 | # Additionnaly, we compute the Integrated Brier Score (IBS) which we will use to summarize the Brier score curve and compare the quality of different estimators of the survival curve on the same test set: 599 | # $$IBS = \frac{1}{t_{max} - t_{min}}\int^{t_{max}}_{t_{min}} BS(t) dt$$ 600 | 601 | # %% 602 | from sksurv.metrics import integrated_brier_score 603 | 604 | km_ibs_test = integrated_brier_score( 605 | survival_train=as_sksurv_recarray(y_train), 606 | survival_test=as_sksurv_recarray(y_test), 607 | estimate=y_pred_km_test, 608 | times=time_grid, 609 | ) 610 | print(f"IBS of Kaplan-Meier estimator on test set: {km_ibs_test:.3f}") 611 | 612 | # %% [markdown] 613 | # Since the KM estimator always predicts the same constant survival curve for any samples in `X_train` or `X_test`, it's quite a limited model: it cannot rank individual by estimated median time to event for instance. Still, it's an interesting baseline because it's well calibrated among all the constant survival curve predictors. 614 | # 615 | # For instance we could compare to a model that would predict a linear decrease of the survival probability over time and measure the IBS on the same test data. The KM-survival curve is hopefully better than such a dummy predictor: 616 | 617 | # %% 618 | linear_survival_curve = np.linspace(1.0, 0.0, time_grid.shape[0]) 619 | constant_linear_survival_curves = [linear_survival_curve] * y_test.shape[0] 620 | 621 | linear_survival_ibs_test = integrated_brier_score( 622 | survival_train=as_sksurv_recarray(y_train), 623 | survival_test=as_sksurv_recarray(y_test), 624 | estimate=constant_linear_survival_curves, 625 | times=time_grid, 626 | ) 627 | print(f"IBS of linear survival estimator on test set: {linear_survival_ibs_test:.3f}") 628 | 629 | # %% [markdown] 630 | # Finally, let's also **introduce the concordance index (C-index)**. This metric evaluates the ranking (or discriminative) power of a model by comparing pairs of individuals having experienced the event. The C-index of a pair $(i, j)$ is maximized when individual $i$ has experienced the event before $j$ and the estimated risk of $i$ is higher than the one of $j$. 631 | # 632 | # This metric is also comprised between 0 and 1 (higher is better), 0.5 corresponds to a random prediction. 633 | # 634 | #
Mathematical formulation 635 | # 636 | # $$\mathrm{C_{index}} = \frac{\sum_{i,j} I(d_i < d_j \space \land \space \delta_i = 1 \space \land \space \mu_i < \mu_j)}{\sum_{i,j} I(d_i < d_j \space \land \space \delta_i = 1)}$$ 637 | # 638 | # Let's introduce the cumulative hazards $\Lambda(t)$, which is the negative log of the survival function $S(t)$: 639 | # 640 | # $$S(t) = \exp(-\Lambda(t)) = \exp(-\int^t_0 \lambda(u)du)$$ 641 | # 642 | # Therefore: 643 | # 644 | # $$\Lambda(t) = -\log(S(t))$$ 645 | # 646 | # Finally, the risk is obtained by summing over the entire cumulative hazard: 647 | # 648 | # $$\mu_i = \int^{t_{max}}_{t_{min}} \Lambda(t, x_i) dt = \int^{t_{max}}_{t_{min}} - \log (S(t, x_i)) dt$$ 649 | # 650 | #
651 | 652 | # %% [markdown] 653 | # To compute the C-index of our Kaplan Meier estimates, we assign every individual with the same survival probabilities given by the Kaplan Meier. 654 | 655 | # %% 656 | from sksurv.metrics import concordance_index_censored 657 | 658 | 659 | def compute_c_index(event, duration, survival_curves): 660 | survival_curves = np.asarray(survival_curves) 661 | if survival_curves.ndim != 2: 662 | raise ValueError( 663 | "`survival_probs` must be a 2d array of " 664 | f"shape (n_samples, times), got {survival_curves.shape}" 665 | ) 666 | assert event.shape[0] == duration.shape[0], survival_curves.shape[0] 667 | 668 | # Cumulative hazard is also known as risk. 669 | cumulative_hazard = survival_to_risk_estimate(survival_curves) 670 | metrics = concordance_index_censored(event, duration, cumulative_hazard) 671 | return metrics[0] 672 | 673 | 674 | def survival_to_risk_estimate(survival_probs_matrix): 675 | return -np.log(survival_probs_matrix + 1e-8).sum(axis=1) 676 | 677 | 678 | # %% 679 | km_c_index_test = compute_c_index(y_test["event"], y_test["duration"], y_pred_km_test) 680 | km_c_index_test 681 | 682 | 683 | # %% [markdown] 684 | # This is equivalent to a random prediction. Indeed, as our Kaplan Meier is a unconditional estimator: it can't be used to rank individuals predictions as it predicts the same survival curve for any row in `X_test`. 685 | # 686 | # Before moving forward, let's define a helper function that consolidates all the evaluation code together: 687 | 688 | # %% 689 | class SurvivalAnalysisEvaluator: 690 | 691 | def __init__(self, y_train, y_test, time_grid): 692 | self.model_data = {} 693 | self.y_train = as_sksurv_recarray(y_train) 694 | self.y_test = as_sksurv_recarray(y_test) 695 | self.time_grid = time_grid 696 | 697 | def add_model(self, model_name, survival_curves): 698 | survival_curves = np.asarray(survival_curves) 699 | _, brier_scores = brier_score( 700 | survival_train=self.y_train, 701 | survival_test=self.y_test, 702 | estimate=survival_curves, 703 | times=self.time_grid, 704 | ) 705 | ibs = integrated_brier_score( 706 | survival_train=self.y_train, 707 | survival_test=self.y_test, 708 | estimate=survival_curves, 709 | times=self.time_grid, 710 | ) 711 | c_index = compute_c_index( 712 | self.y_test["event"], 713 | self.y_test["duration"], 714 | survival_curves, 715 | ) 716 | self.model_data[model_name] = { 717 | "brier_scores": brier_scores, 718 | "ibs": ibs, 719 | "c_index": c_index, 720 | "survival_curves": survival_curves, 721 | } 722 | 723 | def metrics_table(self): 724 | return pd.DataFrame([ 725 | { 726 | "Model": model_name, 727 | "IBS": info["ibs"], 728 | "C-index": info["c_index"], 729 | } 730 | for model_name, info in self.model_data.items() 731 | ]).round(decimals=4) 732 | 733 | def plot(self, model_names=None): 734 | if model_names is None: 735 | model_names = list(self.model_data.keys()) 736 | fig, ax = plt.subplots(figsize=(12, 5)) 737 | self._plot_brier_scores(model_names, ax=ax) 738 | 739 | def _plot_brier_scores(self, model_names, ax): 740 | for model_name in model_names: 741 | info = self.model_data[model_name] 742 | ax.plot( 743 | self.time_grid, 744 | info["brier_scores"], 745 | label=f"{model_name}, IBS:{info['ibs']:.3f}"); 746 | ax.set( 747 | title="Time-varying Brier score (lower is better)", 748 | xlabel="time (days)", 749 | ) 750 | ax.legend() 751 | 752 | def __call__(self, model_name, survival_curves, model_names=None): 753 | self.add_model(model_name, survival_curves) 754 | self.plot(model_names=model_names) 755 | return self.metrics_table() 756 | 757 | evaluator = SurvivalAnalysisEvaluator(y_train, y_test, time_grid) 758 | evaluator.add_model("Constant linear", constant_linear_survival_curves) 759 | evaluator.add_model("Kaplan-Meier", y_pred_km_test) 760 | evaluator.plot() 761 | evaluator.metrics_table() 762 | 763 | # %% [markdown] 764 | # Next, we'll study how to fit survival models that make predictions that depend on the covariates $X$. 765 | # 766 | # ## Predictive survival analysis 767 | 768 | # %% [markdown] 769 | # ### Cox Proportional Hazards 770 | # 771 | # The Cox PH model is the most popular way of dealing with covariates $X$ in survival analysis. It computes a log linear regression on the target $Y = \min(T, C)$, and consists in a baseline term $\lambda_0(t)$ and a covariate term with weights $\beta$. 772 | # $$\lambda(t, x_i) = \lambda_0(t) \exp(x_i^\top \beta)$$ 773 | # 774 | # Note that only the baseline depends on the time $t$, but we can extend Cox PH to time-dependent covariate $x_i(t)$ and time-dependent weigths $\beta(t)$. We won't cover these extensions in this tutorial. 775 | # 776 | # This methods is called ***proportional*** hazards, since for two different covariate vectors $x_i$ and $x_j$, their ratio is: 777 | # $$\frac{\lambda(t, x_i)}{\lambda(t, x_j)} = \frac{\lambda_0(t) e^{x_i^\top \beta}}{\lambda_0(t) e^{x_j^\top \beta}}=\frac{e^{x_i^\top \beta}}{e^{x_j^\top \beta}}$$ 778 | # 779 | # This ratio is not dependent on time, and therefore the hazards are proportional. 780 | # 781 | # Let's run it on our truck-driver dataset using the implementation of `sksurv`. This models requires preprocessing of the categorical features using One-Hot encoding. Let's use the scikit-learn column-transformer to combine the various components of the model as a pipeline: 782 | 783 | # %% 784 | from sklearn.pipeline import make_pipeline 785 | from sklearn.compose import make_column_transformer 786 | from sklearn.preprocessing import OneHotEncoder 787 | from sksurv.linear_model import CoxPHSurvivalAnalysis 788 | 789 | simple_preprocessor = make_column_transformer( 790 | (OneHotEncoder(), ["brand", "truck_model"]), 791 | remainder="passthrough", 792 | verbose_feature_names_out=False, 793 | ) 794 | cox_ph = make_pipeline( 795 | simple_preprocessor, 796 | CoxPHSurvivalAnalysis(alpha=1e-4) 797 | ) 798 | cox_ph.fit(X_train, as_sksurv_recarray(y_train)) 799 | 800 | # %% [markdown] 801 | # Let's compute the predicted survival functions for each row of the test set `X_test` and plot the first 5 survival functions: 802 | 803 | # %% 804 | cox_ph_survival_funcs = cox_ph.predict_survival_function(X_test) 805 | 806 | fig, ax = plt.subplots() 807 | for idx, cox_ph_survival_func in enumerate(cox_ph_survival_funcs[:5]): 808 | survival_curve = cox_ph_survival_func(time_grid) 809 | ax.plot(time_grid, survival_curve, label=idx) 810 | ax.set( 811 | title="Survival probabilities $\hat{S(t)}$ of Cox PH", 812 | xlabel="time (days)", 813 | ylabel="S(t)", 814 | ) 815 | plt.legend(); 816 | 817 | # %% [markdown] 818 | # Those estimated survival curves are predicted for the following test datapoints: 819 | 820 | # %% 821 | X_test.head(5).reset_index(drop=True) 822 | 823 | # %% [markdown] 824 | # We see that predicted survival functions can vary significantly for different test samples. 825 | # 826 | # There are two ways to read this plot: 827 | # 828 | # First we could consider our **predictive survival analysis model as a probabilistic regressor**: if we want to **consider a specific probability of survival, say 50%**, we can mentally draw an horizontal line at 0.5, and see that: 829 | # 830 | # - test data point `#0` has an estimated median survival time around 300 days, 831 | # - test data point `#1` has an estimated median survival time around 800 days, 832 | # - test data point `#2` has an estimated median survival time around 450 days... 833 | # 834 | # Secondly we could also consider our **predictive survival analysis model as a probabilistic binary classifier**: if we **consider a specific time horizon, say 1000 days**, we can see that: 835 | # 836 | # - test data point `#0` has less than a 20% chance to remain event-free at day 1000, 837 | # - test date point `#3` has around a 50% chance to remain event-free at day 1000... 838 | # 839 | # 840 | # Let's try to get some intuition about the features importance from the first 5 truck-driver pairs and their survival probabilities. 841 | 842 | # %% [markdown] 843 | # ***Exercice*** 844 | # 845 | # Find out which features have the strongest positive or negative impact on the predictions of the model by matching the fitted coefficients $\beta$ of the model (stored under `_coef`) with their names from the `get_feature_names_out()` method of the preprocessor. 846 | # 847 | # *Hint*: You can access each step of a scikit-learn pipeline as simply as `pipeline[step_idx]`. 848 | 849 | # %% 850 | cox_ph # the full pipeline 851 | 852 | # %% 853 | cox_ph[0] # the first step of the pipeline 854 | 855 | # %% 856 | cox_ph[1] # the second step of the pipeline 857 | 858 | # %% 859 | ### Your code here 860 | 861 | 862 | 863 | 864 | feature_names = [] 865 | weights = [] 866 | 867 | 868 | 869 | ### 870 | 871 | # %% [markdown] 872 | # **Solution** 873 | 874 | # %% 875 | feature_names = cox_ph[-2].get_feature_names_out() 876 | feature_names.tolist() 877 | 878 | # %% 879 | weights = cox_ph[-1].coef_ 880 | weights 881 | 882 | # %% 883 | features = ( 884 | pd.DataFrame( 885 | dict( 886 | feature_name=feature_names, 887 | weight=weights, 888 | ) 889 | ) 890 | .sort_values("weight") 891 | ) 892 | ax = sns.barplot(features, y="feature_name", x="weight", orient="h") 893 | ax.set_title("Cox PH feature importance of $\lambda(t)$"); 894 | 895 | # %% [markdown] 896 | # Finally, we compute the Brier score for our model. 897 | 898 | # %% 899 | cox_survival_curves = np.vstack( 900 | [ 901 | cox_ph_survival_func(time_grid) 902 | for cox_ph_survival_func in cox_ph_survival_funcs 903 | ] 904 | ) 905 | evaluator("Cox PH", cox_survival_curves) 906 | 907 | # %% [markdown] 908 | # So the Cox Proportional Hazard model from scikit-survival fitted as a simple pipeline with one-hot encoded categorical variables and raw numerical variables seems already significantly better than our unconditional baseline. 909 | 910 | # %% [markdown] 911 | # **Exercise** 912 | # 913 | # Let's define a more expressive polynomial feature engineering pipeline for a Cox PH model that: 914 | # 915 | # - encodes categorical variables using the `OneHotEncoder` as previously; 916 | # - transforms numerical features with `SplineTransformer()` (using the default parameters); 917 | # - transforms the resulting of the encoded categorical variables and spline-transformed numerical variables using a degree 2 polynomial kernel approximation using the Nystroem method (e.g. `Nystroem(kernel="poly", degree=2, n_components=300)`) 918 | 919 | # %% 920 | X.columns 921 | 922 | # %% 923 | from sklearn.preprocessing import SplineTransformer 924 | from sklearn.kernel_approximation import Nystroem 925 | 926 | # TODO: write your pipeline here. 927 | 928 | 929 | # step 1: define a column transformer to: 930 | # - one-hot encode categorical columns 931 | # - spline-transform numerical features 932 | 933 | # step 2: define a Nystroem approximate degree 2 polynomial feature expansion 934 | 935 | # step 3: assemble everything in a pipeline with a CoxPHSurvivalAnalysis 936 | # model at the end. 937 | 938 | # step 4: fit the pipeline on the training set. 939 | 940 | # step 5: predict the survival functions for each row of the test set. 941 | 942 | # step 6: compute the values of the survival function on the usual `time_grid` 943 | # and store the result in an array named `poly_cox_ph_survival_curves`. 944 | 945 | # Uncomment the following to evaluate your pipeline: 946 | 947 | # evaluator("Polynomial Cox PH", poly_cox_ph_survival_curves) 948 | 949 | # %% 950 | ### Solution: 951 | 952 | from sklearn.preprocessing import SplineTransformer 953 | from sklearn.kernel_approximation import Nystroem 954 | 955 | 956 | spline_preprocessor = make_column_transformer( 957 | (OneHotEncoder(), ["brand", "truck_model"]), 958 | (SplineTransformer(), ["driver_skill", "usage_rate"]), 959 | verbose_feature_names_out=False, 960 | ) 961 | poly_cox_ph = make_pipeline( 962 | spline_preprocessor, 963 | Nystroem(kernel="poly", degree=2, n_components=300), 964 | CoxPHSurvivalAnalysis(alpha=1e-2) 965 | ) 966 | poly_cox_ph.fit(X_train, as_sksurv_recarray(y_train)) 967 | poly_cox_ph_survival_funcs = poly_cox_ph.predict_survival_function(X_test) 968 | 969 | 970 | poly_cox_ph_survival_curves = np.vstack( 971 | [ 972 | poly_cox_ph_survival_func(time_grid) 973 | for poly_cox_ph_survival_func in poly_cox_ph_survival_funcs 974 | ] 975 | ) 976 | evaluator("Polynomial Cox PH", poly_cox_ph_survival_curves) 977 | 978 | # %% [markdown] 979 | # ### Random Survival Forest 980 | # 981 | # Random Survival Forests are non-parametric model that is potentially more expressive than Cox PH. In particular, if we expect that the shape of the time varying hazards are not the same for each individual, tree-based models such as RSF might perform better. In general they also require a large enough training set to avoid overfitting. 982 | # 983 | # Note however that they are quite computational intensive and their training time can be prohibitive on very large dataset. 984 | 985 | # %% 986 | from sksurv.ensemble import RandomSurvivalForest 987 | 988 | rsf = make_pipeline( 989 | simple_preprocessor, 990 | RandomSurvivalForest(n_estimators=10, max_depth=8, n_jobs=-1), 991 | ) 992 | rsf.fit(X_train, as_sksurv_recarray(y_train)) 993 | 994 | # %% 995 | rsf_survival_funcs = rsf.predict_survival_function(X_test) 996 | 997 | fig, ax = plt.subplots() 998 | for idx, survival_func in enumerate(rsf_survival_funcs[:5]): 999 | survival_curve = survival_func(time_grid) 1000 | ax.plot(time_grid, survival_curve, label=idx) 1001 | ax.set( 1002 | title="Survival probabilities $\hat{S}(t)$ of Random Survival Forest", 1003 | xlabel="time (days)", 1004 | ylabel="S(t)", 1005 | ) 1006 | plt.legend(); 1007 | 1008 | # %% [markdown] 1009 | # Indeed we observe that the shapes of the curves can vary more than for the Cox-PH model which is more constrained. Let's see if this flexibility makes it a better predictive model on aggregate on the test set: 1010 | 1011 | # %% 1012 | rsf_survival_curves = np.vstack( 1013 | [func(time_grid) for func in rsf_survival_funcs] 1014 | ) 1015 | evaluator("Random Survival Forest", rsf_survival_curves) 1016 | 1017 | # %% [markdown] 1018 | # Unfortunately this does not seem to be able to significantly improve upon the Cox PH model as a ranking model. 1019 | 1020 | # %% [markdown] 1021 | # ### GradientBoostedCIF 1022 | # 1023 | # 1024 | # We now introduce a novel survival estimator named Gradient Boosting CIF. This estimator is based on the `HistGradientBoostingClassifier` of scikit-learn under the hood. It is named `CIF` because it has the capability to estimate cause-specific Cumulative Incidence Functions in a competing risks setting by minimizing a cause specific IBS objective function. 1025 | # 1026 | # Here we first introduce it as a conditional estimator of the any-event survival function by omitting the `event_of_interest` constructor parameter. 1027 | 1028 | # %% 1029 | import sys; sys.path.append("..") 1030 | from models.gradient_boosted_cif import GradientBoostedCIF 1031 | from model_selection.wrappers import PipelineWrapper 1032 | 1033 | 1034 | gb_cif = make_pipeline( 1035 | simple_preprocessor, 1036 | GradientBoostedCIF(n_iter=100, max_leaf_nodes=5, learning_rate=0.1), 1037 | ) 1038 | gb_cif = PipelineWrapper(gb_cif) 1039 | gb_cif.fit(X_train, y_train, time_grid) 1040 | 1041 | # %% 1042 | gb_cif_survival_curves = gb_cif.predict_survival_function(X_test, time_grid) 1043 | evaluator("Gradient Boosting CIF", gb_cif_survival_curves) 1044 | 1045 | # %% 1046 | X_train.shape 1047 | 1048 | # %% [markdown] 1049 | # This model is often better than Random Survival Forest but significantly faster to train and requires few feature engineering than a Cox PH model. 1050 | 1051 | # %% [markdown] 1052 | # Let's try to improve the performance of the models that train fast on a larger dataset. As the `truck_failure_100k` dataset is a superset of the `truck_failure_10k` dataset, we reuse the sample test sampples to simplify model evaluation: 1053 | 1054 | # %% 1055 | truck_failure_100k_any_event = pd.read_parquet("truck_failure_100k_any_event.parquet") 1056 | truck_failure_100k_features = pd.read_parquet("truck_failure_100k_features.parquet") 1057 | 1058 | train_large_mask = np.full(shape=truck_failure_100k_any_event.shape[0], fill_value=True) 1059 | train_large_mask[idx_test] = False 1060 | X_train_large = truck_failure_100k_features[train_large_mask] 1061 | y_train_large = truck_failure_100k_any_event[train_large_mask] 1062 | 1063 | large_model_evaluator = SurvivalAnalysisEvaluator(y_train_large, y_test, time_grid) 1064 | 1065 | # %% 1066 | X_train_large.shape 1067 | 1068 | # %% [markdown] 1069 | # **Warning**: fitting polynomial Cox PH on the larger training set takes several minutes on a modern laptop. Feel free to skip. 1070 | 1071 | # %% 1072 | poly_cox_ph_large_survival_curves = None 1073 | 1074 | # %% 1075 | # %%time 1076 | poly_cox_ph_large = make_pipeline( 1077 | spline_preprocessor, 1078 | Nystroem(kernel="poly", degree=2, n_components=300), 1079 | CoxPHSurvivalAnalysis(alpha=1e-4) 1080 | ) 1081 | poly_cox_ph_large.fit(X_train_large, as_sksurv_recarray(y_train_large)) 1082 | poly_cox_ph_large_survival_funcs = poly_cox_ph_large.predict_survival_function(X_test) 1083 | 1084 | 1085 | poly_cox_ph_large_survival_curves = np.vstack( 1086 | [ 1087 | f(time_grid) for f in poly_cox_ph_large_survival_funcs 1088 | ] 1089 | ) 1090 | large_model_evaluator("Polynomial Cox PH (larger training set)", poly_cox_ph_large_survival_curves) 1091 | 1092 | # %% [markdown] 1093 | # Fitting `GradientBoostedCIF` on the larger dataset should take a fraction of a minute on a modern laptop: 1094 | 1095 | # %% 1096 | # %%time 1097 | gb_cif_large = make_pipeline( 1098 | simple_preprocessor, 1099 | GradientBoostedCIF(max_leaf_nodes=31, learning_rate=0.1, n_iter=100), 1100 | ) 1101 | gb_cif_large = PipelineWrapper(gb_cif_large) 1102 | gb_cif_large.fit(X_train_large, y_train_large, time_grid) 1103 | gb_cif_large_survival_curves = gb_cif_large.predict_survival_function(X_test, time_grid) 1104 | 1105 | large_model_evaluator("Gradient Boosting CIF (larger training set)", gb_cif_large_survival_curves) 1106 | 1107 | # %% [markdown] 1108 | # ### Comparing our estimates to the theoretical survival curves 1109 | # 1110 | # Since the dataset is synthetic, we can access the underlying hazard function for each row of `X_test`: 1111 | 1112 | # %% 1113 | with np.load("truck_failure_10k_hazards.npz") as f: 1114 | theoretical_hazards = f["truck_failure_10k_hazards"] 1115 | theoretical_hazards.shape 1116 | 1117 | # %% [markdown] 1118 | # The first axis correspond to the 3 types of failures of this dataset (that will be covered in the next section). For now let's collapse them all together an consider the "any event" hazard functions: 1119 | 1120 | # %% 1121 | any_event_hazards = theoretical_hazards.sum(axis=0) 1122 | any_event_hazards.shape 1123 | 1124 | # %% [markdown] 1125 | # We can then extra the test records: 1126 | 1127 | # %% 1128 | any_event_hazards_test = any_event_hazards[idx_test] 1129 | any_event_hazards_test.shape 1130 | 1131 | # %% [markdown] 1132 | # and finally, do a numerical integration over the last dimension (using `cumsum(axis=-1)`) and take the exponential of the negative cumulated hazards to recover the theoretical survival curves for each sample of the test set: 1133 | 1134 | # %% 1135 | theoretical_survival_curves = np.exp(-any_event_hazards_test.cumsum(axis=-1)) 1136 | theoretical_survival_curves.shape 1137 | 1138 | # %% [markdown] 1139 | # Finally, we can evaluate the performance metrics (IBS and C-index) of the theoretical curves on the same test events and `time_grid` to be able to see how far our best predictive survival analysis models are from the optimal model: 1140 | 1141 | # %% 1142 | time_grid.shape 1143 | 1144 | # %% 1145 | n_total_days = any_event_hazards.shape[-1] 1146 | original_time_range = np.linspace(0, n_total_days, n_total_days) 1147 | 1148 | theoretical_survival_curves = np.asarray([ 1149 | interp1d( 1150 | original_time_range, 1151 | surv_curve, 1152 | kind="previous", 1153 | bounds_error=False, 1154 | fill_value="extrapolate", 1155 | )(time_grid) for surv_curve in theoretical_survival_curves 1156 | ]) 1157 | 1158 | evaluator("Data generating process", theoretical_survival_curves) 1159 | 1160 | # %% [markdown] 1161 | # The fact that the C-index of the Polynomial Cox PH model can some times be larger than the C-index of the theoretical curves is quite unexpected and would deserve further investigation. It could be an artifact of our evaluation on a finite size test set and the use of partially censored test data. 1162 | # 1163 | # Let's also compare with the version of the model trained on the large dataset: 1164 | 1165 | # %% 1166 | large_model_evaluator("Data generating process", theoretical_survival_curves) 1167 | 1168 | # %% [markdown] 1169 | # We observe that our best models are quite close to the theoretical optimum but there is still some slight margin for improvement. It's possible that re-training the same model pipelines with even larger number of training data points or better choice of hyperparameters and feature preprocessing could help close that gap. 1170 | # 1171 | # Note that the IBS and C-index values of the theoretical survival curves are far from 0.0 and 1.0 respectively: this is expected because not all the variations of the target `y` can be explained by the values of the columns of `X`: there is still a large irreducible amount of unpredictable variability (a.k.a. "noise") in this data generating process. 1172 | 1173 | # %% [markdown] 1174 | # Since our estimators are conditional models, it's also interesting to compare the predicted survival curves for a few test samples and contrasting those to the 1175 | # theoretical survival curves: 1176 | 1177 | # %% 1178 | fig, axes = plt.subplots(nrows=5, figsize=(12, 22)) 1179 | 1180 | for sample_idx, ax in enumerate(axes): 1181 | ax.plot(time_grid, gb_cif_large_survival_curves[sample_idx], label="Gradient Boosting CIF") 1182 | if poly_cox_ph_large_survival_curves is not None: 1183 | ax.plot(time_grid, poly_cox_ph_large_survival_curves[sample_idx], label="Polynomial Cox PH") 1184 | ax.plot(time_grid, theoretical_survival_curves[sample_idx], linestyle="--", label="True survival curve") 1185 | ax.plot(time_grid, 0.5 * np.ones_like(time_grid), linestyle="--", color="black", label="50% probability") 1186 | ax.set( 1187 | title=f"Survival curve for truck #{sample_idx}", 1188 | ylim=[-.01, 1.01], 1189 | ) 1190 | ax.legend() 1191 | 1192 | # %% [markdown] 1193 | # The individual survival functions predicted by the polynomial Cox PH model are always smooth but we can observe that they do now always match the shape of the true survival curve on some test datapoints. 1194 | # 1195 | # We can also observe that the individual survival curves of the Gradient Boosting CIF model **suffer from the constant-piecewise prediction function of the underlying decision trees**. But despite this limitation, this model still yields very good approximation to the true survival curves. In particular **they can provide competitive estimates of the median survival time** for instance. 1196 | # 1197 | # Let's check this final asserion by comparing the Mean absolute error for the median survival time estimates for our various estimators. Note that we can only do this because our data is synthetic and we have access to the true median survival time derived from the data generating process. 1198 | 1199 | # %% 1200 | from sklearn.metrics import mean_absolute_error 1201 | 1202 | def quantile_survival_times(times, survival_curves, q=0.5): 1203 | increasing_survival_curves = survival_curves[:, ::-1] 1204 | median_indices = np.apply_along_axis( 1205 | lambda a: a.searchsorted(q), axis=1, arr=increasing_survival_curves 1206 | ) 1207 | return times[-median_indices] 1208 | 1209 | 1210 | def compute_quantile_metrics(evaluator): 1211 | all_metrics = [] 1212 | for model_name, info in evaluator.model_data.items(): 1213 | survival_curves = info["survival_curves"] 1214 | record = {"Model": model_name} 1215 | for q in [0.25, 0.5, 0.75]: 1216 | mae = mean_absolute_error( 1217 | quantile_survival_times(time_grid, survival_curves, q=q), 1218 | quantile_survival_times(time_grid, theoretical_survival_curves, q=q), 1219 | ) 1220 | record[f"MAE for q={np.round(q, 2)}"] = mae.round(1) 1221 | all_metrics.append(record) 1222 | return pd.merge(evaluator.metrics_table(), pd.DataFrame(all_metrics)) 1223 | 1224 | 1225 | compute_quantile_metrics(evaluator) 1226 | 1227 | # %% 1228 | compute_quantile_metrics(large_model_evaluator) 1229 | 1230 | # %% [markdown] 1231 | # This confirms that the best estimators ranked by IBS computed on a censored test sample are the most accurately modeling the uncensored time-to-event distribution. 1232 | # 1233 | # Furthermore, we observe that a small gain in IBS can have a significant impact in terms of MAE and Gradient Boosting CIF can reduce its prediction error significantly by increasing the size of the training set and simultaneously increasing the number of leaf nodes per tree. 1234 | # 1235 | # This is not the case of the Polynomial Cox PH model which seems to be intrisically limited by its core modeling assumption: the shape of the Cox PH hazard function depends on $t$ but is independent of $X$. If this assumption does not hold, no amount of additional training data will help the estimator reach the optimal IBS. 1236 | 1237 | # %% [markdown] 1238 | # ## Unconditional competing risks modeling with Aalen-Johanson 1239 | # 1240 | # So far, we've been dealing with a single kind of risk: any accident. **What if we have different, mutually exclusive types of failure?** 1241 | # 1242 | # This is the point of **competing risks modeling**. It aims at modeling the probability of incidence for different events, where these probabilities interfer with each other. Here we consider that a truck that had an accident is withdrawn from the fleet, and therefore can't experience any other ones. 1243 | # 1244 | 1245 | # %% [markdown] 1246 | # Let's load our dataset another time. Notice that we have 3 types of event (plus the censoring 0): 1247 | 1248 | # %% 1249 | truck_failure_competing_events = pd.read_parquet("truck_failure_10k_competing_risks.parquet") 1250 | truck_failure_competing_events 1251 | 1252 | # %% [markdown] 1253 | # In this refined variant of the truck failure event data, the event identifiers mean the following: 1254 | # 1255 | # - 1: manufacturing defect: a failure of a truck that happens as a result of mistakes in the assembly of the components (e.g. loose bolts); 1256 | # - 2: operational failures, e.g. a driving accident; 1257 | # - 3: fatigure induced failures, e.g. an engine breaks after heavy use for a prolongued period of time, despite good assembly and regular maintenance. 1258 | # 1259 | # 0 is still the censoring marker. 1260 | 1261 | # %% [markdown] 1262 | # Instead of estimating a survival function (probability of remaining event free over time), a competing risk analysis model attempts to estimate a **cause-specific cumulative incidence function ($CIF_k$)**: 1263 | # 1264 | # For any event $k \in [1, K]$, the cumulative incidence function of the event $k$ becomes: 1265 | # 1266 | # $$CIF_k = P(T < t, \mathrm{event}=k)$$ 1267 | # 1268 | # In the unconditional case, the estimator ignore any side information in $X$ and only models $CIF_k(t)$ from information in $y$: event types and their respective durations (often with censoring). 1269 | # 1270 | # **Aalen-Johanson estimates the CIF for multi-event $k$**, by: 1271 | # - estimating the cause-specific hazards on one hand; 1272 | # - estimating the global (any event) survival probabilities using Kaplan-Meier on the other hand. 1273 | # 1274 | # The two estimates are then combined to produce an estimate of the cause-specific cumulative incidence. 1275 | # 1276 | #
Mathematical formulation 1277 | # 1278 | # We first compute the cause-specific hazards $\lambda_k$, by simply counting for each individual duration $t_i$ the number of individuals that have experienced the event $k$ at $t_i$ ($d_{i,k}$), and the number of people still at risk at $t_i$ ($n_i$). 1279 | # 1280 | # $$ 1281 | # \hat{\lambda}_k(t_i)=\frac{d_{k,i}}{n_i} 1282 | # $$ 1283 | # 1284 | # Then, we compute the survival probability any event with Kaplan-Meier any event, where we can reused the cause-specific hazards. 1285 | # 1286 | # $$ 1287 | # \hat{S}(t)=\prod_{i:t_i\leq t} (1 - \frac{d_i}{n_i})=\prod_{i:t_i\leq t} (1 - \sum_k\hat{\lambda}_{k}(t_i)) 1288 | # $$ 1289 | # 1290 | # Finally, we compute the CIF of event $k$ as the sum of the cause-specific hazards, weighted by the survival probabilities. 1291 | # 1292 | # $$\hat{F}_k(t)=\sum_{i:t_i\leq t} \hat{\lambda}_k(t_i) \hat{S}(t_{i-1})$$ 1293 | # 1294 | # 1295 | #
1296 | 1297 | # %% [markdown] 1298 | # Let's use lifelines to estimate the ${CIF_k}$ using Aalen-Johanson. We need to indicate which event to fit on, so we'll iteratively fit the model on all events. 1299 | 1300 | # %% 1301 | from lifelines import AalenJohansenFitter 1302 | 1303 | fig, ax = plt.subplots() 1304 | 1305 | total_cif = None 1306 | competing_risk_ids = sorted( 1307 | truck_failure_competing_events.query("event > 0")["event"].unique() 1308 | ) 1309 | for event in competing_risk_ids: 1310 | print(f"Fitting Aalen-Johansen for event {event}...") 1311 | ajf = AalenJohansenFitter(calculate_variance=True) 1312 | ajf.fit( 1313 | truck_failure_competing_events["duration"], 1314 | truck_failure_competing_events["event"], 1315 | event_of_interest=event 1316 | ) 1317 | ajf.plot(ax=ax, label=f"event {event}") 1318 | cif_df = ajf.cumulative_density_ 1319 | cif_times = cif_df.index 1320 | if total_cif is None: 1321 | total_cif = cif_df[cif_df.columns[0]].values 1322 | else: 1323 | total_cif += cif_df[cif_df.columns[0]].values 1324 | 1325 | ax.plot(cif_times, total_cif, label="total", linestyle="--", color="black") 1326 | ax.set( 1327 | title="Mean CIFs estimated by Aalen-Johansen", 1328 | xlabel="time (days)", 1329 | xlim=(-30, 2030), 1330 | ylim=(-0.05, 1.05), 1331 | ) 1332 | plt.legend(); 1333 | 1334 | # %% [markdown] 1335 | # This unconditional model helps us identify 3 types of events, having momentum at different times. As expected: 1336 | # 1337 | # - the incidence of type 1 events (failures caused by manufactoring defects) increase quickly from the start and then quickly plateau: after 1000 days, the estimator expects almost no new type 1 event to occur: all trucks with a manufacturing defect should have failed by that time. 1338 | # - the incidence of type 2 events (failures caused by wrong operation of the truck) constantly accumulate throughout the observation period. 1339 | # - the incidence of type 3 events (fatigure induced failures) is almost null until 500 days and then slowly increase. 1340 | # 1341 | # Note that once a truck as failed from one kind of event (e.g. a manufacturing defect), it is taken out of the pool of trucks under study and can therefore no longer experience any other kind of failures: there are therefore many operational or fatigure induced failures that do no happen because some trucks have previously failed from a competing failure type. 1342 | # 1343 | # Finally, we can observe that, as time progresses, operational and, even more importantly, fatigue induced failures are expected to make all the trucks in the study fail. Therefore, the sum of the cumulative incidence functions is expected to reach 100% in our study in the large time limit. 1344 | # 1345 | # However, since all events beyond 2000 days are censored in our dataset, the Aaelen-Johansen estimator produces truncated incidence curves: it does not attempt to extra polate beyond the maximum event time observed in the data. 1346 | 1347 | # %% [markdown] 1348 | # ## Predictive competing risks analysis using our GradientBoostedCIF 1349 | # 1350 | # Contrary to predictive survival analysis, the open source ecosystem is not very mature for predictive competing risk analysis. 1351 | # 1352 | # However, `GradientBoostedCIF` was designed in a way to optimize the cause-specific version of IBS: for each type of event $k$, we fit one `GradientBoostedCIF` instance specialized for this kind of event by passing `event_of_interest=k` to the constructor. 1353 | 1354 | # %% 1355 | y_train_cr = truck_failure_competing_events.loc[idx_train] 1356 | y_test_cr = truck_failure_competing_events.loc[idx_test] 1357 | 1358 | cif_models = {} 1359 | for k in competing_risk_ids: 1360 | gb_cif_k = make_pipeline( 1361 | simple_preprocessor, 1362 | GradientBoostedCIF( 1363 | event_of_interest=k, max_leaf_nodes=15, n_iter=50, learning_rate=0.05 1364 | ), 1365 | ) 1366 | gb_cif_k = PipelineWrapper(gb_cif_k) 1367 | gb_cif_k.fit(X_train, y_train_cr, time_grid) 1368 | cif_models[k] = gb_cif_k 1369 | 1370 | # %% [markdown] 1371 | # Alternatively we can fit a larger model with the largest version of the same dataset (should take less a minute): 1372 | 1373 | # %% 1374 | # truck_failure_competing_events_large = pd.read_parquet("truck_failure_100k_competing_risks.parquet") 1375 | # y_train_cr_large = truck_failure_competing_events_large.loc[train_large_mask] 1376 | 1377 | # cif_models = {} 1378 | # for k in competing_risk_ids: 1379 | # gb_cif_k = make_pipeline( 1380 | # simple_preprocessor, 1381 | # GradientBoostedCIF( 1382 | # event_of_interest=k, max_leaf_nodes=31, n_iter=100, learning_rate=0.1 1383 | # ), 1384 | # ) 1385 | # gb_cif_k = PipelineWrapper(gb_cif_k) 1386 | # gb_cif_k.fit(X_train_large, y_train_cr_large, time_grid) 1387 | # cif_models[k] = gb_cif_k 1388 | 1389 | # %% [markdown] 1390 | # Once fit, we can use this family of model to predict individual CIF predictions for each kind of event. Plotting the average CIF across indiviuals in the test set should recover curves similar to the Aalean-Johansen estimates (in the limit of large training and test data): 1391 | 1392 | # %% 1393 | fig, ax = plt.subplots() 1394 | total_mean_cif = np.zeros(time_grid.shape[0]) 1395 | 1396 | gb_cif_cumulative_incidence_curves = {} 1397 | for k in competing_risk_ids: 1398 | cif_curves_k = cif_models[k].predict_cumulative_incidence(X_test, time_grid) 1399 | gb_cif_cumulative_incidence_curves[k] = cif_curves_k 1400 | mean_cif_curve_k = cif_curves_k.mean(axis=0) # average over test points 1401 | ax.plot(time_grid, mean_cif_curve_k, label=f"event {k}") 1402 | total_mean_cif += mean_cif_curve_k 1403 | 1404 | ax.plot(time_grid, total_mean_cif, label="total", linestyle="--", color="black") 1405 | ax.set( 1406 | title="Mean CIFs estimated by GradientBoostingCIF", 1407 | xlabel="time in days", 1408 | ylabel="Cumulative Incidence", 1409 | xlim=(-30, 2030), 1410 | ylim=(-0.05, 1.05), 1411 | ) 1412 | plt.legend(); 1413 | 1414 | # %% [markdown] 1415 | # The average cause-specific cumulative incidence curves seems to mostly agree with the Aalen-Johansen estimate. One can observe some problematic discrepancy though (depending on the choice of the hyper-parameters): 1416 | # 1417 | # - the cumulative incidence of event 1 and 2 do not start at 0 as expected; 1418 | # - the cumulative incidence of event 1 seems to continue growing beyond day 1500 which is not expected either. 1419 | # 1420 | # On aggregate we can therefore expect the total incidence to be over estimated on the edges of the time range. 1421 | 1422 | # %% [markdown] 1423 | # Let also reuse the any-event survival estimates to check that: 1424 | # 1425 | # $$\hat{S}(t) \approx 1 - \sum_k \hat{CIF_k}(t)$$ 1426 | # 1427 | 1428 | # %% 1429 | fig, ax = plt.subplots() 1430 | mean_survival_curve = gb_cif_large_survival_curves.mean(axis=0) 1431 | ax.plot(time_grid, total_mean_cif, label="Total CIF") 1432 | ax.plot(time_grid, mean_survival_curve, label="Any-event survival") 1433 | ax.plot( 1434 | time_grid, 1435 | total_mean_cif + mean_survival_curve, 1436 | color="black", 1437 | linestyle="--", 1438 | label="Survival + total CIF", 1439 | ) 1440 | ax.legend(); 1441 | 1442 | # %% [markdown] 1443 | # So we see that our Gradient Boosting CIF estimator seems to be mostly unbiased as the sum of the mean CIF curves then mean any-event survival curve randomly fluctuates around 1.0. A more careful study would be required to see how the mean and the variance of the sum evolve when changing the size of the training set, the amount of censoring and the hyperparameters of the estimator. 1444 | # 1445 | # In particular, it's possible that it would be beneficial to tune the hyper-parameters of each cause-specific model indpendently using a validation and early stopping. 1446 | # 1447 | # Note: we could also attempt to constrain the total CIF and survival estimates to always sum to 1 by design but this would make it challenging (impossible?) to also constrain the model to yield monotonically increasing CIF curves as implemented in `GradientBoostedCIF`. This is left as future work. 1448 | 1449 | # %% [markdown] 1450 | # ### Model evaluation with the cause-specific Brier score 1451 | # 1452 | # At this time, neither scikit-survival nor lifelines provide an implementation of time-dependent Brier score adjusted for censoring in a competing risks setting. 1453 | # 1454 | # Let's compute the theoretical cumulative incidence from the true hazards of the data generating process. We start from our theoretical hazards: 1455 | 1456 | # %% 1457 | theoretical_hazards.shape 1458 | 1459 | # %% [markdown] 1460 | # To derive the theoretical cumulative incidence curves, we need to estimate the true survival functions from the any-event hazards. Then we integrate over time the produce of the cause specific hazards with the any-event survival function to derive the cause-specific cumulative incidence. 1461 | # 1462 | # We also need to interpolate them to our evaluation time grid. 1463 | # 1464 | # Finally we compute the IBS both for the cumulative incidence curves predicted by our `GradientBoostedCIF` models and the curves derived from the the true hazards. 1465 | 1466 | # %% 1467 | from models.gradient_boosted_cif import cif_integrated_brier_score 1468 | 1469 | 1470 | any_event_hazards = theoretical_hazards.sum(axis=0) 1471 | true_survival = np.exp(-any_event_hazards.cumsum(axis=-1)) 1472 | 1473 | for k in competing_risk_ids: 1474 | # Compute the integrated 1475 | gb_cif_ibs_k = cif_integrated_brier_score( 1476 | y_train_cr, 1477 | y_test_cr, 1478 | gb_cif_cumulative_incidence_curves[k], 1479 | time_grid, 1480 | event_of_interest=k, 1481 | ) 1482 | # Evaluate the interpolated cumulative incidence curve on the same 1483 | # test set: 1484 | theoretical_cumulated_incidence_curves_k = np.asarray([ 1485 | interp1d( 1486 | original_time_range, 1487 | ci_curve, 1488 | kind="previous", 1489 | bounds_error=False, 1490 | fill_value="extrapolate", 1491 | )(time_grid) 1492 | for ci_curve in (theoretical_hazards[k - 1] * true_survival).cumsum(axis=-1)[idx_test] 1493 | ]) 1494 | theoretical_cif_ibs_k = cif_integrated_brier_score( 1495 | y_train_cr, 1496 | y_test_cr, 1497 | theoretical_cumulated_incidence_curves_k, 1498 | time_grid, 1499 | event_of_interest=k, 1500 | ) 1501 | print( 1502 | f"[event {k}] IBS for GB CIF: {gb_cif_ibs_k:.4f}, " 1503 | f"IBS for True CIF: {theoretical_cif_ibs_k:.4f}" 1504 | ) 1505 | 1506 | # %% [markdown] 1507 | # By looking at the cause-specific IBS values, it seems that our model is already quite close to the optimal. Again, it's likely that this can be improved by increasing the training set size and tuning the hyper-parameters. 1508 | 1509 | # %% [markdown] 1510 | # ### Model inspection with Partial Dependence Plots (PDP) 1511 | # 1512 | # Partial dependence plots make it possible to visualize the impact of an intervention on individual numerical features. 1513 | 1514 | # %% 1515 | cif_models[1].estimator[:-1].get_feature_names_out() 1516 | 1517 | # %% 1518 | from sklearn.inspection import PartialDependenceDisplay 1519 | 1520 | for k in competing_risk_ids: 1521 | preprocessor_k = cif_models[k].estimator[:-1] 1522 | classifier_k = cif_models[k].estimator[-1] 1523 | classifier_k.set_params(time_horizon=1500, show_progressbar=False) 1524 | 1525 | disp = PartialDependenceDisplay.from_estimator( 1526 | classifier_k, 1527 | preprocessor_k.transform(X_test), 1528 | response_method="predict_proba", 1529 | features=["driver_skill", "usage_rate"], 1530 | feature_names=preprocessor_k.get_feature_names_out(), 1531 | ) 1532 | disp.bounding_ax_.set(title=f"Partial dependence for CIF_{k}") 1533 | for ax in disp.axes_.ravel(): 1534 | ax.set(ylim=(0, 1)) 1535 | 1536 | # %% [markdown] 1537 | # **Analysis** 1538 | # 1539 | # We observe the following: 1540 | # 1541 | # - a `usage_rate` increase seems to have a small yet positive effect on the incidence of events of type 1 and 2; 1542 | # - a `driver_skill` increase seems to cause a dramatic reduction on the incidence of events of type 2 and small positive effect on the incidence of type 1 and 3 events. 1543 | # 1544 | # This last point is a bit counter intuitive. Since we have access to the data generative process, we can check that the `drive_skill` variable has not impact on the type 1 and 3 hazards (manufacturing defects and fatigue induced failures). So the effect we observe on our models must come from the competition between events: trucks with skilled drivers have a higher relative incidence of type 1 and 3 events because they have a dramatic reduction in type 2 events (operational failures), hence have more opportunity to encounter the other kinds of failures. This highlights that competing risks make it even more challenging to interpret the results of such model inspections. 1545 | # 1546 | # Note that, at the time of writing, scikit-learn does not yet provide a user friendly way to study the impact of categorical variables (this will hopefull improve in version 1.3 and later). In particular it would interesting to see if the estimator for type 2 events could successfully model the quadratic interaction between drive skill and truck UX (via the model variable). 1547 | # 1548 | # Finally: **we cannot conclude on causal effects from the PDP of a classifier alone**. Indeed interventions on estimator inputs and measured effects on estimator predictions do not necessarily estimate true causal effects, had we the opportunity to intervene in the real world. **We would need to make additional assumptions** such as the structure of a causal graph that relates the treated input, the other input covariates and the outcome variable and the absence of hidden confounders. Furthermore we would also need to assume independence of the treatment assignment and the other covariates. If not, we try to use more adapted methods such as [doubly robust causal inference](https://matheusfacure.github.io/python-causality-handbook/12-Doubly-Robust-Estimation.html) to adjust for any such dependence. 1549 | 1550 | # %% [markdown] 1551 | # ***Exercise*** 1552 | # 1553 | # For each type of event and each datapoint in the test set, predict the time-to-event for a quantile of your choice. 1554 | # 1555 | # Then compare to the corresponding quantile of the time-to-event derived from the theoretical cumulative incidence curves. 1556 | # 1557 | # *Hint* do you expect quantile choices that yield undefined values? If so, how sensitive are each types of event to the choice of quantile? 1558 | # 1559 | # *Hint* you can reuse the model trained as `cif_models[k]` for the event of type `k`. Those expose a handy `.predict_quantile(X_test, quantile=q)` method if you wish. 1560 | 1561 | # %% 1562 | # compute quantile time-to-event predicted by model 1563 | 1564 | # TODO 1565 | 1566 | # %% 1567 | # measure quantile time-to-event on uncensored data 1568 | 1569 | 1570 | # TODO 1571 | 1572 | # %% 1573 | # compare the predictions with expected values, for instance using the mean absolute error metric 1574 | 1575 | 1576 | # TODO 1577 | 1578 | 1579 | 1580 | 1581 | 1582 | 1583 | # %% [markdown] 1584 | # ## Going further 1585 | # 1586 | # We encourage you to dive deeper in the documentation of the [lifelines](https://lifelines.readthedocs.io) and [scikit-survival](https://scikit-survival.readthedocs.io/) packages. 1587 | # 1588 | # You might be interested in the following notable alternatives not presented in this notebook: 1589 | # 1590 | # - XGBoost has a builtin handling of [survival analysis with censored data](https://xgboost.readthedocs.io/en/stable/tutorials/aft_survival_analysis.html). However it does only provide predictions at fixed time horizons and does not attempt to estimate the full survival function. 1591 | # - [XGBoost Survival Embeddings](https://loft-br.github.io/xgboost-survival-embeddings/index.html): another way to leverage gradient boosting for survival analysis. 1592 | # - [DeepHit](https://github.com/chl8856/DeepHit): neural network based, typically with good ranking power but not necessarily well calibrated. Can also handle competing risks. 1593 | # - [SurvTRACE](https://github.com/RyanWangZf/SurvTRACE): more recent transformer-based model. Can also handle competing risks. We did not yet evaluate how this performs from a calibration point of view (e.g. using cause-specific IBS). 1594 | -------------------------------------------------------------------------------- /notebooks/tutorial_part_2.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py:percent 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.14.5 10 | # kernelspec: 11 | # display_name: Python 3 (ipykernel) 12 | # language: python 13 | # name: python3 14 | # --- 15 | 16 | # %% [markdown] 17 | # # Survival Analysis Tutorial Part 2 18 | # 19 | # 20 | # The goal of this notebook is to extract implicit failure events from a raw activity (or "heart-beat") log using [Ibis](ibis-project.org/) and [DuckDB](https://duckdb.org) or [Polars](https://www.pola.rs/). 21 | # 22 | # It is often the case that the we are dealing with a raw **activity event log** for a pool of members/patients/customers/machines... where the event of interest (e.g. churn, death, hospital transfer, failure) only appears in negative via the lack of activity event for an extended period of time: **activity events are collected somewhat regularly as long as the "failure" event has not occured**. 23 | # 24 | # Our goal is to use a common data-wrangling technique named **sessionization** to infer implicit failure events and measure the duration between the start of activity recording until a failure event (or censoring). 25 | # 26 | # We will also see how censoring naturally occur when we extract time-slices of a sessionized dataset. 27 | # 28 | # Links to the slides: 29 | # 30 | # - https://docs.google.com/presentation/d/1pAFmAFiyTA0_-ZjWG1rImAX8lYJt_UnGgqXD-4H6Aqw/edit?usp=sharing 31 | 32 | # %% 33 | import ibis 34 | 35 | ibis.options.interactive = True 36 | ibis.__version__ 37 | 38 | # %% 39 | import duckdb 40 | 41 | duckdb.__version__ 42 | 43 | # %% 44 | from urllib.request import urlretrieve 45 | from pathlib import Path 46 | 47 | data_filepath = Path("wowah_data_raw.parquet") 48 | data_url = ( 49 | "https://storage.googleapis.com/ibis-tutorial-data/wowah_data/" 50 | "wowah_data_raw.parquet" 51 | ) 52 | 53 | if not data_filepath.exists(): 54 | print(f"Downloading {data_url}...") 55 | urlretrieve(data_url, data_filepath) 56 | else: 57 | print(f"Reusing downloaded {data_filepath}") 58 | 59 | # %% 60 | conn = ibis.duckdb.connect() # in-memory DuckDB 61 | event_log = conn.read_parquet(data_filepath) 62 | event_log 63 | 64 | # %% 65 | event_log.count() 66 | 67 | # %% 68 | from ibis import deferred as c 69 | 70 | 71 | entity_window = ibis.cumulative_window( 72 | group_by=c.char, order_by=c.timestamp 73 | ) 74 | threshold = ibis.interval(minutes=30) 75 | deadline_date = c.timestamp.lag().over(entity_window) + threshold 76 | 77 | ( 78 | event_log 79 | .select([c.char, c.timestamp]) 80 | .mutate(deadline_date=deadline_date) 81 | ) 82 | 83 | # %% 84 | ( 85 | event_log 86 | .select([c.char, c.timestamp]) 87 | .mutate( 88 | is_new_session=(c.timestamp > deadline_date).fillna(False) 89 | ) 90 | ) 91 | 92 | # %% 93 | ( 94 | event_log 95 | .select([c.char, c.timestamp]) 96 | .mutate( 97 | is_new_session=(c.timestamp > deadline_date).fillna(False) 98 | ) 99 | .mutate(session_id=c.is_new_session.sum().over(entity_window)) 100 | ) 101 | 102 | # %% 103 | entity_window = ibis.cumulative_window( 104 | group_by=c.char, order_by=c.timestamp 105 | ) 106 | threshold = ibis.interval(minutes=30) 107 | deadline_date = c.timestamp.lag().over(entity_window) + threshold 108 | is_new_session = (c.timestamp > deadline_date).fillna(False) 109 | 110 | sessionized = ( 111 | event_log 112 | .mutate(is_new_session=is_new_session) 113 | .mutate(session_id=c.is_new_session.sum().over(entity_window)) 114 | .drop("is_new_session") 115 | ) 116 | sessions = ( 117 | sessionized 118 | .group_by([c.char, c.session_id]) 119 | .order_by(c.timestamp) 120 | .aggregate( 121 | session_start_date=c.timestamp.min(), 122 | session_end_date=c.timestamp.max(), 123 | ) 124 | .order_by([c.char, c.session_start_date]) 125 | ) 126 | sessions 127 | 128 | 129 | # %% 130 | # ibis.show_sql(sessions) 131 | 132 | # %% 133 | def sessionize(table, threshold, entity_col, date_col): 134 | entity_window = ibis.cumulative_window( 135 | group_by=entity_col, order_by=date_col 136 | ) 137 | deadline_date = date_col.lag().over(entity_window) + threshold 138 | is_new_session = (date_col > deadline_date).fillna(False) 139 | 140 | return ( 141 | table 142 | .mutate(is_new_session=is_new_session) 143 | .mutate(session_id=c.is_new_session.sum().over(entity_window)) 144 | .drop("is_new_session") 145 | ) 146 | 147 | 148 | def extract_sessions(table, entity_col, date_col, session_col): 149 | return ( 150 | table 151 | .group_by([entity_col, session_col]) 152 | .aggregate( 153 | session_start_date=date_col.min(), 154 | session_end_date=date_col.max(), 155 | ) 156 | # XXX: we would like to compute session duration here but 157 | # it seems broken with Ibis + DuckDB at the moment... 158 | .order_by([entity_col, c.session_start_date]) 159 | ) 160 | 161 | 162 | def preprocess_event_log(event_log): 163 | return ( 164 | event_log 165 | .pipe( 166 | sessionize, 167 | threshold=ibis.interval(minutes=30), 168 | entity_col=c.char, 169 | date_col=c.timestamp, 170 | ) 171 | .pipe( 172 | extract_sessions, 173 | entity_col=c.char, 174 | date_col=c.timestamp, 175 | session_col=c.session_id, 176 | ) 177 | ) 178 | 179 | 180 | # %% 181 | # %time sessions = preprocess_event_log(event_log).cache() 182 | 183 | # %% 184 | # %time sessions.count() 185 | 186 | # %% 187 | sessions 188 | 189 | # %% 190 | first_observed_date = event_log.timestamp.max().execute() 191 | first_observed_date 192 | 193 | # %% 194 | last_observed_date = event_log.timestamp.max().execute() 195 | last_observed_date 196 | 197 | 198 | # %% 199 | def censor(sessions, censoring_date, threshold=ibis.interval(minutes=30), observation_duration=None): 200 | if observation_duration is not None: 201 | sessions = sessions.filter(c.session_start_date > censoring_date - observation_duration) 202 | return ( 203 | sessions 204 | .filter(c.session_start_date < censoring_date) 205 | .mutate( 206 | is_censored=censoring_date < (c.session_end_date + threshold), 207 | session_end_date=ibis.ifelse(c.session_end_date < censoring_date, c.session_end_date, censoring_date), 208 | ) 209 | # remove sessions that are two short 210 | .filter(c.session_end_date > c.session_start_date + ibis.interval(minutes=1)) 211 | .order_by(c.session_start_date) 212 | ) 213 | 214 | censor(sessions, last_observed_date).is_censored.sum() 215 | 216 | # %% 217 | from datetime import timedelta 218 | 219 | censor(sessions, last_observed_date - timedelta(days=54)).count() 220 | 221 | # %% 222 | censor(sessions, last_observed_date - timedelta(days=54), observation_duration=timedelta(days=5)).to_pandas() 223 | 224 | # %% 225 | # ibis.show_sql(preprocess_event_log(event_log)) 226 | 227 | # %% 228 | import polars as pl 229 | 230 | 231 | pl.__version__ 232 | 233 | # %% 234 | event_log_df = pl.read_parquet(data_filepath) 235 | event_log_df.head(5) 236 | 237 | # %% 238 | event_log_lazy_df = pl.scan_parquet(data_filepath) 239 | event_log_lazy_df.head(10) 240 | 241 | # %% 242 | event_log_lazy_df.head(10).collect() 243 | 244 | 245 | # %% 246 | def sessionize_pl(df, entity_col, date_col, threshold_minutes): 247 | sessionized = ( 248 | df.sort([entity_col, date_col]) 249 | .with_columns( 250 | [ 251 | (pl.col(date_col).diff().over(entity_col).dt.minutes() > threshold_minutes) 252 | .fill_null(False) 253 | .alias("is_new_session"), 254 | ] 255 | ) 256 | .with_columns( 257 | [ 258 | pl.col("is_new_session").cumsum().over(entity_col).alias("session_id"), 259 | ] 260 | ) 261 | .drop(["is_new_session"]) 262 | ) 263 | return sessionized 264 | 265 | 266 | def extract_sessions_pl( 267 | df, 268 | entity_col, 269 | date_col, 270 | session_col, 271 | metadata_cols=["race", "zone", "charclass", "guild"] 272 | ): 273 | sessions = ( 274 | df 275 | .sort(date_col) 276 | .groupby([entity_col, session_col]) 277 | .agg( 278 | [pl.col(mc).first().alias(mc) for mc in metadata_cols] 279 | + [ 280 | pl.col(date_col).min().alias("session_start_date"), 281 | pl.col(date_col).max().alias("session_end_date"), 282 | ] 283 | ) 284 | .with_columns( 285 | [ 286 | (pl.col("session_end_date") - pl.col("session_start_date")).alias("session_duration"), 287 | ] 288 | ) 289 | .sort([entity_col, "session_start_date"]) 290 | ) 291 | return sessions 292 | 293 | 294 | def preprocess_event_log_pl(df): 295 | return ( 296 | df 297 | .pipe( 298 | sessionize_pl, 299 | entity_col="char", 300 | date_col="timestamp", 301 | threshold_minutes=30, 302 | ) 303 | .pipe( 304 | extract_sessions_pl, 305 | entity_col="char", 306 | date_col="timestamp", 307 | session_col="session_id", 308 | ) 309 | ) 310 | 311 | 312 | # %time sessions_collected = preprocess_event_log_pl(event_log_lazy_df).collect() 313 | sessions_collected 314 | 315 | # %% 316 | first_observed_date = event_log_lazy_df.select("timestamp").min().collect().item() 317 | first_observed_date 318 | 319 | # %% 320 | last_observed_date = event_log_lazy_df.select("timestamp").max().collect().item() 321 | last_observed_date 322 | 323 | 324 | # %% 325 | def censor_pl(sessions, censoring_date, threshold_minutes=30, observation_days=None): 326 | if observation_days: 327 | start_date = censoring_date - timedelta(days=observation_days) 328 | sessions = sessions.filter(pl.col("session_start_date") > start_date) 329 | return ( 330 | sessions 331 | .filter(pl.col("session_start_date") < censoring_date) 332 | .with_columns( 333 | [ 334 | (((censoring_date - pl.col("session_end_date")).dt.minutes()) < threshold_minutes).alias("is_censored"), 335 | pl.min(pl.col("session_end_date"), censoring_date).alias("session_end_date"), 336 | ] 337 | ) 338 | .with_columns( 339 | [ 340 | (pl.col("session_end_date") - pl.col("session_start_date")).dt.minutes().alias("duration"), 341 | (pl.col("is_censored") == False).alias("event"), 342 | ] 343 | ) 344 | .filter(pl.col("duration") > 0) 345 | .sort("session_start_date") 346 | ) 347 | 348 | 349 | censor_pl(sessions_collected, last_observed_date) 350 | 351 | # %% 352 | censor_pl(sessions_collected, last_observed_date).select("is_censored").sum() 353 | 354 | # %% 355 | censor_pl(sessions_collected, last_observed_date - timedelta(days=42), observation_days=5) 356 | 357 | # %% [markdown] 358 | # ***Wrap-up exercise*** 359 | # 360 | # - Select 10 dates randomly from the beginning of January to the end of November (increment the first date with random number of days). For each sample date, define an observation window of 5 days: extract the censored data and concatenate those sessions into a training set; 361 | # 362 | # - Estimate and plot the average survival function using a Kaplan-Meier estimator; 363 | # 364 | # - Reiterate the KM estimation, but stratified on the `race` or the `charclass` features; 365 | # 366 | # - Fit a predictive survival model of your choice with adequate feature engineering on this training set; 367 | # 368 | # - Extract censored data from the last month of the original dataset and use it to measure the performance of your estimator with the metrics of your choice. Compare this to the Kaplan-Meier baseline. 369 | # 370 | # - Inspect which features are the most predictive, one way or another. 371 | 372 | # %% 373 | # TODO: write your code here 374 | -------------------------------------------------------------------------------- /notebooks/variables.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soda-inria/survival-analysis-benchmark/2633aa5e4e6433c73bb87e99a9f3aad1033726e7/notebooks/variables.png -------------------------------------------------------------------------------- /plot/brier_score.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import seaborn as sns 3 | 4 | sns.set_theme() 5 | sns.set_context("paper") 6 | 7 | 8 | def plot_brier_scores(df_lines): 9 | """Plot brier score curve for each estimator. 10 | 11 | Parameters 12 | ---------- 13 | df_lines: pd.DataFrame 14 | model | times | brier_scores 15 | 'times' and 'brier_scores' are numpy array 16 | 17 | Notes 18 | ----- 19 | 'df_tables' and 'df_lines' are loaded with 20 | `model_selection.cross_validation.get_all_results()` after 21 | cross validation with `model_selection.cross_validation.run_cv()` 22 | """ 23 | 24 | cols = df_lines.columns 25 | col_to_idx = dict(zip(cols, range(len(cols)))) 26 | 27 | fig, ax = plt.subplots() #figsize=(14, 5), dpi=300) 28 | for row in df_lines.values: 29 | ax.plot( 30 | row[col_to_idx["times"]], 31 | row[col_to_idx["brier_scores"]], 32 | label=row[col_to_idx["model"]], 33 | ) 34 | plt.xlabel("Duration (days)") 35 | plt.ylabel("Brier score") 36 | legend = plt.legend( 37 | bbox_to_anchor=(0, 1.02, 1, 0.2), 38 | loc="lower left", 39 | mode="expand", 40 | borderaxespad=0, 41 | ncol=6, 42 | facecolor='white', 43 | ) 44 | frame = legend.get_frame() 45 | frame.set_linewidth(2) 46 | 47 | return fig 48 | -------------------------------------------------------------------------------- /plot/individuals.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | 4 | 5 | def plot_individuals_survival_curve(df_tables, df_lines, y, n_indiv=5): 6 | """For each estimator, plot individual survival curves with 7 | event or censoring marker, for a sample of rows. 8 | 9 | Parameters 10 | ---------- 11 | df_tables: pd.DataFrame, 12 | | Method | IBS 13 | Useful to rank estimators plot by mean IBS. 14 | 15 | df_lines: pd.DataFrame 16 | | model | times | survival_probs 17 | Individual survival_probs to sample from and plot 18 | 19 | y: np.ndarray, 20 | Target vector, containing survival or 21 | censoring times, useful to plot markers 22 | 23 | n_indiv: int, 24 | Number of individual curves to display. 25 | 26 | Notes 27 | ----- 28 | 'df_tables' and 'df_lines' are loaded with 29 | `model_selection.cross_validation.get_all_results()` after 30 | cross validation with `model_selection.cross_validation.run_cv()` 31 | """ 32 | # use df_tables to sort df_lines 33 | df_tables["mean_IBS"] = df_tables["IBS"].str.split("±") \ 34 | .str[0].str.strip().astype(float) 35 | 36 | df_lines = df_lines.merge( 37 | df_tables[["Method", "mean_IBS"]], 38 | left_on="model", 39 | right_on="Method", 40 | ) 41 | df_lines.sort_values("mean_IBS", inplace=True) 42 | 43 | n_rows = df_lines.shape[0] 44 | 45 | fig, axes = plt.subplots( 46 | nrows=n_rows, 47 | ncols=1, 48 | #figsize=(10, 17), 49 | #dpi=300, 50 | constrained_layout=True, 51 | ) 52 | 53 | cols = df_lines.columns 54 | col_to_idx = dict(zip(cols, range(len(cols)))) 55 | 56 | # Some models has been tested on a small datasets, 57 | # so we need to use the first `max_indiv_id` rows to compare all models 58 | max_indiv_id = min([len(el) for el in df_lines["survival_probs"].values]) 59 | idxs_indiv = np.random.uniform(high=max_indiv_id, size=n_indiv).astype(int) 60 | 61 | for idx, row in enumerate(df_lines.values): 62 | 63 | if df_lines.shape[0] == 1: 64 | ax = axes 65 | else: 66 | ax = axes[idx] 67 | 68 | for jdx in idxs_indiv: 69 | times = row[col_to_idx["times"]] 70 | surv_probs = row[col_to_idx["survival_probs"]][jdx, :] 71 | ax.plot(times, surv_probs) 72 | color = ax.lines[-1].get_color() 73 | # place the dot on the curve for viz purposes 74 | is_event = y["event"][jdx] 75 | event_time = y["duration"][jdx] 76 | surv_prob_projected = surv_probs[ 77 | np.argmin( 78 | np.abs(times - event_time) 79 | ) 80 | ] 81 | ax.plot( 82 | event_time, 83 | surv_prob_projected, 84 | "^" if is_event else "o", 85 | color=color, 86 | ) 87 | ax.set_title(row[col_to_idx["model"]]) 88 | 89 | return fig --------------------------------------------------------------------------------