├── README.md ├── generate_clusters.py └── run_mortality_prediction.py /README.md: -------------------------------------------------------------------------------- 1 | ## Learning Tasks for Multitask Learning 2 | 3 | The code in this repository implements the models described in the paper *Learning Tasks for Multitask Learning: Heterogenous Patient Populations in the ICU* (KDD 2018). There are two files: 4 | 5 | 1. generate_clusters.py, which trains a sequence-to-sequence autoencoder on patient timeseries data to produce a dense representation, and then fits a Gaussian Mixture Model to the samples in this new space. 6 | 7 | 2. run_mortality_prediction.py, which contains methods to preprocess data, as well as train and run a predictive model to predict in-hospital mortality after a certain point, given patients' physiological timeseries data. 8 | 9 | For more information on the arguments required to run each of these files, use the --help flag. 10 | 11 | ### Data 12 | 13 | Without any modification, this code assumes that you have the following files in a 'data/' folder: 14 | 1. X.h5: an hdf file containing one row per patient per hour. Each row should include the columns {'subject_id', 'icustay_id', 'hours_in', 'hadm_id'} along with any additional features. 15 | 2. static.csv: a CSV file containing one row per patient. Should include {'subject_id', 'hadm_id', 'icustay_id', 'gender', 'age', 'ethnicity', 'first_careunit'}. 16 | 3. saps.csv: a CSV file containing one row per patient. Should include {'subject_id', 'hadm_id', 'icustay_id', 'sapsii'}. This data is found in the saps table in MIMIC III. 17 | 4. code_status.csv: a CSV file containing one row per patient. Should include {'subject_id', 'hadm_id', 'icustay_id', 'timecmo_chart', 'timecmo_nursingnote'}. This data is found in the code_status table of MIMIC III. 18 | -------------------------------------------------------------------------------- /generate_clusters.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"]="2" 3 | 4 | 5 | from numpy.random import seed 6 | seed(1) 7 | import numpy as np 8 | import argparse 9 | from keras.models import Model 10 | from keras.layers import Input, LSTM, RepeatVector 11 | from keras.optimizers import Adam 12 | from keras.callbacks import EarlyStopping 13 | from run_mortality_prediction import stratified_split, load_processed_data 14 | from sklearn.mixture import GaussianMixture 15 | import pickle 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--latent_dim", type=int, default=50, \ 20 | help='The embedding size, or latent dimension of the autoencoder. Type: int. Default: 50.') 21 | parser.add_argument("--ae_epochs", type=int, default=100, \ 22 | help='Number of epochs to train autoencoder. Type: int. Default: 100.') 23 | parser.add_argument("--ae_learning_rate", type=float, default=0.0001, \ 24 | help='Learning rate for autoencoder. Type: float. Default: 0.0001.') 25 | parser.add_argument("--num_clusters", type=int, default=3, \ 26 | help='Number of clusters for GMM. Type: int. Default: 3.') 27 | parser.add_argument("--gmm_tol", type=float, default=0.0001, 28 | help='The convergence threshold for the GMM. Type: float. Default: 0.0001.') 29 | parser.add_argument("--data_hours", type=int, default=24, \ 30 | help='The number of hours of data to use. \ 31 | Type: int. Default: 24.') 32 | parser.add_argument("--gap_time", type=int, default=12, help="Gap between data and when predictions are made. Type: int. Default: 12.") 33 | parser.add_argument("--save_to_fname", type=str, default='test_clusters.npy', \ 34 | help="Filename to save cluster memberships to. Type: String. Default: 'test_clusters.npy'") 35 | parser.add_argument("--train_val_random_seed", type=int, default=0, \ 36 | help="Random seed to use during train / val / split process. Type: int. Default: 0.") 37 | args = parser.parse_args() 38 | print(args) 39 | return args 40 | 41 | 42 | ########## CREATE AE MODEL ############################################################### 43 | ########################################################################################## 44 | 45 | def create_seq_ae(X_train, X_val, latent_dim, learning_rate): 46 | """ 47 | Build sequence autoencoder. 48 | Args: 49 | X_train (Numpy array): training data. (shape = n_samples x n_timesteps x n_features) 50 | X_val (Numpy array): validation data. 51 | latent_dim (int): hidden representation dimension. 52 | learning_rate (float): learning rate for training. 53 | Returns: 54 | encoder (Keras model): compiled model that takes original data as input and produces representation. 55 | sequence_autoencoder (Keras model): compiled autoencoder model. 56 | """ 57 | 58 | timesteps = X_train.shape[1] 59 | input_dim = X_train.shape[2] 60 | latent_dim = latent_dim 61 | 62 | inputs = Input(shape=(timesteps, input_dim)) 63 | encoded = LSTM(latent_dim)(inputs) 64 | 65 | decoded = RepeatVector(timesteps)(encoded) 66 | decoded = LSTM(input_dim, return_sequences=True)(decoded) 67 | 68 | sequence_autoencoder = Model(inputs, decoded) 69 | encoder = Model(inputs, encoded) 70 | 71 | 72 | sequence_autoencoder.compile(optimizer=Adam(lr=learning_rate), 73 | loss='mse') 74 | 75 | return encoder, sequence_autoencoder 76 | 77 | 78 | ########## RUN AE MODEL ################################################################## 79 | ########################################################################################## 80 | 81 | def train_seq_ae(X_train, X_val, FLAGS): 82 | """ 83 | Train a sequence to sequence autoencoder. 84 | Args: 85 | X_train (Numpy array): training data. (shape = n_samples x n_timesteps x n_features) 86 | X_val (Numpy array): validation data. 87 | FLAGS (dictionary): all provided arguments. 88 | Returns: 89 | encoder (Keras model): trained model to encode to latent space. 90 | sequence autoencoer (Keras model): trained autoencoder. 91 | """ 92 | encoder, sequence_autoencoder = create_seq_ae(X_train, X_val, FLAGS.latent_dim, FLAGS.ae_learning_rate) 93 | early_stopping = EarlyStopping(monitor='val_loss', patience=3) 94 | 95 | # fit the model 96 | print("Fitting Sequence Autoencoder ... ") 97 | sequence_autoencoder.fit(X_train, X_train, 98 | epochs=FLAGS.ae_epochs, 99 | batch_size=128, 100 | shuffle=True, 101 | callbacks=[early_stopping], 102 | validation_data=(X_val, X_val)) 103 | 104 | 105 | if not os.path.exists('clustering_models/'): 106 | os.makedirs('clustering_models/') 107 | 108 | encoder.save('clustering_models/encoder_' + str(FLAGS.data_hours)) 109 | sequence_autoencoder.save('clustering_models/seq_ae_' + str(FLAGS.data_hours)) 110 | return encoder, sequence_autoencoder 111 | 112 | ########## MAIN ########################################################################## 113 | ########################################################################################## 114 | 115 | if __name__ == "__main__": 116 | 117 | FLAGS = get_args() 118 | 119 | # Load Data 120 | X, Y, careunits, saps_quartile, subject_ids = load_processed_data(FLAGS.data_hours, FLAGS.gap_time) 121 | Y = Y.astype(int) 122 | cohort_col = careunits 123 | 124 | # Train, val, test split 125 | X_train, X_val, X_test, \ 126 | y_train, y_val, y_test, \ 127 | cohorts_train, cohorts_val, cohorts_test = stratified_split(X, Y, cohort_col, train_val_random_seed=FLAGS.train_val_random_seed) 128 | 129 | # Train autoencoder 130 | encoder, sequence_autoencoder = train_seq_ae(X_train, X_val, FLAGS) 131 | 132 | # Get Embeddings 133 | embedded_train = encoder.predict(X_train) 134 | embedded_all = encoder.predict(X) 135 | 136 | # Train GMM 137 | print("Fitting GMM ...") 138 | gm = GaussianMixture(n_components=FLAGS.num_clusters, tol=FLAGS.gmm_tol, verbose=True) 139 | gm.fit(embedded_train) 140 | pickle.dump(gm, open('clustering_models/gmm_' + str(FLAGS.data_hours), 'wb')) 141 | 142 | # Get cluster membership 143 | cluster_preds = gm.predict(embedded_all) 144 | 145 | if not os.path.exists('cluster_membership/'): 146 | os.makedirs('cluster_membership/') 147 | np.save('cluster_membership/' + FLAGS.save_to_fname, cluster_preds) 148 | -------------------------------------------------------------------------------- /run_mortality_prediction.py: -------------------------------------------------------------------------------- 1 | # Import things 2 | from numpy.random import seed 3 | seed(1) 4 | from tensorflow import set_random_seed 5 | set_random_seed(2) 6 | 7 | import os 8 | import sys 9 | import argparse 10 | import numpy as np 11 | import pandas as pd 12 | from sklearn.model_selection import train_test_split 13 | from sklearn.metrics import roc_auc_score 14 | import tensorflow as tf 15 | from keras.optimizers import Adam 16 | from keras import backend as K 17 | from keras.models import Sequential, load_model 18 | from keras.layers import Dense, Input 19 | from keras.callbacks import ModelCheckpoint, EarlyStopping 20 | from keras.models import Model 21 | from keras.layers.recurrent import LSTM 22 | 23 | np.set_printoptions(threshold=np.nan) 24 | 25 | INDEX_COLS = ['subject_id', 'icustay_id', 'hours_in', 'hadm_id'] 26 | 27 | # where the X, Y, static raw data is 28 | data_path = 'data/' 29 | 30 | # where you will save the processed data matrices 31 | save_data_path = 'data/mortality/' 32 | 33 | 34 | def get_args(): 35 | parser = argparse.ArgumentParser() 36 | 37 | parser.add_argument("--experiment_name", type=str, default='mortality_test', 38 | help="This will become the name of the folder where are the models and results \ 39 | are stored. Type: String. Default: 'mortality_test'.") 40 | parser.add_argument("--data_hours", type=int, default=24, 41 | help="The number of hours of data to use in making the prediction. \ 42 | Type: int. Default: 24.") 43 | parser.add_argument("--gap_time", type=int, default=12, \ 44 | help="The gap between data and when we are making predictions. Type: int. Default: 12.") 45 | parser.add_argument("--model_type", type=str, default='GLOBAL', 46 | help="One of {'GLOBAL', MULTITASK', 'SEPARATE'} indicating \ 47 | which type of model to run. Type: String.") 48 | parser.add_argument("--num_lstm_layers", type=int, default=1, 49 | help="Number of beginning LSTM layers, applies to all model types. \ 50 | Type: int. Default: 1.") 51 | parser.add_argument("--lstm_layer_size", type=int, default=16, 52 | help="Number of units in beginning LSTM layers, applies to all model types. \ 53 | Type: int. Default: 16.") 54 | parser.add_argument("--num_dense_shared_layers", type=int, default=0, 55 | help="Number of shared dense layers following LSTM layer(s), applies to \ 56 | all model types. Type: int. Default: 0.") 57 | parser.add_argument("--dense_shared_layer_size", type=int, default=0, 58 | help="Number of units in shared dense layers, applies to all model types. \ 59 | Type: int. Default: 0.") 60 | parser.add_argument("--num_multi_layers", type=int, default=0, 61 | help="Number of separate-task dense layers, only applies to multitask models. Currently \ 62 | only 0 or 1 separate-task dense layers are supported. Type: int. Default: 0.") 63 | parser.add_argument("--multi_layer_size", type=int, default=0, 64 | help="Number of units in separate-task dense layers, only applies to multitask \ 65 | models. Type: int. Default: 0.") 66 | parser.add_argument("--cohorts", type=str, default='careunit', 67 | help="One of {'careunit', 'saps', 'custom'}. Indicates whether to use pre-defined cohorts \ 68 | (careunits or saps quartile) or use a custom cohort membership (i.e. result of clustering). \ 69 | Type: String. Default: 'careunit'. ") 70 | parser.add_argument("--cohort_filepath", type=str, help="This is the filename containing a numpy \ 71 | array of length len(X), containing the cohort membership for each example in X. This file should be \ 72 | saved in the folder 'cluster_membership'. Only applies to cohorts == 'custom'. Type: str.") 73 | parser.add_argument("--sample_weights", action="store_true", default=False, help="This is an indicator \ 74 | flag to weight samples during training by their cohort's inverse frequency (i.e. smaller cohorts will be \ 75 | more highly weighted during training).") 76 | parser.add_argument("--include_cohort_as_feature", action="store_true", default=False, 77 | help="This is an indicator flag to include cohort membership as an additional feature in the matrix.") 78 | parser.add_argument("--epochs", type=int, default=30, 79 | help="Number of epochs to train for. Type: int. Default: 30.") 80 | parser.add_argument("--train_val_random_seed", type=int, default=0, 81 | help="Random seed to use during train / val / split process. Type: int. Default: 0.") 82 | parser.add_argument("--repeats_allowed", action="store_true", default=False, 83 | help="Indicator flag allowing training and evaluating of existing models. Without this flag, \ 84 | if you run a configuration for which you've already saved models & results, it will be skipped.") 85 | parser.add_argument("--no_val_bootstrap", action="store_true", default=False, 86 | help="Indicator flag turning off bootstrapping evaluation on the validation set. Without this flag, \ 87 | minimum, maximum and average AUCs on bootstrapped samples of the validation dataset are saved. With the flag, \ 88 | just one AUC on the actual validation set is saved.") 89 | parser.add_argument("--num_val_bootstrap_samples", type=int, default=100, 90 | help="Number of bootstrapping samples to evaluate on for the validation set. Type: int. Default: 100. ") 91 | parser.add_argument("--test_time", action="store_true", default=False, 92 | help="Indicator flag of whether we are in testing time. With this flag, we will load in the already trained model \ 93 | of the specified configuration, and evaluate it on the test set. ") 94 | parser.add_argument("--test_bootstrap", action="store_true", default=False, 95 | help="Indicator flag of whether to evaluate on bootstrapped samples of the test set, or just the single \ 96 | test set. Adding the flag will result in saving minimum, maximum and average AUCs on bo6otstrapped samples of the validation dataset. ") 97 | parser.add_argument("--num_test_bootstrap_samples", type=int, default=100, 98 | help="Number of bootstrapping samples to evaluate on for the test set. Type: int. Default: 100. ") 99 | parser.add_argument("--gpu_num", type=str, default='0', 100 | help="Limit GPU usage to specific GPUs. Specify multiple GPUs with the format '0,1,2'. Type: String. Default: '0'.") 101 | 102 | args = parser.parse_args() 103 | print(args) 104 | return args 105 | 106 | ################ HELPER FUNCTIONS ############################################### 107 | #################################################################################### 108 | 109 | 110 | def load_phys_data(): 111 | """ 112 | Loads X, Y, and static matrices into Pandas DataFrames 113 | Returns: 114 | X: Pandas DataFrame containing one row per patient per hour. 115 | Each row should include the columns {'subject_id', 'icustay_id', 'hours_in', 'hadm_id'} 116 | along with any additional features. 117 | static: Pandas DataFrame containing one row per patient. 118 | Should include {'subject_id', 'hadm_id', 'icustay_id'}. 119 | """ 120 | 121 | X = pd.read_hdf(data_path + 'X.h5', 'X') 122 | # Y = pd.read_hdf(data_path + 'Y.h5', 'Y') 123 | static = pd.DataFrame.from_csv(data_path + 'static.csv') 124 | 125 | if 'subject_id' not in X.columns: 126 | X = X.reset_index() 127 | X.columns = [fix_byte_data(c) for c in X.columns] 128 | # if 'subject_id' not in Y.columns: 129 | # Y = Y.reset_index() 130 | # Y.columns = [fix_byte_data(c) for c in Y.columns] 131 | 132 | static = static[static.subject_id.isin(np.unique(X.subject_id))] 133 | return X, static 134 | 135 | 136 | def categorize_ethnicity(ethnicity): 137 | """ 138 | Groups ethnicity sub-categories into 5 major categories. 139 | Args: 140 | ethnicity (str): string indicating patient ethnicity. 141 | Returns: 142 | string: ethnicity. Categorized into 5 main categories. 143 | """ 144 | 145 | if 'ASIAN' in ethnicity: 146 | ethnicity = 'ASIAN' 147 | elif 'WHITE' in ethnicity: 148 | ethnicity = 'WHITE' 149 | elif 'HISPANIC' in ethnicity: 150 | ethnicity = 'HISPANIC/LATINO' 151 | elif 'BLACK' in ethnicity: 152 | ethnicity = 'BLACK' 153 | else: 154 | ethnicity = 'OTHER' 155 | return ethnicity 156 | 157 | 158 | def make_discrete_values(mat): 159 | """ 160 | Converts numerical values into one-hot vectors of number of z-scores 161 | above/below the mean, aka physiological words (see Suresh et al 2017). 162 | Args: 163 | mat (Pandas DataFrame): Matrix of feature values including columns in 164 | INDEX_COLS as the first columns. 165 | Returns: 166 | DataFrame: X_categorized. A DataFrame where each features is a set of 167 | indicator columns signifying number of z-scores above or below the mean. 168 | """ 169 | 170 | normal_dict = mat.groupby(['subject_id']).mean().mean().to_dict() 171 | std_dict = mat.std().to_dict() 172 | feature_cols = mat.columns[len(INDEX_COLS):] 173 | print(feature_cols) 174 | X_words = mat.loc[:, feature_cols].apply( 175 | lambda x: transform_vals(x, normal_dict, std_dict), axis=0) 176 | mat.loc[:, feature_cols] = X_words 177 | X_categorized = pd.get_dummies(mat, columns=mat.columns[len(INDEX_COLS):]) 178 | na_columns = [col for col in X_categorized.columns if '_9' in col] 179 | X_categorized.drop(na_columns, axis=1, inplace=True) 180 | return X_categorized 181 | 182 | 183 | def transform_vals(x, normal_dict, std_dict): 184 | """ 185 | Helper function to convert values to z-scores between -4 and 4. 186 | Missing values are assigned 9. 187 | Args: 188 | param1 (int): The first parameter. 189 | param2 (str): The second parameter. 190 | Returns: 191 | bool: The return value. True for success, False otherwise. 192 | """ 193 | 194 | x = 1.0*(x - normal_dict[x.name])/std_dict[x.name] 195 | x = x.round() 196 | x = x.clip(-4, 4) 197 | x = x.fillna(9) 198 | x = x.round(0).astype(int) 199 | return x 200 | 201 | 202 | def categorize_age(age): 203 | """ 204 | Categorize age into windows. 205 | Args: 206 | age (int): A number. 207 | Returns: 208 | int: cat. The age category. 209 | """ 210 | 211 | if age > 10 and age <= 30: 212 | cat = 1 213 | elif age > 30 and age <= 50: 214 | cat = 2 215 | elif age > 50 and age <= 70: 216 | cat = 3 217 | else: 218 | cat = 4 219 | return cat 220 | 221 | 222 | def _pad_df(df, max_hr, pad_value=np.nan): 223 | """ Add dataframe with padding up to max stay. """ 224 | 225 | existing = set(df.index.get_level_values(1)) 226 | fill_hrs = set(range(max_hr)) - existing 227 | if len(fill_hrs) > 0: 228 | return fill_hrs 229 | else: 230 | return 0 231 | 232 | 233 | def fix_byte_data(s): 234 | """ Python 2/3 fix """ 235 | 236 | try: 237 | s = s.decode() 238 | except AttributeError: 239 | pass 240 | return s 241 | 242 | 243 | def stratified_split(X, Y, cohorts, train_val_random_seed=0): 244 | """ 245 | Return stratified split of X, Y, and a cohort membership array, stratified by outcome. 246 | Args: 247 | X (Numpy array): X matrix, shape = num patients x num timesteps x num features. 248 | Y (Numpy array): Y matrix, shape = num_patients. 249 | cohorts (Numpy array): array of cohort membership, shape = num_patients. 250 | train_val_random_seed (int): random seed for splitting. 251 | Returns: 252 | Numpy arrays: X_train, X_val, X_test, y_train, y_val, y_test, 253 | cohorts_train, cohorts_val, cohorts_test. 254 | """ 255 | 256 | X_train_val, X_test, y_train_val, y_test, \ 257 | cohorts_train_val, cohorts_test = \ 258 | train_test_split(X, Y, cohorts, test_size=0.2, 259 | random_state=train_val_random_seed, stratify=Y) 260 | 261 | X_train, X_val, y_train, y_val, \ 262 | cohorts_train, cohorts_val = \ 263 | train_test_split(X_train_val, y_train_val, cohorts_train_val, test_size=0.125, 264 | random_state=train_val_random_seed, stratify=y_train_val) 265 | 266 | return X_train, X_val, X_test, \ 267 | y_train, y_val, y_test, \ 268 | cohorts_train, cohorts_val, cohorts_test 269 | 270 | 271 | def generate_bootstrap_indices(X, y, split, num_bootstrap_samples=100): 272 | """ 273 | Generates and saves to file sets of indices for val or test bootstrapping. 274 | Args: 275 | X (Numpy array): X matrix, shape = num patients x num timesteps x num features. 276 | y (Numpy array): Y matrix, shape = num_patients. 277 | split (string): 'val' or 'test' indicating for which split to generate indices. 278 | num_bootstrap_samples (int): number indicating how many sets of bootstrap samples to generate. 279 | Returns: 280 | Numpy arrays: all_pos_samples, all_neg_samples. Contains num_bootstrap_samples indices 281 | of positive and negative examples. 282 | """ 283 | 284 | positive_X = X[np.where(y == 1)] 285 | negative_X = X[np.where(y == 0)] 286 | all_pos_samples = [] 287 | all_neg_samples = [] 288 | for i in range(num_bootstrap_samples): 289 | pos_samples = np.random.choice( 290 | len(positive_X), replace=True, size=len(positive_X)) 291 | neg_samples = np.random.choice( 292 | len(negative_X), replace=True, size=len(negative_X)) 293 | all_pos_samples.append(pos_samples) 294 | all_neg_samples.append(neg_samples) 295 | 296 | np.save(split + '_pos_bootstrap_samples_' + 297 | str(num_bootstrap_samples), np.array(all_pos_samples)) 298 | np.save(split + '_neg_bootstrap_samples_' + 299 | str(num_bootstrap_samples), np.array(all_neg_samples)) 300 | return all_pos_samples, all_neg_samples 301 | 302 | 303 | def get_bootstrapped_dataset(X, y, cohorts, index=0, test=False, num_bootstrap_samples=100): 304 | """ 305 | Returns a bootstrapped (sampled w replacement) dataset. 306 | Args: 307 | X (Numpy array): X matrix, shape = num patients x num timesteps x num features. 308 | y (Numpy array): Y matrix, shape = num_patients. 309 | cohorts (Numpy array): array of cohort membership, shape = num_patients. 310 | index (int): which bootstrap sample to look at. 311 | test (bool): 312 | num_bootstrap_samples (int): 313 | Returns: 314 | Numpy arrays: all_pos_samples, all_neg_samples. Contains num_bootstrap_samples indices 315 | of positive and negative examples. 316 | """ 317 | 318 | if index == 0: 319 | return X, y, cohorts 320 | 321 | positive_X = X[np.where(y == 1)] 322 | negative_X = X[np.where(y == 0)] 323 | positive_cohorts = cohorts[np.where(y == 1)] 324 | negative_cohorts = cohorts[np.where(y == 0)] 325 | positive_y = y[np.where(y == 1)] 326 | negative_y = y[np.where(y == 0)] 327 | 328 | split = 'test' if test else 'val' 329 | try: 330 | pos_samples = np.load( 331 | split + '_pos_bootstrap_samples_' + str(num_bootstrap_samples) + '.npy')[index] 332 | neg_samples = np.load( 333 | split + '_neg_bootstrap_samples_' + str(num_bootstrap_samples) + '.npy')[index] 334 | except: 335 | all_pos_samples, all_neg_samples = generate_bootstrap_indices( 336 | X, y, split, num_bootstrap_samples) 337 | pos_samples = all_pos_samples[index] 338 | neg_samples = all_neg_samples[index] 339 | 340 | positive_X_bootstrapped = positive_X[pos_samples] 341 | negative_X_bootstrapped = negative_X[neg_samples] 342 | all_X_bootstrappped = np.concatenate( 343 | (positive_X_bootstrapped, negative_X_bootstrapped)) 344 | all_y_bootstrapped = np.concatenate( 345 | (positive_y[pos_samples], negative_y[neg_samples])) 346 | all_cohorts_bootstrapped = np.concatenate( 347 | (positive_cohorts[pos_samples], negative_cohorts[neg_samples])) 348 | 349 | return all_X_bootstrappped, all_y_bootstrapped, all_cohorts_bootstrapped 350 | 351 | 352 | def bootstrap_predict(X_orig, y_orig, cohorts_orig, task, model, return_everything=False, test=False, all_tasks=[], num_bootstrap_samples=100): 353 | """ 354 | Evaluates model on each of the num_bootstrap_samples sets. 355 | Args: 356 | X_orig (Numpy array): The X matrix. 357 | y_orig (Numpy array): The y matrix. 358 | cohorts_orig (Numpy array): List of cohort membership for each X example. 359 | task (String/Int): task to evalute on (either 'all' to evalute on the entire dataset, or a specific task). 360 | model (Keras model): the model to evaluate. 361 | return_everything (bool): if True, return list of AUCs on all bootstrapped samples. If False, return [min auc, max auc, avg auc]. 362 | test (bool): if True, use the test bootstrap indices. 363 | all_tasks (list): list of the tasks (used for evaluating multitask model). 364 | num_bootstrap_samples (int): number of bootstrapped samples to evalute on. 365 | Returns: 366 | all_aucs OR min_auc, max_auc, avg_auc depending on the value of return_everything. 367 | """ 368 | 369 | all_aucs = [] 370 | 371 | for i in range(num_bootstrap_samples): 372 | X_bootstrap_sample, y_bootstrap_sample, cohorts_bootstrap_sample = get_bootstrapped_dataset( 373 | X_orig, y_orig, cohorts_orig, index=i, test=test, num_bootstrap_samples=num_bootstrap_samples) 374 | if task != 'all': 375 | X_bootstrap_sample_task = X_bootstrap_sample[cohorts_bootstrap_sample == task] 376 | y_bootstrap_sample_task = y_bootstrap_sample[cohorts_bootstrap_sample == task] 377 | cohorts_bootstrap_sample_task = cohorts_bootstrap_sample[cohorts_bootstrap_sample == task] 378 | else: 379 | X_bootstrap_sample_task = X_bootstrap_sample 380 | y_bootstrap_sample_task = y_bootstrap_sample 381 | cohorts_bootstrap_sample_task = cohorts_bootstrap_sample 382 | 383 | preds = model.predict(X_bootstrap_sample_task, batch_size=128) 384 | if len(preds) < len(y_bootstrap_sample_task): 385 | preds = get_correct_task_mtl_outputs( 386 | preds, cohorts_bootstrap_sample_task, all_tasks) 387 | 388 | try: 389 | auc = roc_auc_score(y_bootstrap_sample_task, preds) 390 | all_aucs.append(auc) 391 | except Exception as e: 392 | print(e) 393 | print('Skipped this sample.') 394 | 395 | avg_auc = np.mean(all_aucs) 396 | min_auc = min(all_aucs) 397 | max_auc = max(all_aucs) 398 | 399 | if return_everything: 400 | return all_aucs 401 | else: 402 | return min_auc, max_auc, avg_auc 403 | 404 | ################ MODEL DEFINITIONS ############################################### 405 | #################################################################################### 406 | 407 | 408 | def create_single_task_model(n_layers, units, num_dense_shared_layers, dense_shared_layer_size, input_dim, output_dim): 409 | """ 410 | Create a single task model with LSTM layer(s), shared dense layer(s), and sigmoided output. 411 | Args: 412 | n_layers (int): Number of initial LSTM layers. 413 | units (int): Number of units in each LSTM layer. 414 | num_dense_shared_layers (int): Number of dense layers following LSTM layer(s). 415 | dense_shared_layer_size (int): Number of units in each dense layer. 416 | input_dim (int): Number of features in the input. 417 | output_dim (int): Number of outputs (1 for binary tasks). 418 | Returns: 419 | final_model (Keras model): A compiled model with the provided architecture. 420 | """ 421 | 422 | # global model 423 | model = Sequential() 424 | 425 | # first layer 426 | if n_layers > 1: 427 | return_seq = True 428 | else: 429 | return_seq = False 430 | 431 | model.add(LSTM(units=units, activation='relu', 432 | input_shape=input_dim, return_sequences=return_seq)) 433 | 434 | # additional hidden layers 435 | for l in range(n_layers - 1): 436 | model.add(LSTM(units=units, activation='relu')) 437 | 438 | # additional dense layers 439 | for l in range(num_dense_shared_layers): 440 | model.add(Dense(units=dense_shared_layer_size, activation='relu')) 441 | 442 | # output layer 443 | model.add(Dense(units=output_dim, activation='sigmoid')) 444 | 445 | model.compile(loss='binary_crossentropy', 446 | optimizer=Adam(lr=.0001), 447 | metrics=['accuracy']) 448 | 449 | return model 450 | 451 | 452 | def create_multitask_model(input_dim, n_layers, units, num_dense_shared_layers, dense_shared_layer_size, n_multi_layers, multi_units, output_dim, tasks): 453 | """ 454 | Create a multitask model with LSTM layer(s), shared dense layer(s), separate dense layer(s) 455 | and separate sigmoided outputs. 456 | Args: 457 | input_dim (int): Number of features in the input. 458 | n_layers (int): Number of initial LSTM layers. 459 | units (int): Number of units in each LSTM layer. 460 | num_dense_shared_layers (int): Number of dense layers following LSTM layer(s). 461 | dense_shared_layer_size (int): Number of units in each dense layer. 462 | n_multi_layers (int): Number of task-specific dense layers. 463 | multi_layer_size (int): Number of units in each task-specific dense layer. 464 | output_dim (int): Number of outputs (1 for binary tasks). 465 | tasks (list): list of the tasks. 466 | Returns: 467 | final_model (Keras model): A compiled model with the provided architecture. 468 | """ 469 | 470 | tasks = [str(t) for t in tasks] 471 | n_tasks = len(tasks) 472 | 473 | # Input layer 474 | x_inputs = Input(shape=input_dim) 475 | 476 | # first layer 477 | if n_layers > 1: 478 | return_seq = True 479 | else: 480 | return_seq = False 481 | 482 | # Shared layers 483 | combined_model = LSTM(units, activation='relu', 484 | input_shape=input_dim, 485 | name='combined', return_sequences=return_seq)(x_inputs) 486 | 487 | for l in range(n_layers - 1): 488 | combined_model = LSTM(units, activation='relu')(combined_model) 489 | 490 | for l in range(num_dense_shared_layers): 491 | combined_model = Dense(dense_shared_layer_size, 492 | activation='relu')(combined_model) 493 | 494 | # Individual task layers 495 | if n_multi_layers == 0: 496 | # Only create task-specific output layer. 497 | output_layers = [] 498 | for task_num in range(n_tasks): 499 | output_layers.append(Dense(output_dim, activation='sigmoid', 500 | name=tasks[task_num])(combined_model)) 501 | 502 | else: 503 | # Also create task-specific dense layer. 504 | task_layers = [] 505 | for task_num in range(n_tasks): 506 | task_layers.append(Dense(multi_units, activation='relu', 507 | name=tasks[task_num])(combined_model)) 508 | 509 | output_layers = [] 510 | for task_layer_num in range(len(task_layers)): 511 | output_layers.append(Dense(output_dim, activation='sigmoid', 512 | name=str(tasks[task_layer_num]) + '_output')(task_layers[task_layer_num])) 513 | 514 | loss_fn = 'binary_crossentropy' 515 | learning_rate = 0.0001 516 | final_model = Model(inputs=x_inputs, outputs=output_layers) 517 | final_model.compile(loss=loss_fn, 518 | optimizer=Adam(lr=learning_rate), 519 | metrics=['accuracy']) 520 | 521 | return final_model 522 | 523 | 524 | def get_mtl_sample_weights(y, cohorts, all_tasks, sample_weights=None): 525 | """ 526 | Generates a dictionary of sample weights for the multitask model that masks out 527 | (and prevents training on) outputs corresponding to cohorts to which a given sample doesn't belong. 528 | Args: 529 | y (Numpy array): The y matrix. 530 | cohorts (Numpy array): cohort membership corresponding to each example, in the same order as y. 531 | all_tasks (list/Numpy array): list of all unique tasks. 532 | sample_weights (list/Numpy array): if samples should be weighted differently during training, 533 | provide a list w len == num_samples where each value is how much 534 | that value should be weighted. 535 | Returns: 536 | sw_dict (dictionary): Dictionary mapping task to list w len == num_samples, where each value is 0 if 537 | the corresponding example does not belong to that task, and either 1 or a sample weight 538 | value (if sample_weights != None) if it does. 539 | """ 540 | 541 | sw_dict = {} 542 | for task in all_tasks: 543 | task_indicator_col = (cohorts == task).astype(int) 544 | if sample_weights: 545 | task_indicator_col = np.array( 546 | task_indicator_col) * np.array(sample_weights) 547 | sw_dict[task] = task_indicator_col 548 | return sw_dict 549 | 550 | 551 | def get_correct_task_mtl_outputs(mtl_output, cohorts, all_tasks): 552 | """ 553 | Gets the output corresponding to the right task given the multitask output. Necessary since 554 | the MTL model will produce an output for each cohort's output, but we only care about the one the example 555 | actually belongs to. 556 | Args: 557 | mtl_output (Numpy array/list): the output of the multitask model. Should be of size n_tasks x n_samples. 558 | cohorts (Numpy array): list of cohort membership for each sample. 559 | all_tasks (list): unique list of tasks (should be in the same order that corresponds with that of the MTL model output.) 560 | Returns: 561 | mtl_output (Numpy array): an array of size n_samples x 1 where each value corresponds to the MTL model's 562 | prediction for the task that that sample belongs to. 563 | """ 564 | 565 | n_tasks = len(all_tasks) 566 | cohort_key = dict(zip(all_tasks, range(n_tasks))) 567 | mtl_output = np.array(mtl_output) 568 | mtl_output = mtl_output[[cohort_key[c] 569 | for c in cohorts], np.arange(len(cohorts))] 570 | return mtl_output 571 | 572 | ################ RUNNING MODELS ############################################### 573 | #################################################################################### 574 | 575 | 576 | def run_separate_models(X_train, y_train, cohorts_train, 577 | X_val, y_val, cohorts_val, 578 | X_test, y_test, cohorts_test, 579 | all_tasks, fname_keys, fname_results, 580 | FLAGS): 581 | """ 582 | Train and evaluate separate models for each task. 583 | Results are saved in FLAGS.experiment_name/results: 584 | - The numpy file ending in '_keys' contains the parameters for the model, 585 | and the numpy file ending in '_results' contains the validation AUCs for that 586 | configuration. 587 | - If you run multiple configurations for the same experiment name, 588 | those parameters and results will append to the same files. 589 | - At test time, results are saved into the file beginning 'test_auc_on_separate_'. 590 | The format of results will depend on whether you use bootstrapping or not. With bootstrapping, 591 | minimum, maximum and average AUCs are saved. Without, just the single AUC on the actual 592 | val / test dataset is saved. 593 | Args: 594 | X_train (Numpy array): The X matrix w training examples. 595 | y_train (Numpy array): The y matrix w training examples. 596 | cohorts_train (Numpy array): List of cohort membership for each validation example. 597 | X_val (Numpy array): The X matrix w validation examples. 598 | y_val (Numpy array): The y matrix w validation examples. 599 | cohorts_val (Numpy array): List of cohort membership for each validation example. 600 | X_test (Numpy array): The X matrix w testing examples. 601 | y_test (Numpy array): The y matrix w testing examples. 602 | cohorts_test (Numpy array): List of cohort membership for each testing example. 603 | all_tasks (Numpy array/list): List of tasks. 604 | fname_keys (String): filename where the model parameters will be saved. 605 | fname_results (String): filename where the model AUCs will be saved. 606 | FLAGS (dictionary): all the arguments. 607 | """ 608 | 609 | cohort_aucs = [] 610 | 611 | # if we're testing, just load the model and save results 612 | if FLAGS.test_time: 613 | for task in all_tasks: 614 | model_fname_parts = ['separate', str(task), 'lstm_shared', str(FLAGS.num_lstm_layers), 'layers', str(FLAGS.lstm_layer_size), 'units', 615 | str(FLAGS.num_dense_shared_layers), 'dense_shared', str(FLAGS.dense_shared_layer_size), 'dense_units', 'mortality'] 616 | model_path = FLAGS.experiment_name + \ 617 | '/models/' + "_".join(model_fname_parts) 618 | model = load_model(model_path) 619 | 620 | if FLAGS.test_bootstrap: 621 | all_aucs = bootstrap_predict(X_test, y_test, cohorts_test, task, model, return_everything=True, 622 | test=True, num_bootstrap_samples=FLAGS.num_test_bootstrap_samples) 623 | cohort_aucs.append(np.array(all_aucs)) 624 | 625 | else: 626 | x_test_in_task = X_test[cohorts_test == task] 627 | y_test_in_task = y_test[cohorts_test == task] 628 | 629 | y_pred = model.predict(x_test_in_task) 630 | auc = roc_auc_score(y_test_in_task, y_pred) 631 | cohort_aucs.append(auc) 632 | 633 | suffix = 'single' if not FLAGS.test_bootstrap else 'all' 634 | test_auc_fname = 'test_auc_on_separate_' + suffix 635 | np.save(FLAGS.experiment_name + '/results/' + 636 | test_auc_fname, cohort_aucs) 637 | return 638 | 639 | # otherwise, create and train a model 640 | for task in all_tasks: 641 | 642 | # get training data from cohort 643 | x_train_in_task = X_train[cohorts_train == task] 644 | y_train_in_task = y_train[cohorts_train == task] 645 | 646 | x_val_in_task = X_val[cohorts_val == task] 647 | y_val_in_task = y_val[cohorts_val == task] 648 | 649 | # create & fit model 650 | model = create_single_task_model(FLAGS.num_lstm_layers, FLAGS.lstm_layer_size, 651 | FLAGS.num_dense_shared_layers, FLAGS.dense_shared_layer_size, X_train.shape[1:], 1) 652 | model_fname_parts = ['separate', str(task), 'lstm_shared', str(FLAGS.num_lstm_layers), 'layers', str(FLAGS.lstm_layer_size), 'units', 653 | str(FLAGS.num_dense_shared_layers), 'dense_shared', str(FLAGS.dense_shared_layer_size), 'dense_units', 'mortality'] 654 | model_dir = FLAGS.experiment_name + \ 655 | '/checkpoints/' + "_".join(model_fname_parts) 656 | if not os.path.exists(model_dir): 657 | os.makedirs(model_dir) 658 | model_fname = model_dir + '/{epoch:02d}-{val_loss:.2f}.hdf5' 659 | checkpointer = ModelCheckpoint( 660 | model_fname, monitor='val_loss', verbose=1) 661 | early_stopping = EarlyStopping(monitor='val_loss', patience=4) 662 | model.fit(x_train_in_task, y_train_in_task, epochs=FLAGS.epochs, batch_size=100, 663 | callbacks=[checkpointer, early_stopping], 664 | validation_data=(x_val_in_task, y_val_in_task)) 665 | 666 | # make validation predictions & evaluate 667 | preds_for_cohort = model.predict(x_val_in_task, batch_size=128) 668 | 669 | print('AUC of separate model for ', task, ':') 670 | if FLAGS.no_val_bootstrap: 671 | try: 672 | auc = roc_auc_score(y_val_in_task, preds_for_cohort) 673 | except: 674 | auc = np.nan 675 | 676 | cohort_aucs.append(auc) 677 | else: 678 | min_auc, max_auc, avg_auc = bootstrap_predict( 679 | X_val, y_val, cohorts_val, task, model, return_everything=False, num_bootstrap_samples=FLAGS.num_val_bootstrap_samples) 680 | cohort_aucs.append(np.array([min_auc, max_auc, avg_auc])) 681 | auc = avg_auc 682 | print("(min/max/average):") 683 | 684 | print(cohort_aucs[-1]) 685 | 686 | model.save(FLAGS.experiment_name + '/models/' + 687 | "_".join(model_fname_parts)) 688 | 689 | # save results to a file 690 | current_run_params = [FLAGS.num_lstm_layers, FLAGS.lstm_layer_size, 691 | FLAGS.num_dense_shared_layers, FLAGS.dense_shared_layer_size] 692 | try: 693 | separate_model_results = np.load(fname_results) 694 | separate_model_key = np.load(fname_keys) 695 | separate_model_results = np.concatenate( 696 | (separate_model_results, np.expand_dims(cohort_aucs, 0))) 697 | separate_model_key = np.concatenate( 698 | (separate_model_key, np.array([current_run_params]))) 699 | 700 | except: 701 | separate_model_results = np.expand_dims(cohort_aucs, 0) 702 | separate_model_key = np.array([current_run_params]) 703 | 704 | np.save(fname_results, separate_model_results) 705 | np.save(fname_keys, separate_model_key) 706 | print('Saved separate results.') 707 | 708 | 709 | def run_global_model(X_train, y_train, cohorts_train, 710 | X_val, y_val, cohorts_val, 711 | X_test, y_test, cohorts_test, 712 | all_tasks, fname_keys, fname_results, 713 | FLAGS): 714 | """ 715 | Train and evaluate global model. 716 | Results are saved in FLAGS.experiment_name/results: 717 | - The numpy file ending in '_keys' contains the parameters for the model, 718 | and the numpy file ending in '_results' contains the validation AUCs for that 719 | configuration. 720 | - If you run multiple configurations for the same experiment name, 721 | those parameters and results will append to the same files. 722 | - At test time, results are saved into the file beginning 'test_auc_on_global_'. 723 | The format of results will depend on whether you use bootstrapping or not. With bootstrapping, 724 | minimum, maximum and average AUCs are saved. Without, just the single AUC on the actual 725 | val / test dataset is saved. 726 | Args: 727 | X_train (Numpy array): The X matrix w training examples. 728 | y_train (Numpy array): The y matrix w training examples. 729 | cohorts_train (Numpy array): List of cohort membership for each validation example. 730 | X_val (Numpy array): The X matrix w validation examples. 731 | y_val (Numpy array): The y matrix w validation examples. 732 | cohorts_val (Numpy array): List of cohort membership for each validation example. 733 | X_test (Numpy array): The X matrix w testing examples. 734 | y_test (Numpy array): The y matrix w testing examples. 735 | cohorts_test (Numpy array): List of cohort membership for each testing example. 736 | all_tasks (Numpy array/list): List of tasks. 737 | fname_keys (String): filename where the model parameters will be saved. 738 | fname_results (String): filename where the model AUCs will be saved. 739 | FLAGS (dictionary): all the arguments. 740 | """ 741 | 742 | model_fname_parts = ['global', 'lstm_shared', str(FLAGS.num_lstm_layers), 'layers', str(FLAGS.lstm_layer_size), 'units', 743 | str(FLAGS.num_dense_shared_layers), 'dense_shared', str(FLAGS.dense_shared_layer_size), 'dense_units', 'mortality'] 744 | 745 | if FLAGS.test_time: 746 | model_path = FLAGS.experiment_name + \ 747 | '/models/' + "_".join(model_fname_parts) 748 | model = load_model(model_path) 749 | cohort_aucs = [] 750 | y_pred = model.predict(X_test) 751 | 752 | # all bootstrapped AUCs 753 | for task in all_tasks: 754 | if FLAGS.test_bootstrap: 755 | all_aucs = bootstrap_predict(X_test, y_test, cohorts_test, task, model, return_everything=True, 756 | test=True, num_bootstrap_samples=FLAGS.num_test_bootstrap_samples) 757 | cohort_aucs.append(np.array(all_aucs)) 758 | else: 759 | y_pred_in_cohort = y_pred[cohorts_test == task] 760 | y_true_in_cohort = y_test[cohorts_test == task] 761 | auc = roc_auc_score(y_true_in_cohort, y_pred_in_cohort) 762 | cohort_aucs.append(auc) 763 | 764 | if FLAGS.test_bootstrap: 765 | # Macro AUC 766 | cohort_aucs = np.array(cohort_aucs) 767 | cohort_aucs = np.concatenate( 768 | (cohort_aucs, np.expand_dims(np.mean(cohort_aucs, axis=0), 0))) 769 | 770 | # Micro AUC 771 | all_micro_aucs = bootstrap_predict(X_test, y_test, cohorts_test, 'all', model, 772 | return_everything=True, test=True, num_bootstrap_samples=FLAGS.num_test_bootstrap_samples) 773 | cohort_aucs = np.concatenate( 774 | (cohort_aucs, np.array([all_micro_aucs]))) 775 | 776 | else: 777 | # Macro AUC 778 | macro_auc = np.mean(cohort_aucs) 779 | cohort_aucs.append(macro_auc) 780 | 781 | # Micro AUC 782 | micro_auc = roc_auc_score(y_test, y_pred) 783 | cohort_aucs.append(micro_auc) 784 | 785 | suffix = 'single' if not FLAGS.test_bootstrap else 'all' 786 | test_auc_fname = 'test_auc_on_global_' + suffix 787 | np.save(FLAGS.experiment_name + '/results/' + 788 | test_auc_fname, cohort_aucs) 789 | return 790 | 791 | model = create_single_task_model(FLAGS.num_lstm_layers, FLAGS.lstm_layer_size, 792 | FLAGS.num_dense_shared_layers, FLAGS.dense_shared_layer_size, X_train.shape[1:], 1) 793 | early_stopping = EarlyStopping(monitor='val_loss', patience=4) 794 | model_dir = FLAGS.experiment_name + \ 795 | '/checkpoints/' + "_".join(model_fname_parts) 796 | if not os.path.exists(model_dir): 797 | os.makedirs(model_dir) 798 | model_fname = model_dir + '/{epoch:02d}-{val_loss:.2f}.hdf5' 799 | checkpointer = ModelCheckpoint(model_fname, monitor='val_loss', verbose=1) 800 | 801 | model.fit(X_train, y_train, 802 | epochs=FLAGS.epochs, batch_size=100, 803 | sample_weight=samp_weights, 804 | callbacks=[checkpointer, early_stopping], 805 | validation_data=(X_val, y_val)) 806 | 807 | model.save(FLAGS.experiment_name + '/models/' + 808 | "_".join(model_fname_parts)) 809 | 810 | cohort_aucs = [] 811 | y_pred = model.predict(X_val) 812 | for task in all_tasks: 813 | print('Global Model AUC on ', task, ':') 814 | if FLAGS.no_val_bootstrap: 815 | try: 816 | auc = roc_auc_score( 817 | y_val[cohorts_val == task], y_pred[cohorts_val == task]) 818 | except: 819 | auc = np.nan 820 | cohort_aucs.append(auc) 821 | else: 822 | min_auc, max_auc, avg_auc = bootstrap_predict( 823 | X_val, y_val, cohorts_val, task, model, num_bootstrap_samples=FLAGS.num_val_bootstrap_samples) 824 | cohort_aucs.append(np.array([min_auc, max_auc, avg_auc])) 825 | print ("(min/max/average): ") 826 | 827 | print(cohort_aucs[-1]) 828 | 829 | cohort_aucs = np.array(cohort_aucs) 830 | 831 | # Add Macro AUC 832 | cohort_aucs = np.concatenate( 833 | (cohort_aucs, np.expand_dims(np.nanmean(cohort_aucs, axis=0), 0))) 834 | 835 | # Add Micro AUC 836 | if FLAGS.no_val_bootstrap: 837 | micro_auc = roc_auc_score(y_val, y_pred) 838 | cohort_aucs = np.concatenate((cohort_aucs, np.array([micro_auc]))) 839 | else: 840 | min_auc, max_auc, avg_auc = bootstrap_predict( 841 | X_val, y_val, cohorts_val, 'all', model, num_bootstrap_samples=FLAGS.num_val_bootstrap_samples) 842 | cohort_aucs = np.concatenate( 843 | (cohort_aucs, np.array([[min_auc, max_auc, avg_auc]]))) 844 | 845 | # Save Results 846 | current_run_params = [FLAGS.num_lstm_layers, FLAGS.lstm_layer_size, 847 | FLAGS.num_dense_shared_layers, FLAGS.dense_shared_layer_size] 848 | try: 849 | print('appending results.') 850 | global_model_results = np.load(fname_results) 851 | global_model_key = np.load(fname_keys) 852 | global_model_results = np.concatenate( 853 | (global_model_results, np.expand_dims(cohort_aucs, 0))) 854 | global_model_key = np.concatenate( 855 | (global_model_key, np.array([current_run_params]))) 856 | 857 | except Exception as e: 858 | global_model_results = np.expand_dims(cohort_aucs, 0) 859 | global_model_key = np.array([current_run_params]) 860 | 861 | np.save(fname_results, global_model_results) 862 | np.save(fname_keys, global_model_key) 863 | print('Saved global results.') 864 | 865 | 866 | def run_multitask_model(X_train, y_train, cohorts_train, 867 | X_val, y_val, cohorts_val, 868 | X_test, y_test, cohorts_test, 869 | all_tasks, fname_keys, fname_results, 870 | FLAGS): 871 | """ 872 | Train and evaluate multitask model. 873 | Results are saved in FLAGS.experiment_name/results: 874 | - The numpy file ending in '_keys' contains the parameters for the model, 875 | and the numpy file ending in '_results' contains the validation AUCs for that 876 | configuration. 877 | - If you run multiple configurations for the same experiment name, 878 | those parameters and results will append to the same files. 879 | - At test time, results are saved into the file beginning 'test_auc_on_separate_'. 880 | The format of results will depend on whether you use bootstrapping or not. With bootstrapping, 881 | minimum, maximum and average AUCs are saved. Without, just the single AUC on the actual 882 | val / test dataset is saved. 883 | Args: 884 | X_train (Numpy array): The X matrix w training examples. 885 | y_train (Numpy array): The y matrix w training examples. 886 | cohorts_train (Numpy array): List of cohort membership for each validation example. 887 | X_val (Numpy array): The X matrix w validation examples. 888 | y_val (Numpy array): The y matrix w validation examples. 889 | cohorts_val (Numpy array): List of cohort membership for each validation example. 890 | X_test (Numpy array): The X matrix w testing examples. 891 | y_test (Numpy array): The y matrix w testing examples. 892 | cohorts_test (Numpy array): List of cohort membership for each testing example. 893 | all_tasks (Numpy array/list): List of tasks. 894 | fname_keys (String): filename where the model parameters will be saved. 895 | fname_results (String): filename where the model AUCs will be saved. 896 | FLAGS (dictionary): all the arguments. 897 | """ 898 | 899 | model_fname_parts = ['mtl', 'lstm_shared', str(FLAGS.num_lstm_layers), 'layers', str(FLAGS.lstm_layer_size), 'units', 900 | 'dense_shared', str(FLAGS.num_dense_shared_layers), 'layers', str( 901 | FLAGS.dense_shared_layer_size), 'dense_units', 902 | 'specific', str(FLAGS.num_multi_layers), 'layers', str(FLAGS.multi_layer_size), 'specific_units', 'mortality'] 903 | 904 | n_tasks = len(np.unique(cohorts_train)) 905 | cohort_key = dict(zip(all_tasks, range(n_tasks))) 906 | 907 | if FLAGS.test_time: 908 | model_path = FLAGS.experiment_name + \ 909 | '/models/' + "_".join(model_fname_parts) 910 | model = load_model(model_path) 911 | y_pred = model.predict(X_test) 912 | 913 | cohort_aucs = [] 914 | for task in all_tasks: 915 | if FLAGS.test_bootstrap: 916 | all_aucs = bootstrap_predict(X_test, y_test, cohorts_test, 917 | task=task, model=model, return_everything=True, test=True, 918 | all_tasks=all_tasks, 919 | num_bootstrap_samples=FLAGS.num_test_bootstrap_samples) 920 | cohort_aucs.append(np.array(all_aucs)) 921 | else: 922 | y_pred_in_cohort = y_pred[cohorts_test == 923 | task, cohort_key[task]] 924 | y_true_in_cohort = y_test[cohorts_test == task] 925 | auc = roc_auc_score(y_true_in_cohort, y_pred_in_cohort) 926 | cohort_aucs.append(auc) 927 | 928 | if FLAGS.test_bootstrap: 929 | cohort_aucs = np.array(cohort_aucs) 930 | cohort_aucs = np.concatenate( 931 | (cohort_aucs, np.expand_dims(np.mean(cohort_aucs, axis=0), 0))) 932 | 933 | all_micro_aucs = bootstrap_predict(X_test, y_test, cohorts_test, 'all', model, return_everything=True, test=True, 934 | all_tasks=all_tasks, num_bootstrap_samples=FLAGS.num_test_bootstrap_samples) 935 | cohort_aucs = np.concatenate( 936 | (cohort_aucs, np.array([all_micro_aucs]))) 937 | 938 | else: 939 | macro_auc = np.mean(cohort_aucs) 940 | cohort_aucs.append(macro_auc) 941 | micro_auc = roc_auc_score(y_test, y_pred[np.arange(len(y_test)), [ 942 | cohort_key[c] for c in cohorts_test]]) 943 | cohort_aucs.append(micro_auc) 944 | 945 | suffix = 'single' if not FLAGS.test_bootstrap else 'all' 946 | test_auc_fname = 'test_auc_on_multitask_' + suffix 947 | np.save(FLAGS.experiment_name + '/results/' + 948 | test_auc_fname, cohort_aucs) 949 | return 950 | 951 | # model 952 | mtl_model = create_multitask_model(X_train.shape[1:], FLAGS.num_lstm_layers, 953 | FLAGS.lstm_layer_size, FLAGS.num_dense_shared_layers, FLAGS.dense_shared_layer_size, 954 | FLAGS.num_multi_layers, FLAGS.multi_layer_size, output_dim=1, tasks=all_tasks) 955 | 956 | early_stopping = EarlyStopping(monitor='val_loss', patience=4) 957 | 958 | model_dir = FLAGS.experiment_name + \ 959 | '/checkpoints/' + "_".join(model_fname_parts) 960 | if not os.path.exists(model_dir): 961 | os.makedirs(model_dir) 962 | model_fname = model_dir + '/{epoch:02d}-{val_loss:.2f}.hdf5' 963 | checkpointer = ModelCheckpoint(model_fname, monitor='val_loss', verbose=1) 964 | mtl_model.fit(X_train, [y_train for i in range(n_tasks)], 965 | batch_size=100, 966 | epochs=FLAGS.epochs, 967 | verbose=1, 968 | sample_weight=get_mtl_sample_weights( 969 | y_train, cohorts_train, all_tasks, sample_weights=samp_weights), 970 | callbacks=[early_stopping, checkpointer], 971 | validation_data=(X_val, [y_val for i in range(n_tasks)])) 972 | 973 | mtl_model.save(FLAGS.experiment_name + '/models/' + 974 | "_".join(model_fname_parts)) 975 | 976 | cohort_aucs = [] 977 | 978 | y_pred = get_correct_task_mtl_outputs( 979 | mtl_model.predict(X_val), cohorts_val, all_tasks) 980 | 981 | # task aucs 982 | for task in all_tasks: 983 | print('Multitask AUC on', task, ': ') 984 | if FLAGS.no_val_bootstrap: 985 | y_pred_in_task = y_pred[cohorts_val == task] 986 | try: 987 | auc = roc_auc_score(y_val[cohorts_val == task], y_pred_in_task) 988 | except: 989 | auc = np.nan 990 | cohort_aucs.append(auc) 991 | else: 992 | min_auc, max_auc, avg_auc = bootstrap_predict( 993 | X_val, y_val, cohorts_val, task, mtl_model, all_tasks=all_tasks, num_bootstrap_samples=FLAGS.num_val_bootstrap_samples) 994 | cohort_aucs.append(np.array([min_auc, max_auc, avg_auc])) 995 | print("(min/max/average):") 996 | 997 | print(cohort_aucs[-1]) 998 | 999 | # macro average 1000 | cohort_aucs = np.array(cohort_aucs) 1001 | cohort_aucs = np.concatenate( 1002 | (cohort_aucs, np.expand_dims(np.nanmean(cohort_aucs, axis=0), 0))) 1003 | 1004 | # micro average 1005 | if FLAGS.no_val_bootstrap: 1006 | cohort_aucs = np.concatenate( 1007 | (cohort_aucs, np.array([roc_auc_score(y_val, y_pred)]))) 1008 | else: 1009 | min_auc, max_auc, avg_auc = bootstrap_predict( 1010 | X_val, y_val, cohorts_val, 'all', mtl_model, all_tasks=all_tasks, num_bootstrap_samples=FLAGS.num_val_bootstrap_samples) 1011 | cohort_aucs = np.concatenate( 1012 | (cohort_aucs, np.array([[min_auc, max_auc, avg_auc]]))) 1013 | 1014 | current_run_params = [FLAGS.num_lstm_layers, FLAGS.lstm_layer_size, FLAGS.num_dense_shared_layers, 1015 | FLAGS.dense_shared_layer_size, FLAGS.num_multi_layers, FLAGS.multi_layer_size] 1016 | 1017 | try: 1018 | multitask_model_results = np.load(fname_results) 1019 | multitask_model_key = np.load(fname_keys) 1020 | multitask_model_results = np.concatenate( 1021 | (multitask_model_results, np.expand_dims(cohort_aucs, 0))) 1022 | multitask_model_key = np.concatenate( 1023 | (multitask_model_key, np.array([current_run_params]))) 1024 | 1025 | except: 1026 | multitask_model_results = np.expand_dims(cohort_aucs, 0) 1027 | multitask_model_key = np.array([current_run_params]) 1028 | 1029 | np.save(fname_results, multitask_model_results) 1030 | np.save(fname_keys, multitask_model_key) 1031 | print('Saved multitask results.') 1032 | 1033 | ################ LOAD & PROCESS DATA ############################################### 1034 | #################################################################################### 1035 | 1036 | 1037 | def load_processed_data(data_hours=24, gap_time=12): 1038 | """ 1039 | Either read pre-processed data from a saved folder, or load in the raw data and preprocess it. 1040 | Should have the files 'saps.csv' (with columns 'subject_id', 'hadm_id', 'icustay_id', 'sapsii') 1041 | and 'code_status.csv' (with columns 'subject_id', 'hadm_id', 'icustay_id', timednr_chart, 'timecmo_chart', 'timecmo_nursingnote') 1042 | in the local directory. 1043 | 1044 | Args: 1045 | data_hours (int): hours of data to use for predictions. 1046 | gap_time (int): gap between last data hour and time of prediction. 1047 | Returns: 1048 | X (Numpy array): matrix of data of size n_samples x n_timesteps x n_features. 1049 | Y (Numpy array): binary array of len n_samples corresponding to in hospital mortality after the gap time. 1050 | careunits (Numpy array): array w careunit membership of each sample. 1051 | saps_quartile (Numpy array): array w SAPS quartile of each sample. 1052 | subject_ids (Numpy array): subject_ids corresponding to each row of the X/Y/careunits/saps_quartile arrays. 1053 | """ 1054 | save_data_path = 'data/mortality_' + str(data_hours) + '/' 1055 | 1056 | # see if we already have the data matrices saved 1057 | try: 1058 | X = np.load(save_data_path + 'X.npy') 1059 | careunits = np.load(save_data_path + 'careunits.npy') 1060 | saps_quartile = np.load(save_data_path + 'saps_quartile.npy') 1061 | subject_ids = np.load(save_data_path + 'subject_ids.npy') 1062 | Y = np.load(save_data_path + 'Y.npy') 1063 | print('Loaded data from ' + save_data_path) 1064 | print('shape of X: ', X.shape) 1065 | 1066 | # otherwise create them 1067 | except Exception as e: 1068 | data_cutoff = data_hours 1069 | mort_cutoff = data_hours + gap_time 1070 | 1071 | X, static = load_phys_data() 1072 | 1073 | # Add SAPS Score to static matrix 1074 | saps = pd.read_csv('data/saps.csv') 1075 | ser, bins = pd.qcut(saps.sapsii, 4, retbins=True, labels=False) 1076 | saps['sapsii_quartile'] = pd.cut( 1077 | saps.sapsii, bins=bins, labels=False, include_lowest=True) 1078 | saps = saps[['subject_id', 'hadm_id', 'icustay_id', 'sapsii_quartile']] 1079 | static = pd.merge(static, saps, how='left', on=[ 1080 | 'subject_id', 'hadm_id', 'icustay_id']) 1081 | 1082 | # Add Mortality Outcome 1083 | deathtimes = static[['subject_id', 'hadm_id', 1084 | 'icustay_id', 'deathtime', 'dischtime']].dropna() 1085 | deathtimes_valid = deathtimes[deathtimes.dischtime >= 1086 | deathtimes.deathtime] 1087 | deathtimes_valid['mort_hosp_valid'] = True 1088 | cmo = pd.read_csv('data/code_status.csv') 1089 | cmo = cmo[cmo.cmo > 0] 1090 | cmo['timednr_chart'] = pd.to_datetime(cmo.timednr_chart) 1091 | cmo['timecmo_chart'] = pd.to_datetime(cmo.timecmo_chart) 1092 | cmo['timecmo_nursingnote'] = pd.to_datetime(cmo.timecmo_nursingnote) 1093 | cmo['cmo_min_time'] = cmo.loc[:, [ 1094 | 'timednr_chart', 'timecmo_chart', 'timecmo_nursingnote']].min(axis=1) 1095 | all_mort_times = pd.merge(deathtimes_valid, cmo, on=['subject_id', 'hadm_id', 'icustay_id'], how='outer')[ 1096 | ['subject_id', 'hadm_id', 'icustay_id', 'deathtime', 'dischtime', 'cmo_min_time']] 1097 | all_mort_times['deathtime'] = pd.to_datetime(all_mort_times.deathtime) 1098 | all_mort_times['cmo_min_time'] = pd.to_datetime( 1099 | all_mort_times.cmo_min_time) 1100 | all_mort_times['min_mort_time'] = all_mort_times.loc[:, 1101 | ['deathtime', 'cmo_min_time']].min(axis=1) 1102 | min_mort_time = all_mort_times[[ 1103 | 'subject_id', 'hadm_id', 'icustay_id', 'min_mort_time']] 1104 | static = pd.merge(static, min_mort_time, on=[ 1105 | 'subject_id', 'hadm_id', 'icustay_id'], how='left') 1106 | static['mort_hosp_valid'] = np.invert(np.isnat(static.min_mort_time)) 1107 | 1108 | # For those who died, filter for at least 36 hours of data 1109 | static['time_til_mort'] = pd.to_datetime( 1110 | static.min_mort_time) - pd.to_datetime(static.intime) 1111 | static['time_til_mort'] = static.time_til_mort.apply( 1112 | lambda x: x.total_seconds()/3600) 1113 | 1114 | static['time_in_icu'] = pd.to_datetime( 1115 | static.dischtime) - pd.to_datetime(static.intime) 1116 | static['time_in_icu'] = static.time_in_icu.apply( 1117 | lambda x: x.total_seconds()/3600) 1118 | 1119 | static = static[((static.time_in_icu >= data_cutoff) & ( 1120 | static.mort_hosp_valid == False)) | (static.time_til_mort >= mort_cutoff)] 1121 | 1122 | # Make discrete values and cut off stay at 24 hours 1123 | X_discrete = make_discrete_values(X) 1124 | X_discrete = X_discrete[X_discrete.hours_in < data_cutoff] 1125 | X_discrete = X_discrete[[ 1126 | c for c in X_discrete.columns if c not in ['hadm_id', 'icustay_id']]] 1127 | 1128 | # Pad people whose records stop early 1129 | test = X_discrete.set_index(['subject_id', 'hours_in']) 1130 | extra_hours = test.groupby(level=0).apply(_pad_df, data_cutoff) 1131 | extra_hours = extra_hours[extra_hours != 0].reset_index() 1132 | extra_hours.columns = ['subject_id', 'pad_hrs'] 1133 | pad_tuples = [] 1134 | for s in extra_hours.subject_id: 1135 | for hr in list(extra_hours[extra_hours.subject_id == s].pad_hrs)[0]: 1136 | pad_tuples.append((s, hr)) 1137 | pad_df = pd.DataFrame(0, index=pd.MultiIndex.from_tuples( 1138 | pad_tuples, names=('subject_id', 'hours_in')), columns=test.columns) 1139 | new_df = pd.concat([test, pad_df], axis=0) 1140 | 1141 | # get the static vars we want, make them discrete columns 1142 | static_to_keep = static[['subject_id', 'gender', 'age', 'ethnicity', 1143 | 'sapsii_quartile', 'first_careunit', 'mort_hosp_valid']] 1144 | static_to_keep.loc[:, 'ethnicity'] = static_to_keep['ethnicity'].apply( 1145 | categorize_ethnicity) 1146 | static_to_keep.loc[:, 'age'] = static_to_keep['age'].apply( 1147 | categorize_age) 1148 | static_to_keep = pd.get_dummies(static_to_keep, columns=[ 1149 | 'gender', 'age', 'ethnicity']) 1150 | 1151 | # merge the phys with static 1152 | X_full = pd.merge(new_df.reset_index(), static_to_keep, 1153 | on='subject_id', how='inner') 1154 | X_full = X_full.set_index(['subject_id', 'hours_in']) 1155 | 1156 | # print mortality per careunit 1157 | mort_by_careunit = X_full.groupby( 1158 | 'subject_id')['first_careunit', 'mort_hosp_valid'].max() 1159 | for cu in mort_by_careunit.first_careunit.unique(): 1160 | print(cu + ": " + str(np.sum(mort_by_careunit[mort_by_careunit.first_careunit == cu].mort_hosp_valid)) + ' out of ' + str( 1161 | len(mort_by_careunit[mort_by_careunit.first_careunit == cu]))) 1162 | 1163 | # create Y and cohort matrices 1164 | subject_ids = X_full.index.get_level_values(0).unique() 1165 | Y = X_full[['mort_hosp_valid']].groupby(level=0).max() 1166 | careunits = X_full[['first_careunit']].groupby(level=0).max() 1167 | saps_quartile = X_full[['sapsii_quartile']].groupby(level=0).max() 1168 | Y = Y.reindex(subject_ids) 1169 | careunits = careunits.reindex(subject_ids) 1170 | saps_quartile = saps_quartile.reindex(subject_ids) 1171 | 1172 | # delete those columns from the X matrix 1173 | X_full = X_full.loc[:, X_full.columns != 'mort_hosp_valid'] 1174 | X_full = X_full.loc[:, X_full.columns != 'sapsii_quartile'] 1175 | X_full = X_full.loc[:, X_full.columns != 'first_careunit'] 1176 | 1177 | feature_names = X_full.columns 1178 | np.save('feature_names.npy', feature_names) 1179 | 1180 | # get the data as a np matrix of size num_examples x timesteps x features 1181 | X_full_matrix = np.reshape( 1182 | X_full.as_matrix(), (len(subject_ids), data_cutoff, -1)) 1183 | print("Shape of X: ") 1184 | print(X_full_matrix.shape) 1185 | 1186 | # print feature values 1187 | print("Features : ") 1188 | print(np.array(X_full.columns)) 1189 | 1190 | print(subject_ids) 1191 | print(Y.index) 1192 | print(careunits.index) 1193 | 1194 | print("Number of positive examples : ", len(Y[Y == 1])) 1195 | 1196 | if not os.path.exists(save_data_path): 1197 | os.makedirs(save_data_path) 1198 | 1199 | np.save(save_data_path + 'X.npy', X_full_matrix) 1200 | np.save(save_data_path + 'careunits.npy', 1201 | np.squeeze(careunits.as_matrix(), axis=1)) 1202 | np.save(save_data_path + 'saps_quartile.npy', 1203 | np.squeeze(saps_quartile.as_matrix(), axis=1)) 1204 | np.save(save_data_path + 'subject_ids.npy', np.array(subject_ids)) 1205 | np.save(save_data_path + 'Y.npy', np.squeeze(Y.as_matrix(), axis=1)) 1206 | 1207 | X = X_full_matrix 1208 | 1209 | return X, Y, careunits, saps_quartile, subject_ids 1210 | 1211 | 1212 | ################ RUN THINGS #################################################### 1213 | #################################################################################### 1214 | if __name__ == "__main__": 1215 | 1216 | FLAGS = get_args() 1217 | 1218 | # Limit GPU usage. 1219 | os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_num 1220 | config = tf.ConfigProto() 1221 | config.gpu_options.allow_growth = True # Don't use all GPUs 1222 | config.allow_soft_placement = True # Enable manual control 1223 | K.tensorflow_backend.set_session(tf.Session(config=config)) 1224 | 1225 | # Make folders for the results & models 1226 | for folder in ['results', 'models', 'checkpoints']: 1227 | if not os.path.exists(os.path.join(FLAGS.experiment_name, folder)): 1228 | os.makedirs(os.path.join(FLAGS.experiment_name, folder)) 1229 | 1230 | # The file that we'll save model configurations to 1231 | sw = 'with_sample_weights' if FLAGS.sample_weights else 'no_sample_weights' 1232 | sw = '' if FLAGS.model_type == 'SEPARATE' else sw 1233 | fname_keys = FLAGS.experiment_name + '/results/' + \ 1234 | '_'.join([FLAGS.model_type.lower(), 'model_keys', sw]) + '.npy' 1235 | fname_results = FLAGS.experiment_name + '/results/' + \ 1236 | '_'.join([FLAGS.model_type.lower(), 'model_results', sw]) + '.npy' 1237 | 1238 | # Check that we haven't already run this configuration 1239 | if os.path.exists(fname_keys) and not FLAGS.repeats_allowed: 1240 | model_key = np.load(fname_keys) 1241 | current_run = [FLAGS.num_lstm_layers, FLAGS.lstm_layer_size, 1242 | FLAGS.num_dense_shared_layers, FLAGS.dense_shared_layer_size] 1243 | if FLAGS.model_type == "MULTITASK": 1244 | current_run = current_run + \ 1245 | [FLAGS.num_multi_layers, FLAGS.multi_layer_size] 1246 | print('Now running :', current_run) 1247 | print('Have already run: ', model_key.tolist()) 1248 | if current_run in model_key.tolist(): 1249 | print('Have already run this configuration. Now skipping this one.') 1250 | sys.exit(0) 1251 | 1252 | # Load Data 1253 | X, Y, careunits, saps_quartile, subject_ids = load_processed_data( 1254 | FLAGS.data_hours, FLAGS.gap_time) 1255 | Y = Y.astype(int) 1256 | 1257 | # Split 1258 | if FLAGS.cohorts == 'careunit': 1259 | cohort_col = careunits 1260 | elif FLAGS.cohorts == 'saps': 1261 | cohort_col = saps_quartile 1262 | elif FLAGS.cohorts == 'custom': 1263 | cohort_col = np.load('cluster_membership/' + FLAGS.cohort_filepath) 1264 | cohort_col = np.array([str(c) for c in cohort_col]) 1265 | 1266 | # Include cohort membership as an additional feature 1267 | if FLAGS.include_cohort_as_feature: 1268 | cohort_col_onehot = pd.get_dummies(cohort_col).as_matrix() 1269 | cohort_col_onehot = np.expand_dims(cohort_col_onehot, axis=1) 1270 | cohort_col_onehot = np.tile(cohort_col_onehot, (1, 24, 1)) 1271 | X = np.concatenate((X, cohort_col_onehot), axis=-1) 1272 | 1273 | # Train, val, test split 1274 | X_train, X_val, X_test, \ 1275 | y_train, y_val, y_test, \ 1276 | cohorts_train, cohorts_val, cohorts_test = stratified_split( 1277 | X, Y, cohort_col, train_val_random_seed=FLAGS.train_val_random_seed) 1278 | 1279 | # Sample Weights 1280 | task_weights = dict() 1281 | all_tasks = np.unique(cohorts_train) 1282 | for cohort in all_tasks: 1283 | num_in_cohort = len(np.where(cohorts_train == cohort)[0]) 1284 | print("Number of people in cohort " + 1285 | str(cohort) + ": " + str(num_in_cohort)) 1286 | task_weights[cohort] = len(X_train)*1.0/num_in_cohort 1287 | 1288 | if FLAGS.sample_weights: 1289 | samp_weights = np.array([task_weights[cohort] 1290 | for cohort in cohorts_train]) 1291 | 1292 | else: 1293 | samp_weights = None 1294 | 1295 | # Run model 1296 | run_model_args = [X_train, y_train, cohorts_train, 1297 | X_val, y_val, cohorts_val, 1298 | X_test, y_test, cohorts_test, 1299 | all_tasks, fname_keys, fname_results, 1300 | FLAGS] 1301 | 1302 | if FLAGS.model_type == 'SEPARATE': 1303 | run_separate_models(*run_model_args) 1304 | elif FLAGS.model_type == 'GLOBAL': 1305 | run_global_model(*run_model_args) 1306 | elif FLAGS.model_type == 'MULTITASK': 1307 | run_multitask_model(*run_model_args) 1308 | --------------------------------------------------------------------------------