├── latent_patient_trajectories ├── __init__.py ├── BERT │ ├── __init__.py │ ├── constants.py │ ├── data_processor.py │ ├── load_and_yield_embeddings.py │ ├── model.py │ └── continuous_pretraining_data_processor.py ├── representation_learner │ ├── __init__.py │ ├── constants.py │ ├── sample_task_generalizability_config.json │ ├── sample_hyperopt_config.json │ ├── utils.py │ ├── fts_decoder.py │ ├── fine_tune.py │ ├── single_task.py │ ├── model.py │ ├── task_generalizability.py │ ├── validate_weights.py │ ├── adapted_model.py │ └── run_model.py ├── pytorch_helpers.py ├── utils.py └── constants.py ├── Sample Args ├── mimic_MT_FT │ ├── task_generalizability_config.json │ ├── task_generalizability_exp_args.json │ └── task_generalizability_model_base_args.json ├── cmo_mimic_ST │ ├── example_command.sh │ └── args.json └── eicu_masked_PT │ └── args.json ├── Scripts ├── run_model.py ├── fine_tune_task.py ├── evaluate.py ├── task_generalizability.py └── hyperopt_model.py ├── .gitignore ├── env.yml └── README.md /latent_patient_trajectories/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /latent_patient_trajectories/BERT/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/constants.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Sample Args/mimic_MT_FT/task_generalizability_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpus_per_model": 1, 3 | "models_per_gpu": 1, 4 | "gpus": [0, 0, 0, 1, 1, 2, 2, 3, 3, 3] 5 | } 6 | -------------------------------------------------------------------------------- /Sample Args/cmo_mimic_ST/example_command.sh: -------------------------------------------------------------------------------- 1 | /crimea/conda_envs/latent_patient_trajectories/bin/python run_model.py --do_load_from_dir --run_dir ../Sample\ Args/cmo_mimic_ST 2 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/sample_task_generalizability_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "gpus_per_model": 1, 3 | "models_per_gpu": 2, 4 | "gpus": [0, 1, 3] 5 | } 6 | -------------------------------------------------------------------------------- /Scripts/run_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | from tqdm import tqdm 5 | 6 | from latent_patient_trajectories.constants import * 7 | from latent_patient_trajectories.representation_learner.args import * 8 | from latent_patient_trajectories.representation_learner.run_model import * 9 | 10 | 11 | if __name__=="__main__": 12 | args = Args.from_commandline() 13 | 14 | main(args, tqdm=tqdm) 15 | -------------------------------------------------------------------------------- /Scripts/fine_tune_task.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | from latent_patient_trajectories.constants import * 8 | from latent_patient_trajectories.representation_learner.args import FineTuneArgs 9 | from latent_patient_trajectories.representation_learner.fine_tune import * 10 | 11 | 12 | if __name__=="__main__": 13 | args = FineTuneArgs.from_commandline() 14 | main(args, tqdm) 15 | -------------------------------------------------------------------------------- /Scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | from tqdm import tqdm 5 | 6 | from latent_patient_trajectories.constants import * 7 | from latent_patient_trajectories.representation_learner.args import * 8 | from latent_patient_trajectories.representation_learner.evaluator import * 9 | 10 | 11 | if __name__=="__main__": 12 | args = EvalArgs.from_commandline() 13 | print(args.run_dir) 14 | main(args, tqdm=tqdm, datasets=None) 15 | -------------------------------------------------------------------------------- /Scripts/task_generalizability.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | from latent_patient_trajectories.constants import * 8 | from latent_patient_trajectories.representation_learner.args import * 9 | from latent_patient_trajectories.representation_learner.task_generalizability import * 10 | 11 | 12 | if __name__=="__main__": 13 | args = TaskGeneralizabilityArgs.from_commandline() 14 | main(args) 15 | -------------------------------------------------------------------------------- /latent_patient_trajectories/pytorch_helpers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class MultitaskHead(nn.Module): 4 | def __init__(self, config, task_dims): 5 | super().__init__() 6 | self.task_layers = nn.ModuleDict({ 7 | task: nn.Linear(config.hidden_size, task_dim) for task, task_dim in task_dims.items() 8 | }) 9 | 10 | def forward(self, pooled_output): 11 | return {task: layer(pooled_output) for task, layer in self.task_layers.items()} 12 | -------------------------------------------------------------------------------- /Sample Args/mimic_MT_FT/task_generalizability_exp_args.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_dir": "/crimea/mmd/comprehensive_MTL_EHR/Sample Args/mimic_MT_FT", 3 | "do_eicu": false, 4 | "rotation": 0, 5 | "do_eval": false, 6 | "do_train": false, 7 | "do_fine_tune": true, 8 | "do_fine_tune_eval": true, 9 | "do_frozen_representation": true, 10 | "do_free_representation": true, 11 | "do_match_FT_train_windows": false, 12 | "slurm": false, 13 | "partition": "p100", 14 | "slurm_args": "", 15 | "do_small_data": true, 16 | "do_imbalanced_sex_data": false, 17 | "do_imbalanced_race_data": false, 18 | "train_embedding_after": -1, 19 | "do_single_task": false 20 | } 21 | -------------------------------------------------------------------------------- /Scripts/hyperopt_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../') 3 | 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | from latent_patient_trajectories.constants import * 8 | from latent_patient_trajectories.representation_learner.args import * 9 | from latent_patient_trajectories.representation_learner.hyperparameter_search import * 10 | 11 | if __name__=="__main__": 12 | args = HyperparameterSearchArgs.from_commandline() 13 | 14 | trials = main(args, tqdm=tqdm) 15 | 16 | with open(os.path.join(args.search_dir, str(args.rotation), 'trials.pkl'), mode='wb') as f: 17 | pickle.dump(trials, f) 18 | 19 | print(trials.best_trial['result']) 20 | print(trials.best_trial['misc']['vals']) 21 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/sample_hyperopt_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "do_add_cls_analog": {"method": "choice", "params": [true, false]}, 3 | "notes": {"method": "constant", "params": "no_notes"}, 4 | "batch_size": {"method": "constant", "params": 512}, 5 | "epochs": {"method": "constant", "params": 1}, 6 | "batches_per_gradient": {"method": "quniform", "params": [1, 10]}, 7 | "do_train_note_bert": {"method": "choice", "params": [true, false]}, 8 | "in_dim": {"method": "quniform", "params": [16, 128]}, 9 | "hidden_size_multiplier": {"method": "quniform", "params": [4, 32]}, 10 | "intermediate_size": {"method": "quniform", "params": [16, 256]}, 11 | "num_attention_heads": {"method": "quniform", "params": [2, 8]}, 12 | "num_hidden_layers": {"method": "quniform", "params": [2, 8]}, 13 | "learning_rate": {"method": "loguniform", "params": [-5, -1]}, 14 | "note_bert_lr_reduce": {"method": "uniform", "params": [1, 4]} 15 | } 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | ./**/train_means.pickle 107 | ./Scripts/timeseries_encoder/train_means.pickle 108 | ./Scripts/timeseries_encoder/train_stds.pickle 109 | 110 | 111 | ./Scripts/timeseries_encoder/tmp_output/* 112 | ./**/*.pickle 113 | 114 | ./*/*/Archive/* 115 | 116 | # tags 117 | tags 118 | 119 | # results 120 | results 121 | -------------------------------------------------------------------------------- /latent_patient_trajectories/BERT/constants.py: -------------------------------------------------------------------------------- 1 | import enum, os 2 | 3 | from ..constants import * 4 | 5 | #BERT_MODEL_LOCATION = os.path.join( 6 | # os.environ['ML4H_BASE'], 7 | # 'pretrained_models', 8 | # 'pretrained_bert_tf', 9 | # 'biobert_pretrain_output_all_notes_150000' 10 | #) 11 | 12 | BERT_MODEL_LOCATION_ALT = os.path.join( 13 | RUNS_DIR, 14 | 'BERT', 15 | 'biobert_pretrain_output_all_notes_150000' 16 | ) 17 | BERT_MODEL_LOCATION = BERT_MODEL_LOCATION_ALT 18 | 19 | 20 | PRETRAINED_BERT_HIDDEN_DIM = 768 21 | 22 | NOTES_AS_SINGLE_SENTENCE_FILENAME = 'notes_single_sentence.hdf' 23 | NOTES_AS_SENTENCE_SEQS_FILENAME = 'notes_split_sentences.hdf' 24 | NOTES_AS_SENTENCE_VEC_SEQS_FILENAME = 'all_sentences_as_vecs.hdf' 25 | BERT_RUNS_DIR = os.path.join(RUNS_DIR, 'BERT') 26 | 27 | # TODO: why do we need PAD or UNK? 28 | UNK, SEP, PAD, CLS, MASK = "[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]" 29 | CONTROL_TOKENS = [UNK, SEP, PAD, CLS, MASK] 30 | 31 | NUM_CONTROL_TOKENS = len(CONTROL_TOKENS) 32 | def set_at(i, v=1): 33 | def f(arr): 34 | arr[0, i] = v # TODO: Maybe generalize? 35 | return arr 36 | return f 37 | CONTROL_VECTOR_PREFIXES = { 38 | token: set_at(i, 1) for i, token in enumerate(CONTROL_TOKENS) 39 | #token: np.array(([0]*i) + [1] + ([0]*(len(CONTROL_TOKENS)-i))) for i, token in enumerate(CONTROL_TOKENS) 40 | } 41 | 42 | # For reading notes as sentence seqs: 43 | ALL = 'all' 44 | 45 | # Notes 46 | NOTE_ORDER_COLS = ['chartdate', 'charttime'] 47 | NOTE_ID_COLS = [ICUSTAY_ID, HADM_ID, 'category'] + NOTE_ORDER_COLS 48 | NOTE_ID = 'note_id' 49 | 50 | # NoteBERT 51 | STATIC_CLINICAL_BERT_RUNS_DIR = os.path.join(BERT_RUNS_DIR, 'static_clinical_BERT') 52 | DATASET_FILENAME = 'train_dataset.torch' 53 | PROCESSOR_FILENAME = 'processor.pkl' 54 | EXAMPLES_FILENAME = 'examples.pkl' 55 | 56 | SEQUENCES_ORDERED = 'Sequence Order' 57 | class SequenceOrderType(enum.Enum): 58 | ORDERED_AND_FROM_SAME_STREAM = 0 59 | NOT_ORDERED_BUT_FROM_SAME_STREAM = 1 60 | NOT_FROM_SAME_STREAM = 2 61 | 62 | LABEL_ENUMS = { 63 | SEQUENCES_ORDERED: SequenceOrderType, 64 | } 65 | -------------------------------------------------------------------------------- /Sample Args/mimic_MT_FT/task_generalizability_model_base_args.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_seq_len": 24, 3 | "modeltype": "GRU", 4 | "run_dir": "/crimea/mmd/comprehensive_MTL_EHR/Sample Args/mimic_MT_FT", 5 | "model_file_template": "model", 6 | "do_overwrite": true, 7 | "rotation": 0, 8 | "dataset_dir": "/crimea/latent_patient_trajectories/dataset/rotations/no_notes/0", 9 | "num_dataloader_workers": 6, 10 | "do_eicu": false, 11 | "epochs": 29, 12 | "do_train": true, 13 | "do_eval_train": true, 14 | "do_eval_tuning": true, 15 | "do_eval_test": true, 16 | "train_save_every": 1, 17 | "batches_per_gradient": 1, 18 | "set_to_eval_mode": "", 19 | "notes": "no_notes", 20 | "do_train_note_bert": true, 21 | "in_dim": 32, 22 | "hidden_size": 239, 23 | "intermediate_size": 128, 24 | "num_attention_heads": 4, 25 | "num_hidden_layers": 2, 26 | "batch_size": 253, 27 | "learning_rate": 0.0012814754999746147, 28 | "do_learning_rate_decay": false, 29 | "learning_rate_decay": 0.13723968469651246, 30 | "learning_rate_step": 1, 31 | "note_bert_lr_reduce": 1, 32 | "kernel_sizes": [ 33 | 7, 34 | 7, 35 | 5, 36 | 3 37 | ], 38 | "num_filters": [ 39 | 10, 40 | 100, 41 | 100, 42 | 5 43 | ], 44 | "dropout": 0.5, 45 | "gru_num_hidden": 2, 46 | "gru_hidden_layer_size": 255, 47 | "gru_pooling_method": "last", 48 | "task_weights_filepath": "", 49 | "regression_task_weight": 0, 50 | "do_add_cls_analog": false, 51 | "do_masked_imputation": false, 52 | "imputation_mask_rate": 0.0, 53 | "hidden_dropout_prob": 0.21803518086570073, 54 | "pooling_method": "max", 55 | "pooling_kernel_size": 4, 56 | "pooling_stride": null, 57 | "conv_layers_per_pool": 1, 58 | "do_bidirectional": true, 59 | "fc_layer_sizes": [ 60 | 256 61 | ], 62 | "do_weight_decay": false, 63 | "weight_decay": 0.9374042825319748, 64 | "gru_fc_layer_sizes": [ 65 | 134 66 | ], 67 | "ablate": [], 68 | "frac_data": 1.0, 69 | "frac_data_seed": 0, 70 | "frac_female": 1.0, 71 | "frac_black": 1.0, 72 | "balanced_race": false, 73 | "do_test_run": false, 74 | "do_detect_anomaly": false 75 | } 76 | -------------------------------------------------------------------------------- /Sample Args/cmo_mimic_ST/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_seq_len": 24, 3 | "modeltype": "GRU", 4 | "run_dir": "/crimea/mmd/comprehensive_MTL_EHR/Sample Args/cmo_mimic_ST", 5 | "model_file_template": "model", 6 | "do_overwrite": true, 7 | "rotation": 0, 8 | "dataset_dir": "/crimea/latent_patient_trajectories/dataset/rotations/no_notes/0", 9 | "num_dataloader_workers": 6, 10 | "do_eicu": false, 11 | "epochs": 33, 12 | "do_train": true, 13 | "do_eval_train": true, 14 | "do_eval_tuning": true, 15 | "do_eval_test": true, 16 | "train_save_every": 1, 17 | "batches_per_gradient": 1, 18 | "set_to_eval_mode": "", 19 | "notes": "no_notes", 20 | "do_train_note_bert": true, 21 | "in_dim": 32, 22 | "hidden_size": 186, 23 | "intermediate_size": 128, 24 | "num_attention_heads": 4, 25 | "num_hidden_layers": 2, 26 | "batch_size": 456, 27 | "learning_rate": 0.0006027742749367915, 28 | "do_learning_rate_decay": false, 29 | "learning_rate_decay": 0.18358687739255003, 30 | "learning_rate_step": 1, 31 | "note_bert_lr_reduce": 1, 32 | "kernel_sizes": [ 33 | 7, 34 | 7, 35 | 5, 36 | 3 37 | ], 38 | "num_filters": [ 39 | 10, 40 | 100, 41 | 100, 42 | 5 43 | ], 44 | "dropout": 0.5, 45 | "gru_num_hidden": 4, 46 | "gru_hidden_layer_size": 327, 47 | "gru_pooling_method": "last", 48 | "task_weights_filepath": "", 49 | "regression_task_weight": 0, 50 | "do_add_cls_analog": false, 51 | "do_masked_imputation": false, 52 | "do_fake_masked_imputation_shape": false, 53 | "imputation_mask_rate": 0.0, 54 | "hidden_dropout_prob": 0.2124292039511747, 55 | "pooling_method": "max", 56 | "pooling_kernel_size": 4, 57 | "pooling_stride": null, 58 | "conv_layers_per_pool": 1, 59 | "do_bidirectional": false, 60 | "fc_layer_sizes": [ 61 | 256 62 | ], 63 | "do_weight_decay": false, 64 | "weight_decay": 0.9841213242009911, 65 | "gru_fc_layer_sizes": [ 66 | 355, 67 | 410 68 | ], 69 | "ablate": [ 70 | "icd10", 71 | "discharge", 72 | "mortality", 73 | "los", 74 | "readmission", 75 | "future_treatment_sequence", 76 | "acuity", 77 | "next_timepoint_info", 78 | "dnr" 79 | ], 80 | "frac_data": 1.0, 81 | "frac_data_seed": 0, 82 | "frac_female": 1.0, 83 | "frac_black": 1.0, 84 | "balanced_race": false, 85 | "do_test_run": false, 86 | "do_detect_anomaly": false 87 | } -------------------------------------------------------------------------------- /Sample Args/eicu_masked_PT/args.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_seq_len": 48, 3 | "modeltype": "GRU", 4 | "run_dir": "/crimea/mmd/comprehensive_MTL_EHR/Sample Args/eicu_masked_PT", 5 | "model_file_template": "model", 6 | "do_overwrite": true, 7 | "rotation": 0, 8 | "dataset_dir": "/crimea/latent_patient_trajectories/dataset_eicu/rotations/no_notes/0", 9 | "num_dataloader_workers": 6, 10 | "do_eicu": true, 11 | "epochs": 24, 12 | "do_train": true, 13 | "do_eval_train": true, 14 | "do_eval_tuning": true, 15 | "do_eval_test": true, 16 | "train_save_every": 1, 17 | "batches_per_gradient": 1, 18 | "set_to_eval_mode": "", 19 | "notes": "no_notes", 20 | "do_train_note_bert": true, 21 | "in_dim": 32, 22 | "hidden_size": 16, 23 | "intermediate_size": 128, 24 | "num_attention_heads": 4, 25 | "num_hidden_layers": 2, 26 | "batch_size": 83, 27 | "learning_rate": 0.000719710118270521, 28 | "do_learning_rate_decay": false, 29 | "learning_rate_decay": 0.15787933224069312, 30 | "learning_rate_step": 1, 31 | "note_bert_lr_reduce": 1, 32 | "kernel_sizes": [ 33 | 7, 34 | 7, 35 | 5, 36 | 3 37 | ], 38 | "num_filters": [ 39 | 10, 40 | 100, 41 | 100, 42 | 5 43 | ], 44 | "dropout": 0.5, 45 | "gru_num_hidden": 1, 46 | "gru_hidden_layer_size": 375, 47 | "gru_pooling_method": "max", 48 | "task_weights_filepath": "", 49 | "regression_task_weight": 0, 50 | "do_add_cls_analog": false, 51 | "do_masked_imputation": true, 52 | "do_fake_masked_imputation_shape": false, 53 | "imputation_mask_rate": 0.15, 54 | "hidden_dropout_prob": 0.19207323120675984, 55 | "pooling_method": "max", 56 | "pooling_kernel_size": 4, 57 | "pooling_stride": null, 58 | "conv_layers_per_pool": 1, 59 | "do_bidirectional": false, 60 | "fc_layer_sizes": [ 61 | 256 62 | ], 63 | "do_weight_decay": false, 64 | "weight_decay": 0.3391976095387189, 65 | "gru_fc_layer_sizes": [ 66 | 503, 67 | 697 68 | ], 69 | "ablate": [ 70 | "icd10", 71 | "discharge", 72 | "mortality", 73 | "los", 74 | "readmission", 75 | "future_treatment_sequence", 76 | "acuity", 77 | "next_timepoint_info", 78 | "dnr", 79 | "cmo" 80 | ], 81 | "frac_data": 1.0, 82 | "frac_data_seed": 0, 83 | "frac_female": 1.0, 84 | "frac_black": 1.0, 85 | "balanced_race": false, 86 | "do_test_run": false, 87 | "do_detect_anomaly": false 88 | } 89 | -------------------------------------------------------------------------------- /latent_patient_trajectories/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib, pickle, numpy as np#, ml_toolkit.pandas_constructions as pdc 2 | 3 | import pandas as pd 4 | import re 5 | 6 | def get_index_levels(df, levels, make_objects_categories=True): 7 | #todo: make this package a requirement so env doesn't break 8 | #from ml_toolkit 9 | df_2 = pd.DataFrame(index=df.index) 10 | for level in levels: df_2[level] = df_2.index.get_level_values(level) 11 | if make_objects_categories: 12 | for column in df_2.columns: 13 | if df_2[column].dtype == object: df_2[column] = df_2[column].astype('category') 14 | return df_2 15 | 16 | def freq_or_count(x, type_at_1=int): 17 | #assert type(x) in (float, int), "x must be either a float or an integer." 18 | x = float(x) 19 | assert x >= 0, "x must be nonnegative" 20 | if x < 1: return x 21 | elif x > 1 and int(x) == x: return int(x) 22 | elif x == 1: return type_at_1(x) 23 | else: raise NotImplementedError("x cannot be coerced to frequency or count!") 24 | 25 | def __nested_sorted_repr(c): 26 | if type(c) in (set, frozenset): return tuple(sorted(c)) 27 | if type(c) is dict: return tuple(sorted([(k, __nested_sorted_repr(v)) for k, v in c.items()])) 28 | if type(c) in (tuple, list): return tuple([__nested_sorted_repr(e) for e in c]) 29 | else: return c 30 | 31 | def hash_dict(d): return hash_repr(__nested_sorted_repr(d)) 32 | def hash_repr(tup): 33 | m = hashlib.new('md5') 34 | m.update(repr(tup).encode('utf-8')) 35 | return m.hexdigest() 36 | 37 | def pad(l, max_len, pad_value = 0): 38 | # only df is tested. 39 | if type(l) is list: return l + ([pad_value]*(max_len - len(l))) 40 | elif type(l) is np.ndarray: 41 | try: return l.resize((max_len - len(l), pad_value)) 42 | except ValueError as e: pass 43 | 44 | return np.concatenate((l, pad_value * np.ones([max_len - l.shape[0]] + list(l.shape[1:])))) 45 | raise NotImplementedError("Only supports lists or numpy arrays at present.") 46 | 47 | # TODO(mmd): Consider moving to pdc 48 | def add_id_col(df, id_idxs, id_col_name): 49 | assert len(id_idxs) > 1, "Too few id idx columns." 50 | 51 | df[id_col_name] = [hash_repr(tuple(str(x) for x in a)) for a in get_index_levels(df, id_idxs).values] 52 | df.set_index(id_col_name, append=True, inplace=True) 53 | 54 | def depickle(filepath): 55 | with open(filepath, mode='rb') as f: return pickle.load(f) 56 | def read_txt(filepath): 57 | with open(filepath, mode='r') as f: return f.read() 58 | 59 | # For debugging pandas apply/transform ops. 60 | def print_and_raise(*args, **kwargs): 61 | print(args, kwargs) 62 | raise NotImplementedError 63 | 64 | def zip_dicts_assert(*dcts): 65 | d0 = dcts[0] 66 | s0 = set(d0.keys()) 67 | for d in dcts[1:]: 68 | s1 = set(d.keys()) 69 | assert d0.keys() == d.keys(), f"Keys Disagree! d0 - d1 = {s0 - s1}, d1 - d0 = {s1 - s0}" 70 | 71 | for i in set(dcts[0]).intersection(*dcts[1:]): yield (i,) + tuple(d[i] for d in dcts) 72 | 73 | def zip_dicts(*dcts): 74 | for i in set(dcts[0]).intersection(*dcts[1:]): yield (i,) + tuple(d[i] for d in dcts) 75 | 76 | def zip_dicts_union(*dcts): 77 | keys = set(dcts[0].keys()) 78 | for d in dcts[1:]: keys.update(d.keys()) 79 | 80 | for k in keys: yield (k,) + tuple(d[k] if k in d else np.NaN for d in dcts) 81 | 82 | def tokenize_str(str_): 83 | # keep only alphanumeric and punctations 84 | str_ = re.sub(r'[^A-Za-z0-9(),.!?\'`]', ' ', str_) 85 | # remove multiple whitespace characters 86 | str_ = re.sub(r'\s{2,}', ' ', str_) 87 | # punctations to tokens 88 | str_ = re.sub(r'\(', ' ( ', str_) 89 | str_ = re.sub(r'\)', ' ) ', str_) 90 | str_ = re.sub(r',', ' , ', str_) 91 | str_ = re.sub(r'\.', ' . ', str_) 92 | str_ = re.sub(r'!', ' ! ', str_) 93 | str_ = re.sub(r'\?', ' ? ', str_) 94 | # split contractions into multiple tokens 95 | str_ = re.sub(r'\'s', ' \'s', str_) 96 | str_ = re.sub(r'\'ve', ' \'ve', str_) 97 | str_ = re.sub(r'n\'t', ' n\'t', str_) 98 | str_ = re.sub(r'\'re', ' \'re', str_) 99 | str_ = re.sub(r'\'d', ' \'d', str_) 100 | str_ = re.sub(r'\'ll', ' \'ll', str_) 101 | # lower case 102 | return str_.strip().lower().split() 103 | 104 | def not_none(stuff): 105 | if not stuff is not None: 106 | return False 107 | else: 108 | return True 109 | 110 | def is_none(stuff): 111 | not not_none 112 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: /crimea/conda_envs/latent_patient_trajectories 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - asn1crypto=0.24.0=py37_0 9 | - attrs=19.1.0=py_0 10 | - backcall=0.1.0=py_0 11 | - bcrypt=3.1.6=py37h7b6447c_0 12 | - blas=1.0=mkl 13 | - bleach=3.1.0=py_0 14 | - blosc=1.15.0=hd408876_0 15 | - bokeh=1.1.0=py37_0 16 | - boto=2.49.0=py37_0 17 | - boto3=1.9.134=py_0 18 | - botocore=1.12.134=py_0 19 | - bottleneck=1.2.1=py37h035aef0_1 20 | - bz2file=0.98=py37_1 21 | - bzip2=1.0.6=h14c3975_5 22 | - ca-certificates=2020.10.14=0 23 | - certifi=2020.6.20=pyhd3eb1b0_3 24 | - cffi=1.12.2=py37h2e261b9_1 25 | - cftime=1.0.3.4=py37hdd07704_0 26 | - chardet=3.0.4=py37_1 27 | - click=7.0=py37_0 28 | - cloudpickle=0.8.1=py_0 29 | - colorama=0.4.1=py37_0 30 | - cryptography=2.6.1=py37h1ba5d50_0 31 | - cudatoolkit=10.0.130=0 32 | - curl=7.64.1=hbc83047_0 33 | - cycler=0.10.0=py37_0 34 | - cytoolz=0.9.0.1=py37h14c3975_1 35 | - dask=1.2.0=py_0 36 | - dask-core=1.2.0=py_0 37 | - dbus=1.13.2=h714fa37_1 38 | - decorator=4.4.0=py_0 39 | - defusedxml=0.5.0=py_1 40 | - distributed=1.27.1=py37_0 41 | - docutils=0.14=py37_0 42 | - entrypoints=0.3=py37_1000 43 | - expat=2.2.6=he6710b0_0 44 | - fontconfig=2.13.1=he4413a7_1000 45 | - freetype=2.9.1=h8a8886c_1 46 | - gensim=3.4.0=py37h14c3975_0 47 | - gettext=0.19.8.1=hc5be6a0_1002 48 | - glib=2.56.2=had28632_1001 49 | - gst-plugins-base=1.14.0=hbbd80ab_1 50 | - gstreamer=1.14.0=hb453b48_1 51 | - hdf4=4.2.13=h3ca952b_2 52 | - hdf5=1.10.4=hb1b8bf9_0 53 | - heapdict=1.0.0=py37_2 54 | - icu=58.2=hf484d3e_1000 55 | - idna=2.8=py37_0 56 | - intel-openmp=2019.3=199 57 | - ipykernel=5.1.0=py37h24bf2e0_1002 58 | - ipython=7.4.0=py37h24bf2e0_0 59 | - ipython_genutils=0.2.0=py_1 60 | - ipywidgets=7.4.2=py_0 61 | - jedi=0.13.3=py37_0 62 | - jinja2=2.10.1=py_0 63 | - jmespath=0.9.4=py_0 64 | - jpeg=9b=h024ee3a_2 65 | - jsonschema=3.0.1=py37_0 66 | - jupyter=1.0.0=py_2 67 | - jupyter_client=5.2.4=py_3 68 | - jupyter_console=6.0.0=py_0 69 | - jupyter_contrib_core=0.3.3=py_2 70 | - jupyter_core=4.4.0=py_0 71 | - kiwisolver=1.0.1=py37hf484d3e_0 72 | - krb5=1.16.1=h173b8e3_7 73 | - libcurl=7.64.1=h20c2e04_0 74 | - libedit=3.1.20181209=hc058e9b_0 75 | - libffi=3.2.1=hd88cf55_4 76 | - libgcc-ng=8.2.0=hdf63c60_1 77 | - libgfortran-ng=7.3.0=hdf63c60_0 78 | - libiconv=1.15=h516909a_1005 79 | - libnetcdf=4.6.1=h11d0813_2 80 | - libpng=1.6.36=hbc83047_0 81 | - libpq=11.2=h20c2e04_0 82 | - libsodium=1.0.16=h14c3975_1001 83 | - libssh2=1.8.2=h1ba5d50_0 84 | - libstdcxx-ng=8.2.0=hdf63c60_1 85 | - libtiff=4.0.10=h2733197_2 86 | - libuuid=2.32.1=h14c3975_1000 87 | - libxcb=1.13=h14c3975_1002 88 | - libxml2=2.9.9=h13577e0_0 89 | - libxslt=1.1.32=h4785a14_1002 90 | - locket=0.2.0=py37_1 91 | - lxml=4.3.3=py37h7ec2d77_0 92 | - lzo=2.10=h49e0be7_2 93 | - markupsafe=1.1.1=py37h14c3975_0 94 | - matplotlib=3.0.3=py37h5429711_0 95 | - mistune=0.8.4=py37h14c3975_1000 96 | - mkl=2019.3=199 97 | - mkl_fft=1.0.10=py37ha843d7b_0 98 | - mkl_random=1.0.2=py37hd81dba3_0 99 | - msgpack-python=0.6.1=py37hfd86e86_1 100 | - nbconvert=5.4.1=py_2 101 | - nbformat=4.4.0=py_1 102 | - ncurses=6.1=he6710b0_1 103 | - netcdf4=1.4.2=py37h808af73_0 104 | - ninja=1.9.0=py37hfd86e86_0 105 | - notebook=5.7.8=py37_0 106 | - numexpr=2.6.9=py37h9e4a6bb_0 107 | - numpy=1.16.2=py37h7e9f1db_0 108 | - numpy-base=1.16.2=py37hde5b4d6_0 109 | - olefile=0.46=py37_0 110 | - openssl=1.1.1h=h7b6447c_0 111 | - packaging=19.0=py37_0 112 | - pandas=0.24.2=py37he6710b0_0 113 | - pandoc=2.7.2=0 114 | - pandocfilters=1.4.2=py_1 115 | - paramiko=2.4.2=py37_0 116 | - parso=0.4.0=py_0 117 | - partd=0.3.10=py37_1 118 | - pcre=8.43=he6710b0_0 119 | - pexpect=4.7.0=py37_0 120 | - pickleshare=0.7.5=py37_1000 121 | - pillow=6.0.0=py37h34e0f95_0 122 | - pip=19.0.3=py37_0 123 | - prometheus_client=0.6.0=py_0 124 | - prompt_toolkit=2.0.9=py_0 125 | - psutil=5.6.2=py37h7b6447c_0 126 | - psycopg2=2.7.6.1=py37h1ba5d50_0 127 | - pthread-stubs=0.4=h14c3975_1001 128 | - ptyprocess=0.6.0=py_1001 129 | - pyasn1=0.4.5=py_0 130 | - pycparser=2.19=py37_0 131 | - pygments=2.3.1=py_0 132 | - pynacl=1.3.0=py37h7b6447c_0 133 | - pyopenssl=19.0.0=py37_0 134 | - pyparsing=2.4.0=py_0 135 | - pyqt=5.9.2=py37h05f1152_2 136 | - pyrsistent=0.14.11=py37h14c3975_0 137 | - pysocks=1.6.8=py37_0 138 | - pytables=3.5.1=py37h71ec239_0 139 | - python=3.7.3=h0371630_0 140 | - python-dateutil=2.8.0=py37_0 141 | - pytorch=1.0.1=py3.7_cuda10.0.130_cudnn7.4.2_2 142 | - pytz=2019.1=py_0 143 | - pyyaml=5.1=py37h14c3975_0 144 | - pyzmq=18.0.1=py37hc4ba49a_1 145 | - qt=5.9.7=h5867ecd_1 146 | - qtconsole=4.4.3=py_0 147 | - readline=7.0=h7b6447c_5 148 | - requests=2.21.0=py37_0 149 | - s3transfer=0.2.0=py37_0 150 | - scikit-learn=0.20.3=py37hd81dba3_0 151 | - scipy=1.2.1=py37h7c811a0_0 152 | - seaborn=0.11.0=py_0 153 | - send2trash=1.5.0=py_0 154 | - setuptools=41.0.0=py37_0 155 | - sip=4.19.8=py37hf484d3e_0 156 | - six=1.12.0=py37_0 157 | - smart_open=1.8.2=py_0 158 | - snappy=1.1.7=hbae5bb6_3 159 | - sortedcontainers=2.1.0=py37_0 160 | - sqlite=3.27.2=h7b6447c_0 161 | - tblib=1.3.2=py37_0 162 | - terminado=0.8.2=py37_0 163 | - testpath=0.4.2=py_1001 164 | - tk=8.6.8=hbc83047_0 165 | - toolz=0.9.0=py37_0 166 | - torchvision=0.2.2=py_3 167 | - tornado=6.0.2=py37h516909a_0 168 | - tqdm=4.31.1=py37_1 169 | - traitlets=4.3.2=py37_1000 170 | - urllib3=1.24.2=py37_0 171 | - wcwidth=0.1.7=py_1 172 | - webencodings=0.5.1=py_1 173 | - wheel=0.33.1=py37_0 174 | - widgetsnbextension=3.4.2=py37_0 175 | - xarray=0.12.1=py_0 176 | - xorg-libxau=1.0.9=h14c3975_0 177 | - xorg-libxdmcp=1.1.3=h516909a_0 178 | - xz=5.2.4=h14c3975_4 179 | - yaml=0.1.7=h14c3975_1001 180 | - zeromq=4.3.1=hf484d3e_1000 181 | - zict=0.1.4=py37_0 182 | - zlib=1.2.11=h7b6447c_3 183 | - zstd=1.3.7=h0b5b093_0 184 | - pip: 185 | - future==0.18.0 186 | - hyperopt==0.1.2 187 | - networkx==2.3 188 | - pymongo==3.9.0 189 | - pytorch-pretrained-bert==0.6.2 190 | - regex==2019.8.19 191 | - upsetplot==0.4.0 192 | prefix: /crimea/conda_envs/latent_patient_trajectories 193 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/utils.py: -------------------------------------------------------------------------------- 1 | import os, random, sys, torch, numpy as np, pandas as pd, torch.nn as nn 2 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SubsetRandomSampler 3 | idx=pd.IndexSlice 4 | 5 | from ..constants import * 6 | 7 | from tqdm import tqdm 8 | 9 | def one_hot_encode( 10 | cols_to_encode, 11 | df, 12 | vocab, 13 | inplace=False, 14 | ): 15 | # TODO(mmd): Write unit tests for this function! 16 | if not inplace: df = df.copy() 17 | 18 | for col in cols_to_encode: 19 | max_idx = len(vocab[col]) 20 | one_hot_c = np.zeros((len(df), max_idx)) 21 | one_hot_c[list(range(len(df))), df[col]] = 1 22 | new_cols = [(col, t) for t in vocab[col]] 23 | df[new_cols] = pd.DataFrame(one_hot_c, index=df.index) 24 | 25 | # Now remove the old 26 | df.drop(columns=col, inplace=True) 27 | return df 28 | 29 | #def one_hot_encode(cols_to_encode, df, vocab, inplace=False): 30 | # """ 31 | # """ 32 | # # TODO(mmd): Write unit tests for this function! 33 | # if not inplace: df = df.copy() 34 | # 35 | # 36 | # # i.e. comfort measures ordered 37 | # 38 | # # for each encoded column 39 | # for col in tqdm(cols_to_encode): 40 | # print("\n", col) 41 | # print( vocab[col]) 42 | # print(list(set(df[col].values.tolist()))) 43 | # # find the number of words for that column 44 | # max_idx = len(vocab[col]) 45 | # # create an empty df for the number of words in that column 46 | # one_hot_c = np.zeros((len(df), max_idx)) 47 | # # why not just use [:, df[col]] this line is unintuitive 48 | # one_hot_c[list(range(len(df))), df[col]] = 1 49 | # # create a new column for eacht in vocab 50 | # new_cols = [(col, t) for t in vocab[col]] 51 | # 52 | # new_df=pd.get_dummies(df[col]) 53 | # new_df.index=df.index # set one hot df index to be same as original index 54 | # # df[new_cols]=new_df 55 | # 56 | # print(one_hot_c[:5, :]) 57 | # 58 | # for i, col in enumerate(new_cols): 59 | # df[col]=one_hot_c[:, i].ravel() 60 | # # df.loc[:, new_cols] = one_hot_c # only assign with .loc to supres the assign on copy warning 61 | # 62 | # # Now remove the old 63 | # df.drop(columns=col, inplace=True) 64 | # return df 65 | 66 | def add_time_since_measured( 67 | df, init_time_since_measured=100, max_time_since_measured=100, hour_aggregation = 1, 68 | ): 69 | idx = pd.IndexSlice 70 | df = df.copy() 71 | 72 | df.loc[:, idx[:, 'count']] = df.loc[:, idx[:, 'count']].fillna(0) 73 | 74 | is_absent = (df.loc[:, idx[:, 'count']] == 0).astype(int) 75 | hours_of_absence = is_absent.groupby(ID_COLS).cumsum() 76 | time_since_measured = hours_of_absence - hours_of_absence[is_absent==0].groupby(ID_COLS).fillna(method='ffill') 77 | time_since_measured.fillna(init_time_since_measured, inplace=True) 78 | time_since_measured[time_since_measured > max_time_since_measured] = max_time_since_measured 79 | 80 | # Somehow, prior to the rename above and here, the columns index lost its name, so we use level=1 here. 81 | # But note that this is more brittle. 82 | time_since_measured.rename( 83 | columns={'count': 'time_since_measured'}, level=1, inplace=True 84 | ) 85 | 86 | if hour_aggregation !=1: 87 | time_since_measured.loc[:, :] = time_since_measured.values // hour_aggregation * int(hour_aggregation) 88 | 89 | df_out = pd.concat((df, time_since_measured), axis=1) 90 | df_out.sort_index(axis=1, inplace=True) 91 | return df_out 92 | 93 | def FullyConnectedNet( 94 | in_dim, out_dim, 95 | hidden_sizes = [], 96 | inner_activation = nn.ReLU, 97 | out_activation = None, 98 | dropout_prob = 0.1, 99 | dropout_layer = nn.Dropout, 100 | ): 101 | layers = [] 102 | if dropout is not None: layers.append(dropout_layer(p=dropout_prob)) 103 | for hs in hidden_sizes: 104 | layers.extend([nn.Linear(in_dim, hs), activation()]) 105 | in_dim = hs 106 | layers.append(nn.Linear(in_dim, out_dim)) 107 | if out_activation is not None: layers.append(out_activation()) 108 | 109 | return nn.Sequential(*layers) 110 | 111 | def get_loss_if_labeled(loss_fct): 112 | def f(out, labels, **kwargs): return (out, None) if labels is None else loss_fct(out, labels, **kwargs) 113 | return f 114 | 115 | # TODO(mmd): This looks like it is doing all the same thing.... 116 | def weight_init(m): 117 | ''' 118 | Usage: 119 | model = Model() 120 | model.apply(weight_init) 121 | ''' 122 | if isinstance(m, nn.Conv1d): 123 | torch.nn.init.xavier_normal_(m.weight.data) 124 | torch.nn.init.normal_(m.bias.data) 125 | if isinstance(m, nn.Conv2d): 126 | torch.nn.init.xavier_normal_(m.weight.data) 127 | torch.nn.init.normal_(m.bias.data) 128 | if isinstance(m, nn.Linear): 129 | torch.nn.init.xavier_normal_(m.weight.data) 130 | torch.nn.init.normal_(m.bias.data) 131 | 132 | 133 | # TODO(mmd): This is probably not needed--can just omit heads or data... 134 | def ablate(all_outputs, ablations): 135 | """ 136 | all_outputs (dict): { 137 | 'will_be_measured': (will_be_measured_pred, will_be_measured, will_be_measured_loss), 138 | 'next_timepoint': (next_values_pred, next_values, reconstruction_loss), 139 | 'mort_icu': (mort_icu_pred, mort_icu, mort_icu_loss), 140 | 'mort_24': (mort_24_pred, mort_24, mort_24h_loss), 141 | 'disch_24': (disch_24_pred, disch_24, disch_24h_loss), 142 | 'los_left': (los_left_pred, los_left, LOS_left_loss), 143 | 'fts_decoding': (fts_logits, fts_labels, fts_loss), 144 | } 145 | 146 | args.ablate (list): a list containing they keys of the all_outputs of which to ablate 147 | """ 148 | losses=all_outputs.keys() 149 | 150 | losses=[l for l in losses if l not in ablations] 151 | 152 | # total_loss=Variable(torch.tensor(0)).float() 153 | # for l in losses: 154 | # # print([item.shape for item in all_outputs[l]]) 155 | # # label, target, loss=all_outputs[l] 156 | # # print(label.shape) 157 | # # print(target.shape) 158 | # # print(loss.shape) 159 | # # print(loss) 160 | # total_loss+=all_outputs[l][-1].sum() # sum necessary for multi gpu jobs 161 | 162 | return torch.stack([all_outputs[l][-1].sum() for l in losses]).sum() 163 | -------------------------------------------------------------------------------- /latent_patient_trajectories/BERT/data_processor.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from datetime import datetime 3 | from ..constants import * 4 | 5 | class InputFeatures(object): 6 | """A single set of features of data.""" 7 | 8 | def __init__(self, input_ids, input_mask, segment_ids): 9 | self.input_ids = input_ids 10 | self.input_mask = input_mask 11 | self.segment_ids = segment_ids 12 | 13 | class InputExample(object): 14 | """A single training/test example for simple sequence classification.""" 15 | 16 | def __init__(self, guid, text_a, text_b=None, label=None): 17 | """Constructs a InputExample. 18 | 19 | Args: 20 | guid: Unique id for the example. 21 | text_a: string. The untokenized text of the first sequence. For single 22 | sequence tasks, only this sequence must be specified. 23 | text_b: (Optional) string. The untokenized text of the second sequence. 24 | Only must be specified for sequence pair tasks. 25 | label: (Optional) string. The label of the example. This should be 26 | specified for train and dev examples, but not for test examples. 27 | """ 28 | self.guid = guid 29 | self.text_a = text_a 30 | self.text_b = text_b 31 | self.label = label 32 | 33 | class DfDataProcessor(object): 34 | def __init__(self, sentence_a_col, sentence_b_col = None): 35 | self.sentence_a_col = sentence_a_col 36 | self.sentence_b_col = sentence_b_col 37 | 38 | def get_examples(self, df, folds=None): 39 | fold_idx = df.index.get_level_values(FOLD_IDX_LVL) 40 | if folds is not None: 41 | df = df[fold_idx.isin(folds)] 42 | 43 | return [InputExample( 44 | guid = str(idx), text_a = r[self.sentence_a_col], 45 | text_b = None if self.sentence_b_col is None else r[self.sentence_b_col], 46 | ) for idx, r in df.iterrows()] 47 | 48 | def convert_example_to_tokens(example, tokenizer, max_len, max_seq_length): 49 | tokens_a = tokenizer.tokenize(example.text_a) 50 | 51 | tokens_b = None 52 | if example.text_b: 53 | tokens_b = tokenizer.tokenize(example.text_b) 54 | seq_len = len(tokens_a) + len(tokens_b) 55 | 56 | # Modifies `tokens_a` and `tokens_b` in place so that the total 57 | # length is less than the specified length. 58 | # Account for [CLS], [SEP], [SEP] with "- 3" 59 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 60 | else: 61 | seq_len = len(tokens_a) 62 | # Account for [CLS] and [SEP] with "- 2" 63 | if len(tokens_a) > max_seq_length - 2: 64 | tokens_a = tokens_a[:(max_seq_length - 2)] 65 | 66 | if seq_len > max_len: 67 | max_len = seq_len 68 | # The convention in BERT is: 69 | # (a) For sequence pairs: 70 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 71 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 72 | # (b) For single sequences: 73 | # tokens: [CLS] the dog is hairy . [SEP] 74 | # type_ids: 0 0 0 0 0 0 0 75 | # 76 | # Where "type_ids" are used to indicate whether this is the first 77 | # sequence or the second sequence. The embedding vectors for `type=0` and 78 | # `type=1` were learned during pre-training and are added to the wordpiece 79 | # embedding vector (and position vector). This is not *strictly* necessary 80 | # since the [SEP] token unambigiously separates the sequences, but it makes 81 | # it easier for the model to learn the concept of sequences. 82 | # 83 | # For classification tasks, the first vector (corresponding to [CLS]) is 84 | # used as as the "sentence vector". Note that this only makes sense because 85 | # the entire model is fine-tuned. 86 | tokens = ["[CLS]"] + tokens_a + ["[SEP]"] 87 | 88 | return (tokens, tokens_b, max_len) 89 | 90 | def convert_tokens_to_features(tokens, tokens_b, tokenizer, max_seq_length): 91 | assert not tokens_b, "Not supported." 92 | if isinstance(tokens, str): 93 | tokens = [tokens] 94 | 95 | if len(tokens) > max_seq_length: 96 | tokens = ["[CLS]"] + tokens[-max_seq_length + 1:] 97 | 98 | segment_ids = [0] * len(tokens) 99 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 100 | 101 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 102 | # tokens are attended to. 103 | input_mask = [1] * len(input_ids) 104 | 105 | # Zero-pad up to the sequence length. 106 | padding = [0] * (max_seq_length - len(input_ids)) 107 | input_ids += padding 108 | input_mask += padding 109 | segment_ids += padding 110 | 111 | assert len(input_ids) == max_seq_length, \ 112 | "len(x) [%d] != max_seq_length [%d]" % (len(input_ids), max_seq_length) 113 | assert len(input_mask) == max_seq_length, \ 114 | "len(x) [%d] != max_seq_length [%d]" % (len(input_mask), max_seq_length) 115 | assert len(segment_ids) == max_seq_length, \ 116 | "len(x) [%d] != max_seq_length [%d]" % (len(segment_ids), max_seq_length) 117 | 118 | return InputFeatures(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids) 119 | 120 | 121 | def convert_examples_to_features(examples, task_list, max_seq_length, tokenizer): 122 | """Loads a data file into a list of `InputBatch`s.""" 123 | features = [] 124 | max_len = 0 125 | for (ex_index, example) in enumerate(examples): 126 | tokens_for_example, tokens_b, max_len = convert_example_to_tokens( 127 | example, tokenizer, max_len, max_seq_length 128 | ) 129 | 130 | features_for_example = convert_tokens_to_features( 131 | tokens_for_example, tokens_b, tokenizer, max_seq_length 132 | ) 133 | 134 | features.append(features_for_example) 135 | 136 | print('Max Sequence Length: %d' %max_len) 137 | 138 | return features 139 | 140 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 141 | """Truncates a sequence pair in place to the maximum length.""" 142 | 143 | # This is a simple heuristic which will always truncate the longer sequence 144 | # one token at a time. This makes more sense than truncating an equal percent 145 | # of tokens from each, since if one sequence is very short then each token 146 | # that's truncated likely contains more information than a longer sequence. 147 | while True: 148 | total_length = len(tokens_a) + len(tokens_b) 149 | if total_length <= max_length: 150 | break 151 | if len(tokens_a) > len(tokens_b): 152 | tokens_a.pop() 153 | else: 154 | tokens_b.pop() 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Comprehensive EHR Timeseries Pre-training Benchmark 2 | Source code for our paper (https://dl.acm.org/doi/pdf/10.1145/3450439.3451877) defining a pre-training benchmark system for EHR timeseries data. 3 | Contact mmd@mit.edu and bretnestor@cs.toronto.edu with any questions. Pending interest from the community, we're eager to make this as usable as possible, and will respond promptly to any issues or questions. 4 | 5 | # Install 6 | 7 | Set up the repository 8 | ``` 9 | conda env create --name comprehensive_EHR_PT -f env.yml 10 | conda activate comprehensive_EHR_PT 11 | ``` 12 | 13 | # Obtaining Data 14 | Copies of pre-processed dataset splits used in the paper can be obtained via Google Cloud. To access them, you 15 | must ensure that you have obtained GCP access via physionet.org for the requisite datasets. See 16 | [https://mimic.physionet.org/gettingstarted/cloud/](https://mimic.physionet.org/gettingstarted/cloud/) for 17 | instructions on obtaining Physionet GCP access. 18 | 19 | 1. MIMIC-III Dataset: [https://console.cloud.google.com/storage/browser/ehr_pretraining_benchmark_mimic](https://console.cloud.google.com/storage/browser/ehr_pretraining_benchmark_mimic) 20 | 2. eICU Dataset: [https://console.cloud.google.com/storage/browser/ehr_pretraining_benchmark_eicu](https://console.cloud.google.com/storage/browser/ehr_pretraining_benchmark_eicu) 21 | 22 | # Usage 23 | ## Args in General 24 | Arguments for all scripts are described in the `latent_patient_trajectories/representation_learner/args.py` 25 | file. This file has some base classes, then argument classes (with specific args requested) for all functions. 26 | It is a good reference to determine what a specific script expects. Note this class allows you to (and we 27 | recommend) pre-setting all args for scripts in (appropriately named) json files in the relevant experiment 28 | directories, then simply passing the directory to the given script (according to the appropriate `arg`) and 29 | adding `--do_load_from_dir`, at which point the script will load all arguments from the json file 30 | automatically. Note that some args (specifically `regression_task_weight`, which should always be 0, `notes`, which 31 | should always be `no_notes`, and `task_weights_filepath`, which should always be `''`) are held-out args from older versions of the code, and can be largely ignored. Similarly, the modeltype specific args corresponding to CNN, Self-attention, or Linear projection models are also no longer used. Some sample args for different settings are given in `Sample Args`. Please raise a github issue or contact mmd@mit.edu or bretnestor@cs.toronto.edu with any questions. 32 | 33 | ## Hyperparameter Tuning 34 | To perform hyperparameter tuning, set up a base experimental directory, and add a config file describing your 35 | hyperparameter search in that directory. This file must be named according to the `HYP_CONFIG_FILENAME` 36 | constant in `latent_patient_trajectories/constants.py` file, which is (as of 7/20/20) set to 37 | `hyperparameter_search_config.json`. A sample config file is shown in the file 38 | `latent_patient_trajectories/representation_learner/sample_hyperopt_config.json`. 39 | 40 | Then, run the script `Scripts/End to End/hyperopt_model.py` (with appropriate args, as described in the 41 | `args.py` file referenced above under the class `HyperparameterSearchArgs`) to kick off your search. 42 | 43 | ## Generic Runs 44 | To perform a generic run, training a multi- or single-task model, or a masked imputation model, use the `Args` 45 | class in `args.py` and the `run_model.py` script. As with everything else, you will need to specify a base 46 | directory, and many other args to describe the architecture you want to use and training details. 47 | 48 | ### Evaluation 49 | Evaluating a pre-trained run can be accomplished with the `EvalArgs` class and the `evaluate.py` script. You 50 | will need to specify the model's training directory (e.g., the directory passed to `run_model.py`) so the 51 | script knows what model to reload. 52 | 53 | To convert evaluation results into a form that is human readable and aggregated across tasks, use the `get_manuscript_metrics*` functions (e.g., https://github.com/mmcdermott/comprehensive_MTL_EHR/blob/master/latent_patient_trajectories/representation_learner/evaluator.py#L41) on the output dictionaries. This function just re-processes the more granular output of `evaluate.py` into a more readable form. 54 | 55 | To see an example of where that function is called you can look within the hyperparameter tuning code here:https://github.com/mmcdermott/comprehensive_MTL_EHR/blob/master/latent_patient_trajectories/representation_learner/hyperparameter_search.py#L724 56 | then here:https://github.com/mmcdermott/comprehensive_MTL_EHR/blob/master/latent_patient_trajectories/representation_learner/hyperparameter_search.py#L664 57 | 58 | That code bit shows where the output from the evaluator main function can be parsed into the expected input to the `get_manuscript_metrics_via_args` function. 59 | 60 | ## Task-Generalizability Analyses 61 | These runs consist of pre-training a model either via masked imputation or via multi-task pre-training, in 62 | which the model is pre-trained on all tasks except for one held-out task, then fine-tuning the 63 | model on that held-out task, and evaluating both the pre-trained and fine-tuned models on all tasks. This 64 | could be done manually, through repeated use of the `run_model.py` script and the use of the `--ablate` arg, 65 | but we have a helper script that can manage doing all requisite runs across multiple GPUs in parallel on a 66 | single-machine. To use this, you must first create a base directory for this experiment, which will ultimately 67 | hold all runs associated with this experiment (including pre-trained and fine=tuned). In this directory, you 68 | will specify the model's base args (which will be duplicated and used in all pre-training and fine-tuning 69 | experiments, with the ablate arg automatically adjusted to perform the appropriate experiments) as a `json` 70 | file which is parseable by the `Args` class (note that when run generally, models will write such a file to 71 | disk in their directory, so you can just copy and paste the file from the model you want to examine), as well 72 | as a configuration describing which GPUs you have available on the system and how many models you want to run 73 | on each GPU at a given time, and how many GPUs each model needs (usually both of the latter are 1). There is a 74 | sample config available in 75 | `latent_patient_trajectories/representation_learner/sample_task_generalizability_config.json`, and note that 76 | the config must be renamed to `task_generalizability_config.json` for actual use. You additionally can specify 77 | args according to the `TaskGeneralizabilityArgs` class in the `args.py` file. 78 | 79 | ## Analysis Notebooks 80 | All our results can be analyzed via the `All Results.ipynb` notebook. Input files for this notebook are 81 | available upon request -- as they aggregate both within MIMIC-III and eICU, we cannot use GCP so instead must 82 | validate your physionet access directly. 83 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/fts_decoder.py: -------------------------------------------------------------------------------- 1 | from ..pytorch_helpers import * 2 | from .constants import * 3 | 4 | import torch, torch.nn as nn 5 | 6 | def mask_and_avg_sequential_loss(loss_tensor, seq_lengths): 7 | # loss = [batch, seq_len] 8 | # seq_lens = [batch,] 9 | 10 | max_seq_len = loss_tensor.shape[1] 11 | seq_indices = torch.arange(max_seq_len).unsqueeze(0).expand_as(loss) 12 | seq_lengths = seq_lengths.unsqueeze(1).expand_as(loss) 13 | seq_lengths_mask = seq_indices < seq_lengths 14 | 15 | loss_masked = torch.where(seq_lengths_mask, loss_tensor, torch.zeros_like(loss_tensor)) 16 | loss_masked_averaged = loss_masked.sum(dim=1) / seq_lengths_mask.float().sum(dim=1) 17 | 18 | return loss_masked_averaged 19 | 20 | def embed_classifier( 21 | in_dim, 22 | out_dim, 23 | embeddings_or_none = None, 24 | ): 25 | if embeddings_or_none is None: return nn.Linear(in_dim, out_dim) 26 | 27 | num_elements, embedding_dim = embeddings_or_none.shape 28 | 29 | assert out_dim == embedding_dim, "Embeddings mismatched for output!" 30 | 31 | classifier = nn.Linear(embedding_dim, num_elements) 32 | classifier.weight = embeddings_or_none 33 | 34 | if in_dim == embedding_dim: return classifier 35 | 36 | projection = nn.Linear(in_dim, embedding_dim) 37 | return nn.Sequential(projection, classifier) 38 | 39 | ### Predictors 40 | # We try two predictors--one which uses the fact that the treatments sub-part of this task is actually 41 | # multi-task. The other just embeds all combinations and outputs over all. 42 | 43 | # TODO: make actually work. 44 | # class MultiTaskPredictor(nn.Module): 45 | # def __init__( 46 | # self, 47 | # in_dim, 48 | # treatments, # Assumes all are binary for now. 49 | # mort_pp_dim = len(MortalityPPLabels), 50 | # treatments_embed_weights = None, 51 | # mort_pp_embed_weights = None, 52 | # ): 53 | # super().__init__() 54 | # 55 | # self.log_softmax = nn.LogSoftmax() # TODO(mmd): dim 56 | # self.log_sigmoid = nn.LogSigmoid() 57 | # self.nll_loss = nn.NLLLoss(reduction='none') 58 | # 59 | # self.treatments_classifier = embed_classifier(in_dim, len(treatments), treatments_embed_weights) 60 | # self.mort_pp_classifier = embed_classifier(in_dim, mort_pp_dim, mort_pp_embed_weights) 61 | # self.is_at_end_classifier = nn.Linear(in_dim, 1) 62 | # 63 | # def forward(self, X, mort_pp_label, treatment_label): 64 | # is_at_end_logit = self.is_at_end_classifier(X) 65 | # mort_pp_logits = self.mort_pp_classifier(X) 66 | # treatments_logits = self.treatments_classifier(X) 67 | # 68 | # # generative process -> choose if end, if so, mort_pp, else, treatment_label. 69 | # 70 | # mort_pp_log_probs = self.log_softmax(mort_pp_logits) + self.log_sigmoid(is_at_end_logit) 71 | # treatments_log_probs = self.log_sigmoid(treatments_logits) + self.log_sigmoid(-is_at_end_logit) 72 | # treatments_log_complement_probs = self.log_sigmoid(-treatments_logits) + self.log_sigmoid(-is_at_end_logit) 73 | # 74 | # loss = torch.where( 75 | # torch.isna(mort_pp_label), 76 | # torch.where( 77 | # treatment_label == 1, 78 | # treatments_log_probs, 79 | # treatments_log_complement_probs, 80 | # ), 81 | # self.nll_loss(mort_pp_log_probs, mort_pp_label), 82 | # ) 83 | # 84 | # assert not torch.isnan(loss), "FTS Decoding Loss should not be NaN!" 85 | 86 | class SingleTaskPredictor(nn.Module): 87 | def __init__( 88 | self, 89 | in_dim, 90 | num_classes, 91 | ): 92 | super().__init__() 93 | self.classifier = nn.Linear(in_dim, num_classes) 94 | self.loss_fct = nn.CrossEntropyLoss(reduction='none') 95 | 96 | def forward(self, X, fts_label=None, mort_pp_labels=None): 97 | logits = self.classifier(X).transpose(1, 2) 98 | if fts_label is None: return logits, None 99 | 100 | mask = (fts_label != 0).float() 101 | 102 | loss = (self.loss_fct(logits, fts_label) * mask).sum(dim=1) 103 | num_labels_per_el = mask.sum(dim=1) 104 | num_labels_per_el_or_one = torch.where( 105 | num_labels_per_el > 0, num_labels_per_el, torch.ones_like(num_labels_per_el) 106 | ) #loss will be zero where num_labels_per_el is zero and we want to avoid nans. 107 | loss = (loss / num_labels_per_el_or_one).mean(dim=0) # TODO: validate loss scale. 108 | return logits, loss 109 | 110 | ### Decoders 111 | # We experiment with several decoders 112 | def fts_labels_to_embeddings(fts_labels, treatment_embeddings): 113 | return torch.matmul(fts_labels, treatment_embeddings) 114 | 115 | class LSTMDecoder(nn.Module): 116 | HIDDEN_SIZE = 'hidden_size' 117 | 118 | def __init__( 119 | self, in_dim, treatment_embeddings, 120 | lstm_kwargs={'dropout': 0, } 121 | ): 122 | super().__init__() 123 | 124 | num_treatments, embedding_dim = treatment_embeddings.weight.shape 125 | self.treatment_embeddings = treatment_embeddings 126 | 127 | if self.HIDDEN_SIZE not in lstm_kwargs: lstm_kwargs[self.HIDDEN_SIZE] = embedding_dim 128 | lstm_kwargs['batch_first'] = True 129 | lstm_kwargs['bidirectional'] = False 130 | lstm_kwargs['input_size'] = lstm_kwargs[self.HIDDEN_SIZE] 131 | 132 | self.hidden_size = lstm_kwargs[self.HIDDEN_SIZE] 133 | self.sequence_in_dim = embedding_dim 134 | 135 | self.C_proj = nn.Linear(in_dim, self.hidden_size) 136 | self.H_proj = nn.Linear(in_dim, self.hidden_size) 137 | self.X_proj = nn.Linear(in_dim, embedding_dim) 138 | self.LSTM = nn.LSTM(**lstm_kwargs) 139 | 140 | def forward(self, decoded_state, fts_labels, use_teacher_forcing=True): 141 | init_c = self.C_proj(decoded_state).unsqueeze(0) 142 | init_h = self.H_proj(decoded_state).unsqueeze(0) # I don't know why the unsqueeze is needed... 143 | init_x = self.X_proj(decoded_state).unsqueeze(1) 144 | 145 | assert use_teacher_forcing, "Doesn't support non-forcing yet." 146 | 147 | input_sequence = torch.cat((init_x, self.treatment_embeddings(fts_labels[:, :-1])), dim=1) 148 | # fts_labels_to_embeddings(fts_labels, self.treatment_embeddings) 149 | 150 | return self.LSTM(input_sequence, (init_h, init_c))[0] 151 | 152 | ### Overall Module 153 | # By pushing complexity up, we keep this module very simple. 154 | # TODO(mmd): Make actually work. 155 | class FutureTreatmentSequenceDecoder(nn.Module): 156 | def __init__( 157 | self, 158 | decoder_module, 159 | predictor_module, 160 | ): 161 | super().__init__() 162 | 163 | self.decoder, self.predictor = decoder_module, predictor_module 164 | 165 | def forward(self, X, labels=None, mort_pp_labels=None): 166 | # TODO(mmd): Figure this all out. 167 | 168 | decoded_state = self.decoder(X, labels) 169 | predictions, loss = self.predictor(decoded_state, labels, mort_pp_labels) 170 | return predictions, loss 171 | -------------------------------------------------------------------------------- /latent_patient_trajectories/BERT/load_and_yield_embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. 2 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import pdb 16 | 17 | import os, random, tables, torch, warnings, numpy as np, pandas as pd 18 | from torch.utils.data import DataLoader, SequentialSampler, TensorDataset 19 | from tqdm import tqdm_notebook 20 | 21 | from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE 22 | from pytorch_pretrained_bert.modeling import BertModel 23 | from pytorch_pretrained_bert.tokenization import BertTokenizer 24 | 25 | from .data_processor import * 26 | 27 | def load_and_get_ids( 28 | df, 29 | bert_model, 30 | processor, 31 | processor_args, 32 | do_lower_case, 33 | max_seq_length 34 | ): 35 | tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case) 36 | examples = processor.get_examples(df, **processor_args) 37 | 38 | features = convert_examples_to_features( 39 | examples, None, max_seq_length, tokenizer 40 | ) 41 | 42 | all_input_ids = [f.input_ids for f in features] 43 | all_input_mask = [f.input_mask for f in features] 44 | all_segment_ids = [f.segment_ids for f in features] 45 | return (all_input_ids, all_input_mask, all_segment_ids) 46 | 47 | 48 | def load_and_get_outputs( 49 | df, 50 | bert_model, 51 | processor, 52 | model_class = BertModel, # TODO(mmd): Better init 53 | model_kwargs = {}, 54 | processor_args = {}, 55 | use_gpu = False, 56 | seed = 42, 57 | do_lower_case = False, 58 | max_seq_length = 128, 59 | batch_size = 8, 60 | learning_rate = 5e-5, 61 | num_train_epochs = 3, 62 | cache_dir = None, 63 | tqdm = tqdm_notebook, 64 | save_to_filepath = None, 65 | chunksize = 0, 66 | start_at = 0, 67 | categories = None 68 | ): 69 | 70 | print('Loading pretrained model from %s' % bert_model) 71 | if chunksize > 0: assert save_to_filepath is not None, "Must save if chunking!" 72 | if start_at > 0: 73 | print('Filtering df down to rows %d - %d' % (start_at, len(df))) 74 | df = df.iloc[start_at:] 75 | 76 | if use_gpu and torch.cuda.is_available(): 77 | device = torch.device("cuda") 78 | n_gpu = torch.cuda.device_count() 79 | else: 80 | device = torch.device("cpu") 81 | n_gpu = 0 82 | 83 | random.seed(seed) 84 | np.random.seed(seed) 85 | torch.manual_seed(seed) 86 | if n_gpu > 0: 87 | torch.cuda.manual_seed_all(seed) 88 | 89 | # Prepare model 90 | ### TODO(mmd): this is where to reload the model properly. 91 | cache_dir = cache_dir if cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_-1') 92 | # Load a trained model and config that you have fine-tuned 93 | model = model_class.from_pretrained( 94 | bert_model, 95 | cache_dir=cache_dir, 96 | **model_kwargs, 97 | ) 98 | 99 | model.to(device) 100 | if n_gpu > 1: model = torch.nn.DataParallel(model) 101 | 102 | all_input_ids, all_input_mask, all_segment_ids = load_and_get_ids( 103 | df, bert_model, processor, processor_args, do_lower_case, max_seq_length 104 | ) 105 | 106 | all_input_ids = torch.tensor(all_input_ids, dtype=torch.long) 107 | all_input_mask = torch.tensor(all_input_mask, dtype=torch.long) 108 | all_segment_ids = torch.tensor(all_segment_ids, dtype=torch.long) 109 | 110 | # Run prediction for full data 111 | all_outputs, start, end = generate_input_embeddings( 112 | model, all_input_ids, all_input_mask, all_segment_ids, device, 113 | batch_size = batch_size, chunksize=chunksize, tqdm=tqdm, start_at=start_at 114 | ) 115 | 116 | output_df = pd.DataFrame(all_outputs, index=df.index[start:end]) 117 | if save_to_filepath is not None: 118 | with warnings.catch_warnings(): 119 | warnings.simplefilter('ignore', tables.NaturalNameWarning) 120 | if categories: 121 | output_df.to_hdf(save_to_filepath, categories) 122 | else: 123 | output_df.to_hdf(save_to_filepath, '[%d-%d)' % (start_at + start, start_at + end)) 124 | else: return output_df 125 | 126 | 127 | def generate_input_embeddings( 128 | model, all_input_ids, all_input_mask, all_segment_ids, device, 129 | batch_size = 8, chunksize=0, tqdm=tqdm_notebook, start_at=0, disable_tqdm=False 130 | ): 131 | 132 | data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids) 133 | sampler = SequentialSampler(data) 134 | dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size) 135 | 136 | model.eval() 137 | 138 | all_outputs, start = [], 0 139 | tqdm_loader = tqdm(dataloader, desc="Saved @ %d" % start_at, disable=False) 140 | for i, (input_ids, input_mask, segment_ids) in enumerate(tqdm_loader): 141 | input_ids = input_ids.to(device) 142 | input_mask = input_mask.to(device) 143 | segment_ids = segment_ids.to(device) 144 | 145 | with torch.no_grad(): _, output = model(input_ids, segment_ids, input_mask) 146 | all_outputs.extend(output.detach().cpu().numpy()) 147 | 148 | if chunksize > 0 and i > 0 and i % chunksize == 0: 149 | end = start + len(all_outputs) 150 | output_df = pd.DataFrame(all_outputs, index=df.index[start:end]) 151 | with warnings.catch_warnings(): 152 | warnings.simplefilter('ignore', tables.NaturalNameWarning) 153 | output_df.to_hdf(save_to_filepath, '[%d-%d)' % (start_at + start, start_at + end)) 154 | tqdm_loader.set_description("Saved @ %d" % (start_at + end)) 155 | 156 | start = end 157 | all_outputs = [] 158 | 159 | end = start + len(all_outputs) 160 | 161 | return (all_outputs, start, end) 162 | 163 | def generate_input_embeddings_test( 164 | model, all_input_ids, all_input_mask, all_segment_ids, device, 165 | batch_size = 4, tqdm=tqdm_notebook 166 | ): 167 | data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids) 168 | sampler = SequentialSampler(data) 169 | dataloader = DataLoader(data, sampler=sampler, batch_size=batch_size) 170 | 171 | all_outputs = [] 172 | for i, (input_ids, input_mask, segment_ids) in enumerate(dataloader): 173 | _, output = model(input_ids, segment_ids, input_mask) 174 | all_outputs.append(output) 175 | 176 | return all_outputs 177 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/fine_tune.py: -------------------------------------------------------------------------------- 1 | """ 2 | fine_tune.py 3 | Fine tunes a pre-trained model on a specific (single) task 4 | """ 5 | 6 | import torch.optim 7 | from torch.autograd import set_detect_anomaly 8 | from torch.utils.data import DataLoader, RandomSampler, SubsetRandomSampler 9 | 10 | import json, os, pickle, random 11 | from copy import deepcopy 12 | from tqdm import tqdm 13 | import glob 14 | 15 | 16 | # TODO: check these imports. 17 | from ..utils import * 18 | from ..constants import * 19 | from ..data_utils import * 20 | from ..BERT.model import * 21 | from ..BERT.constants import * 22 | 23 | from .fts_decoder import * 24 | from .evaluator import * 25 | from .meta_model import * 26 | from .run_model import setup_datasets_and_dataloaders, args_run_setup, train_meta_model 27 | 28 | def fine_tune_model( 29 | fine_tune_args, meta_model_args, sample_datum, binary_multilabel_keys, train_dataloaders_by_data_frac, 30 | tqdm=None, meta_model=None, tuning_dataloader=None 31 | ): 32 | print('in fine tune model') 33 | reloaded = (meta_model is not None) 34 | 35 | verbose = False 36 | if hasattr(fine_tune_args, 'verbose'): 37 | verbose = fine_tune_args.verbose 38 | 39 | if meta_model is None: 40 | meta_model = MetaModel( 41 | meta_model_args, sample_datum, 42 | class_names = {'tasks_binary_multilabel': binary_multilabel_keys}, 43 | verbose = verbose, 44 | ) 45 | 46 | if not(reloaded): 47 | reloaded, epoch = meta_model.load() 48 | 49 | if fine_tune_args.do_single_task: assert not reloaded, "Shouldn't be reloading a ST fine-tuning run!" 50 | else: assert reloaded, "Can't fine-tune a not-yet-trained model!" 51 | 52 | epoch=0 53 | reloaded=False 54 | 55 | # For fine-tuning, we want to ablate away everything *but* the target task. 56 | ablate = [k for k in ABLATION_GROUPS.keys() if k != fine_tune_args.fine_tune_task] 57 | meta_model.ablate(ablate, post_init=True) 58 | 59 | outputs = [meta_model] 60 | for data_frac, train_dataloader in train_dataloaders_by_data_frac.items(): 61 | fine_tune_dir_name = fine_tune_args.fine_tune_task 62 | if data_frac != 1: fine_tune_dir_name += f"_{str(data_frac).replace('.', '-')}" 63 | 64 | fine_tune_run_dir = os.path.join(fine_tune_args.run_dir, fine_tune_dir_name) 65 | assert os.path.isdir(fine_tune_run_dir), f"{fine_tune_run_dir} must exist!" 66 | 67 | if fine_tune_args.do_frozen_representation: 68 | meta_model_FTD = deepcopy(meta_model) 69 | meta_model_FTD_args = deepcopy(meta_model_args) 70 | 71 | meta_model_FTD.run_dir = os.path.join(fine_tune_run_dir, "FTD") 72 | meta_model_FTD.freeze_representation() 73 | meta_model_FTD_args.run_dir = meta_model_FTD.run_dir 74 | 75 | if not os.path.isdir(meta_model_FTD.run_dir): os.makedirs(meta_model_FTD.run_dir) 76 | 77 | # Train it from scractch with the representation frozen and task_weights appropriately set. 78 | train_meta_model( 79 | meta_model_FTD, train_dataloader, meta_model_FTD_args, reloaded=reloaded, epoch=epoch, 80 | tuning_dataloader=tuning_dataloader, 81 | train_embedding_after=fine_tune_args.train_embedding_after 82 | ) 83 | outputs.append(meta_model_FTD) 84 | if fine_tune_args.do_free_representation: 85 | meta_model_FTF = deepcopy(meta_model) 86 | meta_model_FTF_args = deepcopy(meta_model_args) 87 | 88 | meta_model_FTF.run_dir = os.path.join(fine_tune_run_dir, "FTF") 89 | meta_model_FTF_args.run_dir = meta_model_FTF.run_dir 90 | 91 | if not os.path.isdir(meta_model_FTF.run_dir): os.makedirs(meta_model_FTF.run_dir) 92 | 93 | # Train it from scractch with the representation frozen and task_weights appropriately set. 94 | train_meta_model( 95 | meta_model_FTF, train_dataloader, meta_model_FTF_args, reloaded=reloaded, epoch=epoch, 96 | tuning_dataloader=tuning_dataloader, 97 | train_embedding_after=fine_tune_args.train_embedding_after 98 | ) 99 | outputs.append(meta_model_FTF) 100 | return outputs 101 | 102 | def main(fine_tune_args, tqdm): 103 | assert os.path.isdir(fine_tune_args.run_dir), "Run dir must exist!" 104 | assert ( 105 | fine_tune_args.do_frozen_representation or 106 | fine_tune_args.do_free_representation 107 | ), "Need to do either FTF or FTD!" 108 | 109 | fine_tune_args.to_json_file(os.path.join(fine_tune_args.run_dir, FINE_TUNE_ARGS_FILENAME)) 110 | 111 | ablation_groups = EICU_ABLATION_GROUPS if fine_tune_args.do_eicu else ABLATION_GROUPS 112 | 113 | assert fine_tune_args.fine_tune_task in ablation_groups,\ 114 | f"Invalid fine tune task: {fine_tune_args.fine_tune_task}" 115 | assert ((fine_tune_args.frac_fine_tune_data >0) and (fine_tune_args.frac_fine_tune_data<=1)),\ 116 | "frac_fine_tune_data must be in the range(0, 1]" 117 | 118 | if fine_tune_args.do_masked_imputation_PT: 119 | meta_model_dir = os.path.dirname(fine_tune_args.run_dir) 120 | meta_model_args = Args.from_json_file(os.path.join(meta_model_dir, ARGS_FILENAME)) 121 | assert meta_model_args.do_masked_imputation, "Expected PT to do masked imputation!" 122 | assert meta_model_args.imputation_mask_rate > 0, "Expected PT to do masked imputation!" 123 | 124 | meta_model_args.do_fake_masked_imputation_shape = True 125 | meta_model_args.do_masked_imputation = False 126 | meta_model_args.imputation_mask_rate = 0 127 | else: 128 | meta_model_args = Args.from_json_file(os.path.join(fine_tune_args.run_dir, ARGS_FILENAME)) 129 | meta_model_args.set_to_eval_mode="" 130 | 131 | set_to_eval_mode = None 132 | if fine_tune_args.do_match_train_windows: 133 | meta_model_args.set_to_eval_mode = EVAL_MODES_BY_ABLATION_GROUPS[fine_tune_args.fine_tune_task] 134 | 135 | ablate = [k for k in ablation_groups if k != fine_tune_args.fine_tune_task] 136 | 137 | assert fine_tune_args.frac_fine_tune_data == 1, "frac_fine_tune_data is deprecated!" 138 | assert fine_tune_args.frac_female == 1, "frac_female is deprecated!" 139 | assert fine_tune_args.frac_black == 1, "frac_balck is deprecated!" 140 | 141 | data_fracs = [1] 142 | if fine_tune_args.do_small_data: data_fracs.extend(SMALL_DATA_FRACS) 143 | 144 | datasets, train_dataloader = setup_datasets_and_dataloaders(meta_model_args) 145 | 146 | orig_len=len(datasets['train']) 147 | 148 | # Just to be safe, here, we'll take copies of everything. 149 | orig_subjects = deepcopy(datasets['train'].orig_subjects) 150 | orig_max_hours = deepcopy(datasets['train'].orig_max_hours) 151 | orig_index = deepcopy(datasets['train'].index) 152 | subjects_hours = deepcopy(list(zip(orig_subjects, orig_max_hours))) 153 | 154 | assert len(set(subjects_hours))==len(subjects_hours) 155 | 156 | assert datasets['train'].max_seq_len == meta_model_args.max_seq_len 157 | assert train_dataloader.dataset.max_seq_len == meta_model_args.max_seq_len 158 | 159 | sample_datum = datasets['train'][0] 160 | binary_multilabel_keys = datasets['train'].get_binary_multilabel_keys() 161 | train_dataloaders_by_data_frac = {1: train_dataloader} 162 | 163 | random.seed(fine_tune_args.frac_fine_tune_data_seed) 164 | 165 | for frac in data_fracs: 166 | fine_tune_dir_name = fine_tune_args.fine_tune_task 167 | if frac != 1: fine_tune_dir_name += f"_{str(frac).replace('.', '-')}" 168 | 169 | fine_tune_run_dir = os.path.join(fine_tune_args.run_dir, fine_tune_dir_name) 170 | 171 | if not os.path.exists(fine_tune_run_dir): os.makedirs(fine_tune_run_dir) 172 | 173 | data_frac_seed = random.randint(0, int(1e10)) 174 | with open(os.path.join(fine_tune_run_dir, 'data_frac_seed.txt'), mode='w') as f: 175 | f.write(str(data_frac_seed)) 176 | 177 | random.seed(data_frac_seed) 178 | if frac != 1: 179 | frac_subjects_hours = random.choices(subjects_hours, k=int(frac * orig_len)) 180 | frac_subjects, frac_hours = zip(*frac_subjects_hours) 181 | 182 | frac_dataset = deepcopy(datasets['train']) 183 | 184 | frac_dataset.orig_subjects = frac_subjects 185 | frac_dataset.orig_max_hours = frac_hours 186 | 187 | frac_dataset.reset_sequence_len(frac_dataset.sequence_len) 188 | 189 | new_index = frac_dataset.index 190 | frac_dataset.item_cache_remap = { 191 | i: next(j for j, ov in enumerate(orig_index) if ov == nv) for i, nv in enumerate(new_index) 192 | } 193 | frac_len = len(frac_dataset) 194 | assert frac_len < orig_len, f"{len(frac_dataset)} !< {orig_len}" 195 | 196 | with open(os.path.join(fine_tune_run_dir, 'item_cache_remap.json'), mode='w') as f: 197 | f.write(json.dumps(frac_dataset.item_cache_remap)) 198 | with open(os.path.join(fine_tune_run_dir, 'frac_dataset.pkl'), mode='wb') as f: 199 | pickle.dump(frac_dataset, f) 200 | with open(os.path.join(fine_tune_run_dir, 'len_stats.json'), mode='w') as f: 201 | f.write(json.dumps({ 202 | 'orig_len': orig_len, 203 | 'frac': frac, 204 | 'frac_len': frac_len, 205 | })) 206 | 207 | sampler = RandomSampler(frac_dataset) 208 | 209 | train_dataloader = DataLoader( 210 | frac_dataset, sampler=sampler, 211 | batch_size=train_dataloader.batch_size, num_workers=train_dataloader.num_workers 212 | ) 213 | 214 | assert train_dataloader.dataset.max_seq_len == meta_model_args.max_seq_len 215 | train_dataloaders_by_data_frac[frac] = train_dataloader 216 | 217 | for (do, suffix) in [ 218 | (fine_tune_args.do_frozen_representation, "FTD"), (fine_tune_args.do_free_representation, "FTF"), 219 | ]: 220 | if not do: continue 221 | 222 | # Really the data subsetting should go in these args, for reproducibility. 223 | fine_tune_meta_model_args = deepcopy(meta_model_args) 224 | fine_tune_meta_model_args.run_dir = os.path.join(fine_tune_run_dir, suffix) 225 | 226 | fine_tune_meta_model_args.ablate = ablate 227 | 228 | args_run_setup(fine_tune_meta_model_args) 229 | 230 | return fine_tune_model( 231 | fine_tune_args, meta_model_args, sample_datum, binary_multilabel_keys, train_dataloaders_by_data_frac, 232 | tqdm=tqdm, tuning_dataloader=None, 233 | ) 234 | -------------------------------------------------------------------------------- /latent_patient_trajectories/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | PROJECT_NAME = 'latent_patient_trajectories' 4 | PROJECT_DIR = os.path.join(os.environ['PROJECTS_BASE'], PROJECT_NAME) 5 | RUNS_DIR = os.path.join(PROJECT_DIR, 'runs') 6 | HYPERPARAMETER_SEARCH_DIR = os.path.join(RUNS_DIR, 'hyperparameter_search') 7 | TASK_GENERALIZABILITY_DIR = os.path.join(RUNS_DIR, 'task_generalizability') 8 | DATA_DIR = os.path.join(PROJECT_DIR, 'processed_data') 9 | # DATA_DIR = '/scratch/gobi2/bnestor/mimic_extraction_results' 10 | ROTATIONS_DIR = os.path.join(PROJECT_DIR, 'dataset', 'rotations') 11 | EICU_ROTATIONS_DIR = os.path.join(PROJECT_DIR, 'dataset_eicu', 'rotations') 12 | DATA_FILENAME = 'all_hourly_data.h5' 13 | EICU_DATA_FILENAME = 'eicu_extract.hdf' 14 | FOLDS_FILENAME = 'subject_ids_per_fold.pkl' 15 | EICU_FOLDS_FILENAME = 'eicu_subject_ids_per_fold.pkl' 16 | FTS_FILENAME = 'treatment_sequence.h5' 17 | EICU_FTS_FILENAME = 'eicu_treatment_set.h5' 18 | NOTES_FILENAME = 'notes.hdf' 19 | ARGS_FILENAME = 'args.json' 20 | HYPERPARAMETER_SEARCH_ARGS_FILENAME = 'hyperopt_args.json' 21 | PARAMS_FILENAME = 'raw_params.json' 22 | FINE_TUNE_ARGS_FILENAME = 'fine_tune_args.json' 23 | EVAL_ARGS_FILENAME = 'eval_args.json' 24 | CLUSTERING_ARGS_FILENAME = 'clustering_args.json' 25 | CLUSTERING_CONFIG_FILENAME = 'clustering_config.json' 26 | CONFIG_FILENAME = 'bert_config.json' 27 | HYP_CONFIG_FILENAME = 'hyperparameter_search_config.json' 28 | TASK_GEN_CFG_FILENAME = 'task_generalizability_config.json' 29 | HYP_REEVAL_CFG_FILENAME = 'hyperparameter_search_re_eval_config.json' 30 | TASK_GEN_BASE_ARGS_FILENAME = 'task_generalizability_model_base_args.json' 31 | TASK_GEN_EXP_ARGS_FILENAME = 'task_generalizability_exp_args.json' 32 | GET_ALL_FLAT_REPR_ARGS_FILENAME = 'get_all_flat_repr_args.json' 33 | GET_PCA_ARGS_FILENAME = 'get_pca_args.json' 34 | 35 | FLAT_DATA_FILENAME_TEMPLATE = '{split}_{type}_flat_data.h5' 36 | 37 | TEST_DATA_FILENAME = 'all_hourly_data_test.h5' 38 | TEST_NOTES_FILENAME = 'notes_test.hdf' 39 | TEST_FTS_FILENAME = 'treatment_sequence_test.h5' 40 | 41 | STATICS = 'patients' 42 | NUMERICS = 'vitals_labs' 43 | CODES = 'codes' 44 | TREATMENTS = 'interventions' 45 | 46 | ICUSTAY_ID = 'icustay_id' 47 | SUBJECT_ID = 'subject_id' 48 | HADM_ID = 'hadm_id' 49 | 50 | ID_COLS = [ICUSTAY_ID, HADM_ID, SUBJECT_ID] 51 | 52 | EVAL_FILE_TEMPLATES = ('%s_reprs.pkl', '%s_task_info.pkl', '%s_perf_metrics.pkl') 53 | 54 | FOLD_IDX_LVL = 'Fold' 55 | K = 10 # K-fold CV 56 | 57 | EXCLUSION_CRITERIA = { 58 | # TODO(ANYONE WHO CHANGES): FTS were generated with exclusion criteria 1.5, None for LOS. Must be updated 59 | # if changed. 60 | 'los_icu': (1.5, None), 61 | } 62 | 63 | PATIENT_ID_COLS = [SUBJECT_ID, HADM_ID, ICUSTAY_ID] # TODO(mmd): We need a separate one for some joins. Why? 64 | 65 | ALL_TASKS_EICU = [ 66 | 'disch_24h', 67 | 'disch_48h', 68 | 'Final Acuity Outcome', 69 | 'tasks_binary_multilabel', 70 | 'next_timepoint', 71 | 'next_timepoint_was_measured', 72 | ] 73 | ALL_TASKS = [ 74 | 'rolling_ftseq', 75 | 'disch_24h', 76 | 'disch_48h', 77 | 'Final Acuity Outcome', 78 | 'tasks_binary_multilabel', 79 | 'next_timepoint', 80 | 'next_timepoint_was_measured', 81 | ] 82 | EICU_ABLATION_GROUPS = ['discharge', 'mortality', 'los', 'acuity', 'next_timepoint_info'] 83 | ABLATION_GROUPS = { 84 | 'icd10': [ 85 | 'icd_infection', 'icd_neoplasms', 'icd_endocrine', 'icd_blood', 'icd_mental', 'icd_nervous', 86 | 'icd_circulatory', 'icd_respiratory', 'icd_digestive', 'icd_genitourinary', 'icd_pregnancy', 87 | 'icd_skin', 'icd_musculoskeletal', 'icd_congenital', 'icd_perinatal', 'icd_ill_defined','icd_injury', 88 | 'icd_unknown' 89 | ], 90 | 'discharge': ['disch_24h', 'disch_48h'], 91 | 'mortality': ['mort_24h', 'mort_48h'], 92 | 'los': ['Long LOS'], 93 | 'readmission': ['Readmission 30'], 94 | 'future_treatment_sequence': ['rolling_ftseq'], 95 | 'acuity': ['Final Acuity Outcome'], 96 | 'next_timepoint_info': ['next_timepoint', 'next_timepoint_was_measured'], 97 | 'dnr': ['dnr_24h', 'dnr_48h'], 98 | 'cmo': ['cmo_24h', 'cmo_48h'], 99 | } 100 | 101 | ALL_SPECIFIC_TREATMENTS = [ 102 | 'vent', 103 | 'nivdurations', 104 | 'adenosine', 105 | 'dobutamine', 106 | 'dopamine', 107 | 'epinephrine', 108 | 'isuprel', 109 | 'milrinone', 110 | 'norepinephrine', 111 | 'phenylephrine', 112 | 'vasopressin', 113 | 'colloid_bolus', 114 | 'crystalloid_bolus', 115 | ] 116 | 117 | GENERALIZED_TREATMENTS = { 118 | 'vent': ('vent', 'nivdurations'), 'vaso': ('vaso',), 'bolus': ('colloid_bolus', 'crystalloid_bolus'), 119 | } 120 | 121 | DURATIONS_COL = 'hours_in' 122 | 123 | TRAIN, TUNING, HELD_OUT = 'train', 'tuning', 'held out' 124 | 125 | UNK = 'Unknown' 126 | 127 | ABBREVIATIONS = { 128 | 'Imminent Mortality': 'MOR', 129 | 'Comfort Measures': 'CMO', 130 | 'DNR Ordered': 'DNR', 131 | 'ICD Code Prediction': 'ICD', 132 | 'Long LOS': 'LOS', 133 | '30-day Readmission': 'REA', 134 | 'Imminent Discharge': 'DIS', 135 | 'Final Acuity Outcome': 'ACU', 136 | 'Next Hour Will-be-measured': 'WBM', 137 | 'Future Treatment Sequence (FTS)': 'FTS', 138 | 'Masked Imputation Regression': 'MIR', 139 | 'Masked Imputation Classification': 'MIC', 140 | } 141 | 142 | MASKED_IMPUTATION_BREAKDOWN = { 143 | ABBREVIATIONS['Masked Imputation Regression']: 'masked_imputation_regression', 144 | ABBREVIATIONS['Masked Imputation Classification']: 'masked_imputation_classification', 145 | } 146 | EICU_MANUSCRIPT_BREAKDOWN = { 147 | ABBREVIATIONS['Imminent Mortality']: ('tasks_binary_multilabel','all_time',lambda s: s.startswith('mort')), 148 | ABBREVIATIONS['Long LOS']: ('tasks_binary_multilabel','first_24','Long LOS'), 149 | ABBREVIATIONS['Imminent Discharge']: ['disch_24h', 'disch_48h'], 150 | ABBREVIATIONS['Final Acuity Outcome']: 'Final Acuity Outcome', 151 | ABBREVIATIONS['Next Hour Will-be-measured']: 'next_timepoint_was_measured', 152 | } 153 | MANUSCRIPT_BREAKDOWN = { 154 | ABBREVIATIONS['Imminent Mortality']: ('tasks_binary_multilabel','all_time',lambda s: s.startswith('mort')), 155 | ABBREVIATIONS['Comfort Measures']: ('tasks_binary_multilabel','all_time',lambda s: s.startswith('cmo')), 156 | ABBREVIATIONS['DNR Ordered']: ('tasks_binary_multilabel','all_time',lambda s: s.startswith('dnr')), 157 | ABBREVIATIONS['ICD Code Prediction']: ('tasks_binary_multilabel','first_24',lambda s: s.startswith('icd')), 158 | ABBREVIATIONS['Long LOS']: ('tasks_binary_multilabel','first_24','Long LOS'), 159 | ABBREVIATIONS['30-day Readmission']: ('tasks_binary_multilabel','extend_till_discharge','Readmission 30'), 160 | ABBREVIATIONS['Imminent Discharge']: ['disch_24h', 'disch_48h'], 161 | ABBREVIATIONS['Final Acuity Outcome']: 'Final Acuity Outcome', 162 | ABBREVIATIONS['Next Hour Will-be-measured']: 'next_timepoint_was_measured', 163 | ABBREVIATIONS['Future Treatment Sequence (FTS)']: 'rolling_ftseq', 164 | } 165 | EVAL_MODES = ('all_time', 'first_24', 'extend_till_discharge') 166 | EVAL_MODES_BY_ABLATION_GROUPS = { 167 | 'icd10': 'first_24', 168 | 'discharge': 'all_time', 169 | 'mortality': 'all_time', 170 | 'los': 'first_24', 171 | 'readmission': 'extend_till_discharge', 172 | 'future_treatment_sequence': 'all_time', 173 | 'acuity': 'first_24', 174 | 'next_timepoint_info': 'all_time', 175 | 'dnr': 'all_time', 176 | 'cmo': 'all_time', 177 | } 178 | 179 | TASK_BINARY_MULTILABEL_ORDER = ['mort_24h', 'mort_48h', 'dnr_24h', 'dnr_48h', 'cmo_24h', 180 | 'cmo_48h', 'Long LOS', 'icd_infection', 'icd_neoplasms', 181 | 'icd_endocrine', 'icd_blood', 'icd_mental', 'icd_nervous', 182 | 'icd_circulatory', 'icd_respiratory', 'icd_digestive', 183 | 'icd_genitourinary', 'icd_pregnancy', 'icd_skin', 'icd_musculoskeletal', 184 | 'icd_congenital', 'icd_perinatal', 'icd_ill_defined', 185 | 'icd_injury', 'icd_unknown', 'Readmission 30'] 186 | 187 | TASK_HEAD_MAPPING = { 188 | 'acuity': set(['task_heads.Final Acuity Outcome.weight', 189 | 'task_heads.Final Acuity Outcome.bias']), 190 | 'discharge': set(['task_heads.disch_24h.weight', 191 | 'task_heads.disch_24h.bias', 192 | 'task_heads.disch_48h.weight', 193 | 'task_heads.disch_48h.bias']), 194 | 'next_timepoint_info': set(['task_heads.next_timepoint.weight', 195 | 'task_heads.next_timepoint.bias', 196 | 'task_heads.next_timepoint_was_measured.weight', 197 | 'task_heads.next_timepoint_was_measured.bias']), 198 | 'future_treatment_sequence': set(['treatment_embeddings.weight', 199 | 'FTS_decoder.decoder.treatment_embeddings.weight', 200 | 'FTS_decoder.decoder.C_proj.weight', 201 | 'FTS_decoder.decoder.C_proj.bias', 202 | 'FTS_decoder.decoder.H_proj.weight', 203 | 'FTS_decoder.decoder.H_proj.bias', 204 | 'FTS_decoder.decoder.X_proj.weight', 205 | 'FTS_decoder.decoder.X_proj.bias', 206 | 'FTS_decoder.decoder.LSTM.weight_ih_l0', 207 | 'FTS_decoder.decoder.LSTM.weight_hh_l0', 208 | 'FTS_decoder.decoder.LSTM.bias_ih_l0', 209 | 'FTS_decoder.decoder.LSTM.bias_hh_l0', 210 | 'FTS_decoder.predictor.classifier.weight', 211 | 'FTS_decoder.predictor.classifier.bias']), 212 | 'dnr': set(), 213 | 'los': set(), 214 | 'mortality': set(), 215 | 'icd10': set(), 216 | 'readmission': set(), 217 | 'cmo': set(), 218 | } 219 | 220 | # TODO(mmd): The for tasks_binary_multilabel are augmented 221 | TASK_SPECIFIC_WEIGHT_PREFIXES = { 222 | 'acuity': set(['task_heads.Final Acuity Outcome']), 223 | 'discharge': set(['task_heads.disch_24h', 'task_heads.disch_48h']), 224 | 'next_timepoint_info': set(['task_heads.next_timepoint', 'task_heads.next_timepoint_was_measured']), 225 | 'future_treatment_sequence': set(['treatment_embeddings', 'FTS_decoder']), 226 | 'dnr': set(['task_heads.tasks_binary_multilabel.dnr']), 227 | 'los': set(['task_heads.tasks_binary_multilabel.Long LOS']), 228 | 'mortality': set(['task_heads.tasks_binary_multilabel.mort']), 229 | 'icd10': set(['task_heads.tasks_binary_multilabel.icd']), 230 | 'readmission': set(['task_heads.tasks_binary_multilabel.Readmission 30']), 231 | 'cmo': set(['task_heads.tasks_binary_multilabel.cmo']), 232 | } 233 | 234 | ALWAYS_EQ_KEYS = [ 235 | 'task_losses.tasks_binary_multilabel.BCE_LL.pos_weight', 236 | 'bert.encoder.layer.0.attention.self.key.bias', 237 | ] 238 | 239 | ABLATIONS_TO_REPORTING_MAP = { 240 | 'mortality': ABBREVIATIONS['Imminent Mortality'], 241 | 'cmo': ABBREVIATIONS['Comfort Measures'], 242 | 'dnr': ABBREVIATIONS['DNR Ordered'], 243 | 'icd10': ABBREVIATIONS['ICD Code Prediction'], 244 | 'los': ABBREVIATIONS['Long LOS'], 245 | 'readmission': ABBREVIATIONS['30-day Readmission'], 246 | 'discharge': ABBREVIATIONS['Imminent Discharge'], 247 | 'acuity': ABBREVIATIONS['Final Acuity Outcome'], 248 | 'next_timepoint_info': ABBREVIATIONS['Next Hour Will-be-measured'], 249 | 'future_treatment_sequence': ABBREVIATIONS['Future Treatment Sequence (FTS)'], 250 | } 251 | 252 | SMALL_DATA_FRACS = ( 253 | 0.00029, 0.001, 0.001778, 0.003162, 0.005623, 0.01, 0.01778279, 0.03162278, 0.05623413, 0.1, 0.1778, 254 | 0.3162, 0.5623, 255 | ) 256 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/single_task.py: -------------------------------------------------------------------------------- 1 | # Generic Imports 2 | import copy, math, itertools, json, os, queue, subprocess, sys, time 3 | import multiprocessing as mp 4 | 5 | from tqdm import tqdm 6 | 7 | # LPT Imports 8 | from . import run_model 9 | 10 | from ..constants import * 11 | from .args import * 12 | from .evaluator import * 13 | 14 | PYTHON_EXECUTABLE_PATH = sys.executable 15 | SCRIPTS_DIR = '/%s' % os.path.join(*(os.path.realpath(__file__).split('/')[:-3] + ['Scripts','End to End'])) 16 | 17 | def task_setting_to_str(task_setting): 18 | return task_setting.replace(' ', '_') 19 | 20 | def read_config_and_args(exp_dir): 21 | """ 22 | Reads a json task generalizability config, e.g.: 23 | { 24 | "gpus_per_model": 4, 25 | "models_per_gpu": 1, 26 | "gpus": [0, 1, 2, 3] 27 | } 28 | """ 29 | config_filepath = os.path.join(exp_dir, ST_CFG_FILENAME) 30 | with open(config_filepath, mode='r') as f: config = json.loads(f.read()) 31 | 32 | args_filepath = os.path.join(exp_dir, ST_BASE_ARGS_FILENAME) 33 | args = Args.from_json_file(args_filepath) 34 | 35 | return config, args 36 | 37 | class Runner(): 38 | # #&> {{log_file}}\ 39 | #COMMAND_TEMPLATE = """\ 40 | # cd {scripts_dir} 41 | # PROJECTS_BASE={projects_base} CUDA_VISIBLE_DEVICES={{gpus}} {python_path} run_v2_model.py \ 42 | # --run_dir="{{run_dir}}" \ 43 | # --do_load_from_dir 44 | #""".format( 45 | # scripts_dir=SCRIPTS_DIR, 46 | # projects_base=os.environ['PROJECTS_BASE'], 47 | # python_path=PYTHON_EXECUTABLE_PATH 48 | #).strip() 49 | #ENV = 50 | 51 | def __init__( 52 | self, 53 | gpus_per_model, 54 | run_dir, 55 | # task_setting_fine_tune_dir, 56 | gpu_queue, 57 | do_train = True, 58 | do_eval = True, 59 | do_fine_tune = True, 60 | do_fine_tune_eval = True, 61 | slurm =False, 62 | slurm =False, 63 | partition='gpu', 64 | ): 65 | assert do_train or do_eval or do_fine_tune or do_fine_tune_eval, "Must do something!" 66 | 67 | self.gpus_per_model = gpus_per_model 68 | self.run_dir = run_dir 69 | # self.task_setting_fine_tune_dir = task_setting_fine_tune_dir 70 | self.timings_file = os.path.join(run_dir, 'timings.json') 71 | self.stdout_file = os.path.join(run_dir, 'stdout.txt') 72 | self.stderr_file = os.path.join(run_dir, 'stderr.txt') 73 | self.gpu_queue = gpu_queue 74 | self.do_train = do_train 75 | self.do_eval = do_eval 76 | self.do_fine_tune = do_fine_tune 77 | self.do_fine_tune_eval = do_fine_tune_eval 78 | self.slurm = slurm 79 | self.partition = partition 80 | 81 | def call_slurm(self): 82 | env = {'PROJECTS_BASE': os.environ['PROJECTS_BASE']} 83 | prog = PYTHON_EXECUTABLE_PATH 84 | train_args = ['run_v2_model.py', '--run_dir=%s' % self.run_dir, '--do_load_from_dir'] 85 | eval_args = ['evaluate.py', '--run_dir=%s' % self.run_dir, '--do_load_from_dir'] 86 | 87 | fine_tune_args = ['fine_tune_task.py', '--run_dir=%s' % self.run_dir, '--do_load_from_dir'] 88 | fine_tune_eval_args = [ 89 | 'evaluate.py', '--run_dir=%s' % self.task_setting_fine_tune_dir, '--do_load_from_dir' 90 | ] 91 | 92 | path_root = os.path.abspath(os.path.join(os.getcwd(), '../../')) 93 | 94 | # just print the bash script 95 | bash_script=f"""#!/bin/bash 96 | #SBATCH -p {self.partition} 97 | #SBATCH --gres=gpu:1 98 | #SBATCH -c 12 99 | #SBATCH --mem=48G 100 | #SBATCH --output {os.path.join(self.run_dir, "train_%j.log")}""" 101 | bash_script_name=os.path.join(self.run_dir, 'train_task_gen.sh') 102 | 103 | with open(bash_script_name, 'w') as f: 104 | f.writelines(bash_script) 105 | bash_script=f""" 106 | SEARCH_DIR={self.run_dir} 107 | cd {os.path.join(path_root,'Scripts/End to End')}""" 108 | 109 | if self.do_train: 110 | bash_script=f""" 111 | python run_v2_model.py --run_dir $SEARCH_DIR --do_load_from_dir""" 112 | # bash_script=f""" 113 | # {prog} run_v2_model.py $SEARCH_DIR --do_load_from_dir""" 114 | with open(bash_script_name, 'a') as f: 115 | f.writelines(bash_script) 116 | 117 | if self.do_eval: 118 | bash_script=f""" 119 | python -u evaluate.py --run_dir $SEARCH_DIR --do_load_from_dir""" 120 | # bash_script=f""" 121 | # {prog} -u evaluate.py --run_dir $SEARCH_DIR --do_load_from_dir""" 122 | with open(bash_script_name, 'a') as f: 123 | f.writelines(bash_script) 124 | 125 | print(f"sbatch {bash_script_name};\n") 126 | try: 127 | 128 | with open(self.stdout_file, mode='w') as stdout_h, open(self.stderr_file, mode='w') as stderr_h: 129 | subprocess.run(f"sbatch -W {bash_script_name}",shell=True, env=os.environ.copy(), cwd=SCRIPTS_DIR, stdout=stdout_h, stderr=stderr_h) 130 | except Exception as e: 131 | print("run dir %s failed! Exception: %s" % (self.run_dir, e)) 132 | result = e 133 | 134 | 135 | return ('submitted',) 136 | 137 | def call_no_slurm(self): 138 | gpus=set() 139 | while len(gpus) < self.gpus_per_model: 140 | try: 141 | new_gpu = self.gpu_queue.get(block=True, timeout=30) 142 | if new_gpu in gpus: self.gpu_queue.put(new_gpu) 143 | else: gpus.update([new_gpu]) 144 | except queue.Empty as e: 145 | if gpus: 146 | for gpu in gpus: self.gpu_queue.put(gpu) 147 | time.sleep(90) 148 | pass 149 | 150 | gpus = [str(g) for g in gpus] 151 | 152 | env = {'PROJECTS_BASE': os.environ['PROJECTS_BASE'], 'CUDA_VISIBLE_DEVICES': ','.join(gpus)} 153 | prog = PYTHON_EXECUTABLE_PATH 154 | train_args = ['run_v2_model.py', '--run_dir=%s' % self.run_dir, '--do_load_from_dir'] 155 | eval_args = ['evaluate.py', '--run_dir=%s' % self.run_dir, '--do_load_from_dir'] 156 | 157 | # fine_tune_args = ['fine_tune_task.py', '--run_dir=%s' % self.run_dir, '--do_load_from_dir'] 158 | # fine_tune_eval_args = [ 159 | # 'evaluate.py', '--run_dir=%s' % self.task_setting_fine_tune_dir, '--do_load_from_dir' 160 | # ] 161 | 162 | 163 | 164 | print("Running for task %s with gpus %s" % (self.run_dir, ', '.join(gpus))) 165 | 166 | 167 | 168 | try: 169 | st = time.time() 170 | results = [] 171 | timings = {} 172 | with open(self.stdout_file, mode='w') as stdout_h, open(self.stderr_file, mode='w') as stderr_h: 173 | if self.do_train: 174 | tr_st = time.time() 175 | results.append(subprocess.run( 176 | [prog] + train_args, env=env, cwd=SCRIPTS_DIR, stdout=stdout_h, stderr=stderr_h 177 | )) 178 | timings['train'] = time.time() - tr_st 179 | if self.do_eval: 180 | ev_st = time.time() 181 | results.append(subprocess.run( 182 | [prog] + eval_args, env=env, cwd=SCRIPTS_DIR, stdout=stdout_h, stderr=stderr_h 183 | )) 184 | timings['eval'] = time.time() - ev_st 185 | if self.do_fine_tune: 186 | ft_st = time.time() 187 | results.append(subprocess.run( 188 | [prog] + fine_tune_args, env=env, cwd=SCRIPTS_DIR, stdout=stdout_h, stderr=stderr_h 189 | )) 190 | timings['fine_tune'] = time.time() - ft_st 191 | if self.do_fine_tune_eval: 192 | fte_st = time.time() 193 | results.append(subprocess.run( 194 | [prog] + fine_tune_eval_args, env=env, cwd=SCRIPTS_DIR, stdout=stdout_h, 195 | stderr=stderr_h, 196 | )) 197 | timings['fine_tune_eval'] = time.time() - fte_st 198 | timings['total'] = time.time() - st 199 | 200 | with open(self.timings_file, mode='w') as f: f.write(json.dumps(timings)) 201 | except Exception as e: 202 | print("run dir %s failed! Exception: %s" % (self.run_dir, e)) 203 | result = e 204 | finally: 205 | for gpu in gpus: self.gpu_queue.put(gpu) 206 | 207 | return tuple(results) 208 | 209 | def __call__(self): 210 | if self.slurm: 211 | return self.call_slurm() 212 | else: 213 | return self.call_no_slurm() 214 | 215 | def main(single_task_args, tqdm=tqdm): 216 | exp_dir = single_task_args.exp_dir 217 | single_task_args.to_json_file(os.path.join(exp_dir, ST_EXP_ARGS_FILENAME)) 218 | print(single_task_args, ST_EXP_ARGS_FILENAME) 219 | 220 | config, base_model_args = read_config_and_args(exp_dir) 221 | 222 | assert len(config['gpus']) >= config['gpus_per_model'], "Invalid config!" 223 | if config['gpus_per_model'] > 1: assert config['models_per_gpu'] == 1, "Not yet supported." 224 | 225 | rotation = single_task_args.rotation 226 | base_dir = os.path.join(exp_dir, str(rotation)) 227 | if not os.path.isdir(base_dir): os.makedirs(base_dir) 228 | 229 | base_model_args.rotation = rotation 230 | base_model_args.do_overwrite = True 231 | 232 | expected_filenames = [ 233 | 'task_weights.pkl', 234 | #'tuning_task_info.pkl', 235 | #'tuning_perf_metrics.pkl', 236 | #'test_task_info.pkl', 237 | #'test_perf_metrics.pkl', 238 | 'model.epoch-%d' % (base_model_args.epochs - 1), 239 | ] 240 | 241 | #if task_generalizability_args.do_save_all_reprs: 242 | # expected_filenames.extend(['tuning_reprs.pkl', 'test_reprs.pkl']) 243 | 244 | gpus_available = mp.Queue(maxsize=len(config['gpus']) * config['models_per_gpu']) 245 | for gpu in config['gpus']: 246 | for _ in range(config['models_per_gpu']): gpus_available.put_nowait(gpu) 247 | print("Loaded %d gpus into the queue" % gpus_available.qsize()) 248 | 249 | runners = [] 250 | # TODO(mmd): Put in config 251 | # TODO(mmd): Wrong granularity 252 | for ablation_setting in ABLATION_GROUPS.keys(): 253 | task_setting_str = ablation_setting 254 | task_dir = os.path.join(base_dir, task_setting_str) 255 | if not os.path.isdir(task_dir): os.makedirs(task_dir) 256 | 257 | expected_final_filepaths = [os.path.join(task_dir, fn) for fn in expected_filenames] 258 | task_complete = all(os.path.isfile(fp) for fp in expected_final_filepaths) 259 | 260 | if task_complete: continue 261 | 262 | task_setting_args = copy.deepcopy(base_model_args) 263 | task_setting_args.run_dir = task_dir 264 | # everything but current task 265 | task_setting_args.ablate = [k for k in ABLATION_GROUPS.keys() if k != task_setting_str] 266 | 267 | task_setting_args.to_json_file(os.path.join(task_dir, ARGS_FILENAME)) 268 | 269 | # task_setting_fine_tune_args = FineTuneArgs( 270 | # run_dir = task_dir, 271 | # fine_tune_task = task_setting_str, 272 | # num_dataloader_workers = 8, # should be in arg... 273 | # do_match_train_windows = single_task_args.do_match_FT_train_windows, 274 | # ) 275 | # task_setting_fine_tune_args.to_json_file(os.path.join(task_dir, FINE_TUNE_ARGS_FILENAME)) 276 | 277 | task_setting_eval_args = EvalArgs( 278 | run_dir = task_dir, 279 | rotation = rotation, 280 | do_save_all_reprs = True, 281 | do_eval_train = True, 282 | do_eval_tuning = True, 283 | do_eval_test = True, 284 | num_dataloader_workers = 8, 285 | ) 286 | task_setting_eval_args.to_json_file(os.path.join(task_dir, EVAL_ARGS_FILENAME)) 287 | 288 | # task_setting_fine_tune_dir = os.path.join(task_dir, task_setting_str) 289 | # if not os.path.exists(task_setting_fine_tune_dir): os.makedirs(task_setting_fine_tune_dir) 290 | 291 | # task_setting_fine_tune_eval_args = EvalArgs( 292 | # run_dir = task_setting_fine_tune_dir, 293 | # rotation = rotation, 294 | # do_save_all_reprs = False, 295 | # do_eval_train = False, 296 | # do_eval_tuning = True, 297 | # do_eval_test = True, 298 | # num_dataloader_workers = 8, 299 | # ) 300 | # task_setting_fine_tune_eval_args.to_json_file(os.path.join(task_setting_fine_tune_dir, EVAL_ARGS_FILENAME)) 301 | 302 | 303 | 304 | 305 | runners.append(Runner( 306 | gpus_per_model = config['gpus_per_model'], 307 | run_dir = task_dir, 308 | # task_setting_fine_tune_dir = task_setting_fine_tune_dir, 309 | gpu_queue = gpus_available, 310 | do_train = single_task_args.do_train, 311 | do_eval = single_task_args.do_eval, 312 | do_fine_tune = single_task_args.do_fine_tune, 313 | do_fine_tune_eval = single_task_args.do_fine_tune_eval, 314 | slurm = task_generalizability_args.slurm, 315 | partition = task_generalizability_args.partition, 316 | )) 317 | 318 | 319 | processes = [mp.Process(target=r) for r in runners] 320 | for process in processes: process.start() 321 | 322 | results = [process.join() for process in processes] 323 | 324 | if single_task_args.slurm: 325 | return 326 | 327 | with open(os.path.join(exp_dir, 'results.pkl'), mode='wb') as f: pickle.dump(results, f) 328 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | model.py 3 | This contains the source for the pytorch model which learns how to represent 4 | """ 5 | 6 | import glob, os, random, numpy as np, pandas as pd 7 | import torch, torch.optim, torch.nn as nn, torch.nn.functional as F 8 | from typing import Sequence 9 | from dataclasses import dataclass 10 | from torch.autograd import Variable, set_detect_anomaly 11 | from torch.utils.data import ( 12 | DataLoader, Dataset, RandomSampler, SubsetRandomSampler, Subset, SequentialSampler 13 | ) 14 | from torch.utils.data.distributed import DistributedSampler 15 | from tqdm import tqdm, trange 16 | 17 | idx = pd.IndexSlice 18 | 19 | from ..utils import * 20 | from ..constants import * 21 | from ..data_utils import * 22 | from ..representation_learner.fts_decoder import * 23 | from ..BERT.model import * 24 | from ..BERT.constants import * 25 | from .utils import * 26 | 27 | # TODO(mmd): Move to utils. 28 | def fts_decoder_loss(logits, fts_label): 29 | loss_fct = nn.CrossEntropyLoss(reduction='none') 30 | mask = (fts_label != 0).float() 31 | 32 | loss = (loss_fct(logits, fts_label) * mask).sum(dim=1) 33 | num_labels_per_el = mask.sum(dim=1) 34 | num_labels_per_el_or_one = torch.where( 35 | num_labels_per_el > 0, num_labels_per_el, torch.ones_like(num_labels_per_el) 36 | ) #loss will be zero where num_labels_per_el is zero and we want to avoid nans. 37 | loss = (loss / num_labels_per_el_or_one).mean(dim=0) # TODO: validate loss scale. 38 | return logits, loss 39 | 40 | def single_label_loss(logits, labels): F.cross_entropy(logits, labels) 41 | def multi_label_loss(logits, labels): F.multilabel_soft_margin_loss(logits, labels) 42 | def mse_loss(X, Y): F.mse_loss(X, Y) 43 | 44 | class GenericPredictor(nn.Module): 45 | def __init__( 46 | self, 47 | in_dim, 48 | out_dim, 49 | loss_fct 50 | out_net = nn.Linear 51 | ): 52 | super().__init__() 53 | self.out_net = out_net(in_dim, out_dim) 54 | self.loss_fct = loss_fct 55 | 56 | def forward(self, X, labels=None): 57 | out = self.out_net(X) 58 | if labels is None: return out, None 59 | 60 | return out, self.loss_fct(out, labels) 61 | 62 | def get_tasks_dict(config): 63 | """ 64 | returns {task: (task_weight, task_head (nn.Module, forward(X, labels=None): out/logits, loss (if labels)} 65 | X = pooled_output. 66 | """ 67 | heads = {} 68 | heads['next_timepoint'] = GenericPredictor(config.hidden_size, config.num_feat, mse_loss) 69 | heads['next_timepoint_was_measured'] = GenericPredictor( 70 | config.hidden_size, config.num_feat, multi_label_loss 71 | ) 72 | heads['static_tasks_continuous'] = GenericPredictor(config.hidden_size, ???, mse_loss) 73 | heads['rolling_fts'] = FutureTreatmentSequenceDecoder( 74 | decoder_module = LSTMDecoder( 75 | in_dim = config.hidden_size, 76 | treatment_embeddings = self.treatment_embeddings, 77 | ), 78 | predictor_module = SingleTaskPredictor( 79 | in_dim = 25, # TODO(mmd): Make params!! 80 | num_classes = 9, 81 | ), 82 | ) 83 | heads['rolling_tasks_continuous'] = GenericPredictor(config.hidden_size, ???, mse_loss) 84 | 85 | # TODO(mmd): Re-do extractors such that it isn't so dumb. 86 | 87 | return {k: (1, v) for k, v in heads.items()} 88 | 'rolling_tasks_to_embed': (1.0, FullyConnectedNet), 89 | 'static_tasks_to_embed': (1.0, FullyConnectedNet), 90 | } 91 | 92 | class SelfAttentionTimeseries(BertPreTrainedModel): 93 | """ TODO(this) 94 | """ 95 | def __init__( 96 | self, config, use_cuda=False, tasks={}, 97 | ): 98 | super().__init__(config) 99 | self.bert = ContinuousBertModel(config) 100 | # self.cls = ContinuousBertPreTrainingHeads(config) # modify this to get all of the necessary tasks 101 | self.apply(self.init_bert_weights) 102 | self.use_cuda = use_cuda 103 | 104 | self.tasks = nn.ModuleDict(tasks) # {task: (weight, head)} 105 | 106 | self.ts_continuous_projector = nn.Linear(config.ts_feat_dim, config.hidden_dim) 107 | self.statics_continuous_projector = nn.Linear(config.statics_feat_dim, config.hidden_dim) 108 | 109 | self.embedders = None # TODO 110 | 111 | # additional losses 112 | self.treatment_embeddings = nn.Embedding( 113 | num_embeddings = 9, # TODO(mmd): Actually set this... 114 | embedding_dim = 25 # TODO(mmd): Actually set this... Belongs in config... 115 | ) 116 | 117 | # forward should be called with a dictionary, via, e.g., model(**batch) 118 | def forward( 119 | self, 120 | 121 | # Inputs: 122 | ts_continuous = None, # batch X seq_len X features 123 | ts_to_embed = None, # batch X seq_len X features 124 | ts_mask = None, # batch X seq_len X 1 125 | statics = None, # batch X features 126 | 127 | # Tasks: 128 | **tasks_kwargs, 129 | ): 130 | # TODO(mmd): Embedding Features... 131 | input_sequence = self.ts_continuous_projector(ts_continuous) 132 | statics_continuous = self.statics_projector(statics) 133 | statics_continuous = statics_continuous.unsqueeze(1).expand_as(input_sequence) 134 | 135 | input_sequence += statics_continuous 136 | 137 | ts_mask = ts_mask.expand_as(input_sequence) 138 | 139 | _, pooled_output = self.bert(input_sequence, None, ts_mask, 140 | output_all_encoded_layers=False) 141 | # sequence_output.shape is batch_size, max_seq_length, hidden_dim 142 | # pooled_output.shape is batch_size, hidden_dim 143 | 144 | total_loss, tasks_out = 0, {} 145 | for task_name in set(self.tasks.keys()).intersection(tasks_kwargs.keys()): 146 | weight, head = self.tasks[task_name] 147 | task_labels = tasks_kwargs[task_label] 148 | out, loss = head(pooled_output, task_labels) 149 | tasks_out[task_name] = (out, task_labels, loss) 150 | if loss is not None: total_loss += weight * loss 151 | 152 | return ( 153 | pooled_output, 154 | tasks_out, 155 | total_loss 156 | ) 157 | 158 | 159 | 160 | 161 | 162 | class CNN(nn.Module): 163 | """ TODO(this) 164 | """ 165 | def __init__( 166 | self, config, use_cuda=False, tasks={}, conv_layers = [10, 100, 20], filt_size= [7, 5, 5] 167 | ): 168 | super(CNN).__init__(config) 169 | 170 | 171 | # conv 2d appraoch 172 | self.conv1 = nn.Conv2d(1, conv_layers[0], (filt_size[0],1)) # in channels, out channels, kernel size 173 | self.conv2 = nn.Conv2d(conv_layers[0], conv_layers[1], (filt_size[1], 1)) 174 | self.conv3 = nn.Conv2d(conv_layers[1], conv_layers[2], (filt_size[2], 1)) 175 | 176 | 177 | self.fc1 = nn.Linear(7840, 512) 178 | self.fc2 = nn.Linear(512, config.hidden_dim) 179 | self.relu = nn.ReLU() 180 | 181 | # self.cls = ContinuousBertPreTrainingHeads(config) # modify this to get all of the necessary tasks 182 | self.apply(self.init_bert_weights) 183 | self.use_cuda = use_cuda 184 | 185 | self.tasks = nn.ModuleDict(tasks) # {task: (weight, head)} 186 | 187 | self.ts_continuous_projector = nn.Linear(config.ts_feat_dim, config.hidden_dim) 188 | self.statics_continuous_projector = nn.Linear(config.statics_feat_dim, config.hidden_dim) 189 | 190 | self.embedders = None # TODO 191 | 192 | # additional losses 193 | self.treatment_embeddings = nn.Embedding( 194 | num_embeddings = 9, # TODO(mmd): Actually set this... 195 | embedding_dim = 25 # TODO(mmd): Actually set this... Belongs in config... 196 | ) 197 | 198 | # forward should be called with a dictionary, via, e.g., model(**batch) 199 | def forward( 200 | self, 201 | 202 | # Inputs: 203 | ts_continuous = None, # batch X seq_len X features 204 | ts_to_embed = None, # batch X seq_len X features 205 | ts_mask = None, # batch X seq_len X 1 206 | statics = None, # batch X features 207 | 208 | # Tasks: 209 | **tasks_kwargs, 210 | ): 211 | # TODO(mmd): Embedding Features... 212 | input_sequence = self.ts_continuous_projector(ts_continuous) 213 | statics_continuous = self.statics_projector(statics) 214 | statics_continuous = statics_continuous.unsqueeze(1).expand_as(input_sequence) 215 | 216 | input_sequence += statics_continuous 217 | 218 | ts_mask = ts_mask.expand_as(input_sequence) 219 | 220 | # replace with conv layers 221 | # _, pooled_output = self.bert(input_sequence, None, ts_mask, 222 | # output_all_encoded_layers=False) 223 | 224 | print(input_sequence.shape) 225 | input_sequence.unsqueeze(1) 226 | x = self.relu(self.conv1(x)) 227 | x = self.relu(self.conv2(x)) 228 | x = self.relu(self.conv3(x)) 229 | x = x.view(batch_size, -1) 230 | x = self.relu(self.fc1(x)) 231 | pooled_output = self.fc2(x) 232 | 233 | # sequence_output.shape is batch_size, max_seq_length, hidden_dim 234 | # pooled_output.shape is batch_size, hidden_dim 235 | 236 | total_loss, tasks_out = 0, {} 237 | for task_name in set(self.tasks.keys()).intersection(tasks_kwargs.keys()): 238 | weight, head = self.tasks[task_name] 239 | task_labels = tasks_kwargs[task_label] 240 | out, loss = head(pooled_output, task_labels) 241 | tasks_out[task_name] = (out, task_labels, loss) 242 | if loss is not None: total_loss += weight * loss 243 | 244 | return ( 245 | pooled_output, 246 | tasks_out, 247 | total_loss 248 | ) 249 | 250 | 251 | class GRUModel(nn.Module): 252 | """ TODO(this) 253 | """ 254 | def __init__( 255 | self, config, device, use_cuda=False, tasks={}, hidden_dim=512, num_layers=2, drop_prob=0.2, 256 | ): 257 | super(GRUModel).__init__(config) 258 | 259 | # initialise the model and the weights 260 | self.hidden_dim = hidden_dim 261 | self.n_layers = n_layers 262 | self.gru = nn.GRU(input_size=config.input_dim, hidden_size=hidden_dim, num_layers=n_layers, batch_first=True, dropout=drop_prob, bidirectional=False) 263 | 264 | # self.fc = nn.Linear(hidden_dim*2, output_dim) # for bidirectional 265 | self.fc = nn.Linear(hidden_dim*config.input_dim, output_dim) # not bidirectional 266 | self.relu = nn.ReLU() 267 | # self.h_0 = torch.zeros(2*n_layers, 1, hidden_dim).float().to('cuda:0') # for bidirectional 268 | self.h_0 = torch.zeros(n_layers, 1, hidden_dim).float().to('cuda:0' if use_cuda else 'cpu') # not bidirectional 269 | self.hidden_dim=hidden_dim 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | # self.cls = ContinuousBertPreTrainingHeads(config) # modify this to get all of the necessary tasks 278 | self.apply(self.init_bert_weights) 279 | self.use_cuda = use_cuda 280 | 281 | self.tasks = nn.ModuleDict(tasks) # {task: (weight, head)} 282 | 283 | self.ts_continuous_projector = nn.Linear(config.ts_feat_dim, config.hidden_dim) 284 | self.statics_continuous_projector = nn.Linear(config.statics_feat_dim, config.hidden_dim) 285 | 286 | self.embedders = None # TODO 287 | 288 | # additional losses 289 | self.treatment_embeddings = nn.Embedding( 290 | num_embeddings = 9, # TODO(mmd): Actually set this... 291 | embedding_dim = 25 # TODO(mmd): Actually set this... Belongs in config... 292 | ) 293 | 294 | # forward should be called with a dictionary, via, e.g., model(**batch) 295 | def forward( 296 | self, 297 | 298 | # Inputs: 299 | ts_continuous = None, # batch X seq_len X features 300 | ts_to_embed = None, # batch X seq_len X features 301 | ts_mask = None, # batch X seq_len X 1 302 | statics = None, # batch X features 303 | h_0 = None 304 | 305 | # Tasks: 306 | **tasks_kwargs, 307 | ): 308 | # TODO(mmd): Embedding Features... 309 | input_sequence = self.ts_continuous_projector(ts_continuous) 310 | statics_continuous = self.statics_projector(statics) 311 | statics_continuous = statics_continuous.unsqueeze(1).expand_as(input_sequence) 312 | 313 | input_sequence += statics_continuous 314 | 315 | ts_mask = ts_mask.expand_as(input_sequence) 316 | 317 | 318 | if h_0 is None: 319 | h_0 = self.h_0 320 | 321 | bs = x.shape[0] 322 | if bs!=1: 323 | h_0 = h_0.expand(-1, bs, -1).contiguous() 324 | 325 | # out, (h, c) = self.gru(x, (h_0, c_0)) # for lstm 326 | out, h = self.gru(input_sequence, h_0) # for gru 327 | 328 | # out = out.view(-1, 2 * self.hidden_dim) # num directions is 2 for forward bachward rnn 329 | out = out.contiguous().view(bs, -1) # num directions is 1 for forward rnn 330 | pooled_output = self.fc(self.relu(out)) 331 | 332 | 333 | # sequence_output.shape is batch_size, max_seq_length, hidden_dim 334 | # pooled_output.shape is batch_size, hidden_dim 335 | 336 | total_loss, tasks_out = 0, {} 337 | for task_name in set(self.tasks.keys()).intersection(tasks_kwargs.keys()): 338 | weight, head = self.tasks[task_name] 339 | task_labels = tasks_kwargs[task_label] 340 | out, loss = head(pooled_output, task_labels) 341 | tasks_out[task_name] = (out, task_labels, loss) 342 | if loss is not None: total_loss += weight * loss 343 | 344 | return ( 345 | pooled_output, 346 | tasks_out, 347 | total_loss 348 | ) 349 | -------------------------------------------------------------------------------- /latent_patient_trajectories/BERT/model.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/pytorch_pretrained_bert 3 | # at commit e6cf62d49945e6277b5e4dc855f9186b3f789e35 4 | from ..pytorch_helpers import * 5 | 6 | import numpy as np, torch 7 | 8 | from pytorch_pretrained_bert.modeling import ( 9 | BertEncoder, BertPooler, BertPredictionHeadTransform, BertPreTrainedModel, BertModel, BertLayerNorm, 10 | BertConfig 11 | ) 12 | from torch import nn 13 | 14 | class ContinuousBertConfig(BertConfig): 15 | """ An extension of the BERT Config to store continuous params as well. 16 | """ 17 | def __init__(self, *args, in_dim=-1, **kwargs): 18 | super().__init__(0, *args, **kwargs) 19 | self.in_dim = in_dim 20 | 21 | class ContinuousBertEmbeddings(nn.Module): 22 | """Construct the embeddings from word, position and token_type embeddings. 23 | """ 24 | def __init__(self, config): 25 | super().__init__() 26 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) 27 | # should eliminate the below two and fold into meta_model. 28 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) 29 | 30 | # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load 31 | # any TensorFlow checkpoint file 32 | self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) 33 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 34 | 35 | def forward(self, sequence, token_type_ids=None): 36 | #input ids should be in the form of (batch_size, time_steps, feature_size) 37 | seq_length = sequence.size(1) 38 | position_ids = torch.arange(seq_length, dtype=torch.long, device=sequence.device) 39 | position_ids = position_ids.unsqueeze(0).expand_as(sequence[:, :, 0]) 40 | if token_type_ids is None: token_type_ids = torch.zeros_like(sequence[:, :, 0]).long() 41 | 42 | sequence_embeddings = sequence 43 | position_embeddings = self.position_embeddings(position_ids) 44 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 45 | 46 | embeddings = sequence_embeddings + position_embeddings + token_type_embeddings 47 | embeddings = self.LayerNorm(embeddings) 48 | embeddings = self.dropout(embeddings) 49 | return embeddings 50 | 51 | class ContinuousBertModel(BertPreTrainedModel): 52 | """Continuous BERT model ("Bidirectional Embedding Representations from a Transformer"). 53 | Just like BERT but a different kind embedding layer. 54 | 55 | Params: 56 | config: a BertConfig class instance with the configuration to build a new model 57 | 58 | Inputs: 59 | `sequence`: a torch.LongTensor of shape [batch_size, sequence_length] 60 | with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts 61 | `extract_features.py`, `run_classifier.py` and `run_squad.py`) 62 | `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 63 | types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to 64 | a `sentence B` token (see BERT paper for more details). 65 | `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 66 | selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 67 | input sequence length in the current batch. It's the mask that we typically use for attention when 68 | a batch has varying length sentences. 69 | `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. 70 | 71 | Outputs: Tuple of (encoded_layers, pooled_output) 72 | `encoded_layers`: controled by `output_all_encoded_layers` argument: 73 | - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end 74 | of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each 75 | encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], 76 | - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding 77 | to the last attention block of shape [batch_size, sequence_length, hidden_size], 78 | `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a 79 | classifier pretrained on top of the hidden state associated to the first character of the 80 | input (`CLS`) to train on the Next-Sentence task (see BERT's paper). 81 | 82 | Example usage: 83 | ```python 84 | # TODO(): Update 85 | sequence = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 86 | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 87 | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 88 | 89 | config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 90 | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 91 | 92 | model = modeling.BertModel(config=config) 93 | all_encoder_layers, pooled_output = model(sequence, token_type_ids, input_mask) 94 | ``` 95 | """ 96 | def __init__(self, config): 97 | super().__init__(config) 98 | self.embedder = ContinuousBertEmbeddings(config) 99 | self.encoder = BertEncoder(config) 100 | self.pooler = BertPooler(config) 101 | self.apply(self.init_bert_weights) 102 | 103 | def forward(self, sequence, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): 104 | if attention_mask is None: attention_mask = torch.ones_like(sequence[:, :, 0]).long() 105 | if token_type_ids is None: token_type_ids = torch.zeros_like(sequence[:, :, 0]).long() 106 | 107 | # We create a 3D attention mask from a 2D tensor mask. 108 | # Sizes are [batch_size, 1, 1, to_seq_length] 109 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 110 | # this attention mask is more simple than the triangular masking of causal attention 111 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 112 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 113 | 114 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 115 | # masked positions, this operation will create a tensor which is 0.0 for 116 | # positions we want to attend and -10000.0 for masked positions. 117 | # Since we are adding it to the raw scores before the softmax, this is 118 | # effectively the same as removing these entirely. 119 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 120 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 121 | 122 | embedded_sequence = self.embedder(sequence, token_type_ids) 123 | encoded_layers = self.encoder( 124 | embedded_sequence, extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers 125 | ) 126 | sequence_output = encoded_layers[-1] 127 | pooled_output = self.pooler(sequence_output) 128 | if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] 129 | return encoded_layers, pooled_output 130 | 131 | #class BertReconstructionPredictionHead(nn.Module): 132 | # def __init__(self, config): 133 | # super().__init__() 134 | # self.transform = BertPredictionHeadTransform(config) # May not need the transform here. 135 | # self.decoder = nn.Linear(config.hidden_size, config.in_dim, bias=True) 136 | # 137 | # def forward(self, hidden_states): return self.decoder(self.transform(hidden_states)) 138 | # 139 | #class ContinuousBertPreTrainingHeads(nn.Module): 140 | # def __init__(self, config, task_dims={}): 141 | # super().__init__() 142 | # self.predictions = BertReconstructionPredictionHead(config) 143 | # self.seq_predictions = MultitaskHead(config, task_dims) 144 | # 145 | # def forward(self, sequence_output, pooled_output): 146 | # prediction_scores = self.predictions(sequence_output) 147 | # seq_scores = self.seq_predictions(pooled_output) 148 | # return prediction_scores, seq_scores 149 | # 150 | #class ContinuousBertForPreTraining(BertPreTrainedModel): 151 | # """BERT model with continuous reconstruction pre-training heads. 152 | # Also supports additional, user-specified auxiliary losses. 153 | # 154 | # This module comprises the BERT model followed by the two pre-training heads: 155 | # - the masked continuous reconstruction modeling head, and 156 | # - the next sequence classification head. 157 | # as well as any user specified auxiliary loss prediction heads. 158 | # 159 | # Params: 160 | # config: a BertConfig class instance with the configuration to build a new model. 161 | # 162 | # Inputs: 163 | # `sequence`: a torch.LongTensor of shape [batch_size, sequence_length, hidden_size] 164 | # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token 165 | # types indices selected in [0, 1]. Type 0 corresponds to a `sequence A` and type 1 corresponds to 166 | # a `sequence B` token (see BERT paper for more details). 167 | # `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices 168 | # selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max 169 | # input sequence length in the current batch. It's the mask that we typically use for attention when 170 | # a batch has varying length sequences. 171 | # `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] 172 | # with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss 173 | # is only computed for the labels set in [0, ..., vocab_size] 174 | # `next_sequence_label`: optional next sequence classification loss: torch.LongTensor of shape [batch_size] 175 | # with indices selected in [0, 1]. 176 | # 0 => next sequence is the continuation, 1 => next sequence is a random sequence. 177 | # 178 | # Outputs: 179 | # if `masked_lm_labels` and `next_sequence_label` are not `None`: 180 | # Outputs the total_loss which is the sum of the masked language modeling loss and the next 181 | # sequence classification loss. 182 | # if `masked_lm_labels` or `next_sequence_label` is `None`: 183 | # Outputs a tuple comprising 184 | # - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and 185 | # - the next sequence classification logits of shape [batch_size, 2]. 186 | # 187 | # Example usage: 188 | # ```python 189 | # # Already been converted into WordPiece token ids 190 | # sequence = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) 191 | # input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) 192 | # token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) 193 | # 194 | # config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, 195 | # num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) 196 | # 197 | # model = BertForPreTraining(config) 198 | # masked_lm_logits_scores, seq_relationship_logits = model(sequence, token_type_ids, input_mask) 199 | # ``` 200 | # """ 201 | # def __init__(self, config, seq_task_dims, lambda_seq_tasks=1): 202 | # super().__init__(config) 203 | # self.bert = ContinuousBertModel(config) 204 | # self.cls = ContinuousBertPreTrainingHeads(config, seq_task_dims) 205 | # self.lambda_seq_tasks = lambda_seq_tasks 206 | # self.apply(self.init_bert_weights) 207 | # 208 | # def forward( 209 | # self, 210 | # masked_sequence_targets, input_sequence, attention_mask, token_type_ids, el_was_masked, 211 | # whole_sequence_labels, whole_sequence_labels_present, 212 | # ): 213 | # # TODO(mmd): Auxiliary Losses. 214 | # sequence_output, pooled_output = self.bert(input_sequence, token_type_ids, attention_mask, 215 | # output_all_encoded_layers=False) 216 | # reconstructed_sequence, seq_scores = self.cls(sequence_output, pooled_output) 217 | # 218 | # total_loss = 0 219 | # if masked_sequence_targets is not None: 220 | # reconstruction_loss_fct = nn.MSELoss(reduction="none") 221 | # masked_reconstruction_loss = reconstruction_loss_fct( 222 | # reconstructed_sequence, masked_sequence_targets 223 | # ) 224 | # masked_reconstruction_loss *= el_was_masked.unsqueeze(2).expand_as(masked_reconstruction_loss) 225 | # masked_reconstruction_loss = (masked_reconstruction_loss.sum())/(el_was_masked.sum()) 226 | # total_loss = masked_reconstruction_loss + total_loss 227 | # 228 | # for task, label in whole_sequence_labels.items(): 229 | # labels_present = whole_sequence_labels_present[task] 230 | # labels_present_sum = labels_present.sum() 231 | # 232 | # scores = seq_scores[task] 233 | # 234 | # # TODO(mmd): Generalize 235 | # whole_sequence_loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduction='none') 236 | # 237 | # whole_sequence_loss = whole_sequence_loss_fct(scores, label.view(-1)) 238 | # whole_sequence_loss = torch.where( 239 | # labels_present == 1, whole_sequence_loss, torch.zeros_like(whole_sequence_loss) 240 | # ).sum() 241 | # 242 | # whole_sequence_loss = torch.where( 243 | # labels_present_sum > 0, 244 | # whole_sequence_loss/labels_present_sum, 245 | # torch.zeros_like(whole_sequence_loss) 246 | # ) 247 | # 248 | # total_loss = total_loss + self.lambda_seq_tasks * whole_sequence_loss 249 | # 250 | # return pooled_output, total_loss, reconstructed_sequence, seq_scores 251 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/task_generalizability.py: -------------------------------------------------------------------------------- 1 | # Generic Imports 2 | import copy, math, itertools, json, os, queue, subprocess, sys, time 3 | import multiprocessing as mp 4 | 5 | from tqdm import tqdm 6 | 7 | # LPT Imports 8 | from . import run_model 9 | 10 | from ..constants import * 11 | from .args import * 12 | from .evaluator import * 13 | 14 | PYTHON_EXECUTABLE_PATH = sys.executable 15 | SCRIPTS_DIR = '/%s' % os.path.join(*(os.path.realpath(__file__).split('/')[:-3] + ['Scripts','End to End'])) 16 | 17 | def task_setting_to_str(task_setting): 18 | return task_setting.replace(' ', '_') 19 | 20 | 21 | def read_config_and_args(exp_dir): 22 | """ 23 | Reads a json task generalizability config, e.g.: 24 | { 25 | "gpus_per_model": 4, 26 | "models_per_gpu": 1, 27 | "gpus": [0, 1, 2, 3] 28 | } 29 | """ 30 | config_filepath = os.path.join(exp_dir, TASK_GEN_CFG_FILENAME) 31 | with open(config_filepath, mode='r') as f: config = json.loads(f.read()) 32 | 33 | args_filepath = os.path.join(exp_dir, TASK_GEN_BASE_ARGS_FILENAME) 34 | args = Args.from_json_file(args_filepath) 35 | 36 | return config, args 37 | 38 | class Runner(): 39 | def __init__( 40 | self, 41 | gpus_per_model, 42 | run_dir, 43 | task_setting_fine_tune_dir, 44 | gpu_queue, 45 | do_train = True, 46 | do_eval = True, 47 | do_fine_tune = True, 48 | do_fine_tune_eval = True, 49 | slurm =False, 50 | partition='gpu', 51 | slurm_args=None, 52 | do_small_data=False, 53 | do_imbalanced_sex_data=False, 54 | do_imbalanced_race_data=False, 55 | do_frozen_representation=True, 56 | do_free_representation=False, 57 | do_single_task=False, 58 | do_masked_imputation_PT = False, 59 | do_copy_masked_imputation_PT = False, 60 | ): 61 | assert do_train or do_eval or do_fine_tune or do_fine_tune_eval or do_small_data or do_imbalanced_sex_data or do_imbalanced_race_data, "Must do something!" 62 | 63 | self.gpus_per_model = gpus_per_model 64 | self.run_dir = run_dir 65 | self.task_setting_fine_tune_dir = task_setting_fine_tune_dir 66 | self.timings_file = os.path.join(run_dir, 'timings.json') 67 | self.stdout_file = os.path.join(run_dir, '{key}_stdout.txt') 68 | self.stderr_file = os.path.join(run_dir, '{key}_stderr.txt') 69 | self.gpu_queue = gpu_queue 70 | self.do_train = do_train 71 | self.do_eval = do_eval 72 | self.do_fine_tune = do_fine_tune 73 | self.do_fine_tune_eval = do_fine_tune_eval 74 | self.slurm = slurm 75 | self.partition = partition 76 | self.slurm_args = slurm_args 77 | self.do_small_data = do_small_data 78 | self.do_imbalanced_sex_data = do_imbalanced_sex_data 79 | self.do_imbalanced_race_data = do_imbalanced_race_data 80 | self.do_frozen_representation = do_frozen_representation 81 | self.do_free_representation = do_free_representation 82 | self.do_single_task = do_single_task 83 | self.do_masked_imputation_PT = do_masked_imputation_PT 84 | self.do_copy_masked_imputation_PT = do_copy_masked_imputation_PT 85 | 86 | if self.do_single_task or (self.do_masked_imputation_PT and self.do_copy_masked_imputation_PT): 87 | assert not self.do_train, "Shouldn't pre-train a single-task/MI model!" 88 | assert not self.do_eval, "Shouldn't eval a (non-existent) PT model in single-task/MI mode!" 89 | 90 | def run(self, args, env, key, timings, results): 91 | st = time.time() 92 | stdout_file = self.stdout_file.format(key=key) 93 | stderr_file = self.stderr_file.format(key=key) 94 | with open(stdout_file, mode='w') as stdout_h, open(stderr_file, mode='w') as stderr_h: 95 | ev_st = time.time() 96 | results.append(subprocess.run( 97 | [PYTHON_EXECUTABLE_PATH] + args, env=env, cwd=SCRIPTS_DIR, stdout=stdout_h, stderr=stderr_h, 98 | check = True 99 | )) 100 | timings[key] = time.time() - ev_st 101 | 102 | def call_no_slurm(self): 103 | gpus=set() 104 | while len(gpus) < self.gpus_per_model: 105 | try: 106 | new_gpu = self.gpu_queue.get(block=True, timeout=30) 107 | if new_gpu in gpus: self.gpu_queue.put(new_gpu) 108 | else: gpus.update([new_gpu]) 109 | except queue.Empty as e: 110 | if gpus: 111 | for gpu in gpus: self.gpu_queue.put(gpu) 112 | time.sleep(90) 113 | pass 114 | 115 | gpus = [str(g) for g in gpus] 116 | 117 | env = {'PROJECTS_BASE': os.environ['PROJECTS_BASE'], 'CUDA_VISIBLE_DEVICES': ','.join(gpus)} 118 | prog = PYTHON_EXECUTABLE_PATH 119 | 120 | pretrain_args = ['run_v2_model.py', '--run_dir=%s' % self.run_dir, '--do_load_from_dir'] 121 | eval_args = ['evaluate.py', '--run_dir=%s' % self.run_dir, '--do_load_from_dir'] 122 | fine_tune_args = ['fine_tune_task.py', '--run_dir=%s' % self.run_dir, '--do_load_from_dir'] 123 | 124 | print("Running for task %s with gpus %s" % (self.run_dir, ', '.join(gpus))) 125 | 126 | data_fracs = [1] 127 | if self.do_small_data: data_fracs.extend(SMALL_DATA_FRACS) 128 | 129 | try: 130 | st = time.time() 131 | results = [] 132 | timings = {} 133 | if self.do_train: self.run(pretrain_args, env, 'train', timings, results) 134 | if self.do_eval: self.run(eval_args, env, 'eval', timings, results) 135 | if self.do_fine_tune: self.run(fine_tune_args, env, 'fine_tune', timings, results) 136 | if self.do_fine_tune_eval: 137 | for frac in data_fracs: 138 | for do, suffix in [ 139 | (self.do_frozen_representation, "FTD"), (self.do_free_representation, "FTF"), 140 | ]: 141 | if not do: continue 142 | 143 | fine_tune_dir_name = self.task_setting_fine_tune_dir 144 | if frac != 1: fine_tune_dir_name += f"_{str(frac).replace('.', '-')}" 145 | 146 | fine_tune_dir = os.path.join(fine_tune_dir_name, suffix) 147 | assert os.path.isdir(fine_tune_dir) 148 | 149 | fine_tune_eval_args = [ 150 | 'evaluate.py', f"--run_dir={fine_tune_dir}", '--do_load_from_dir' 151 | ] 152 | self.run( 153 | fine_tune_eval_args, env, f"{100*frac}%_data_{suffix}_eval", timings, results 154 | ) 155 | 156 | with open(self.timings_file, mode='w') as f: f.write(json.dumps(timings)) 157 | except Exception as e: 158 | print("run dir %s failed! Exception: %s" % (self.run_dir, e)) 159 | result = e 160 | finally: 161 | for gpu in gpus: self.gpu_queue.put(gpu) 162 | 163 | return tuple(results) 164 | 165 | def __call__(self): 166 | if self.slurm: 167 | raise NotImplementedError("Slurm support has been deprecated.") 168 | else: 169 | return self.call_no_slurm() 170 | 171 | def main(task_generalizability_args, tqdm=tqdm): 172 | exp_dir = task_generalizability_args.exp_dir 173 | task_generalizability_args.to_json_file(os.path.join(exp_dir, TASK_GEN_EXP_ARGS_FILENAME)) 174 | print(task_generalizability_args, TASK_GEN_EXP_ARGS_FILENAME) 175 | 176 | config, base_model_args = read_config_and_args(exp_dir) 177 | assert os.path.exists(base_model_args.dataset_dir), f'{base_model_args.dataset_dir} does not exist' 178 | 179 | assert ( 180 | task_generalizability_args.do_frozen_representation or 181 | task_generalizability_args.do_free_representation 182 | ), "Need to do either FTF or FTD!" 183 | 184 | assert task_generalizability_args.do_eicu == base_model_args.do_eicu 185 | 186 | if task_generalizability_args.do_single_task: 187 | assert not task_generalizability_args.do_train, "Can't pre-train a single-task model!" 188 | assert not task_generalizability_args.do_eval, "Can't eval a pre-trained model in single-task mode!" 189 | assert not task_generalizability_args.do_frozen_representation, \ 190 | "FTD doesn't make sense in single-task mode!" 191 | assert not task_generalizability_args.do_masked_imputation_PT, \ 192 | "Can't do both single-task and Masked Imputation!" 193 | elif task_generalizability_args.do_masked_imputation_PT: 194 | some_PT = ( 195 | task_generalizability_args.do_train or task_generalizability_args.do_copy_masked_imputation_PT 196 | ) 197 | assert some_PT, "It isn't masked imputation PT with PT!" 198 | if task_generalizability_args.do_copy_masked_imputation_PT: 199 | assert not task_generalizability_args.do_train, "Shouldn't copy and PT!" 200 | assert not task_generalizability_args.do_eval, "Shouldn't copy and eval!" 201 | 202 | assert len(config['gpus']) >= config['gpus_per_model'], "Invalid config!" 203 | if config['gpus_per_model'] > 1: assert config['models_per_gpu'] == 1, "Not yet supported." 204 | 205 | rotation = task_generalizability_args.rotation 206 | base_dir = os.path.join(exp_dir, str(rotation)) 207 | if not os.path.isdir(base_dir): os.makedirs(base_dir) 208 | 209 | do_masked_imputation_PT = task_generalizability_args.do_masked_imputation_PT 210 | do_copy_masked_imputation_PT = task_generalizability_args.do_copy_masked_imputation_PT 211 | 212 | base_model_args.rotation = rotation 213 | if do_masked_imputation_PT and do_copy_masked_imputation_PT: base_model_args.do_overwrite = False 214 | else: base_model_args.do_overwrite = True 215 | 216 | gpus_available = mp.Queue(maxsize=len(config['gpus']) * config['models_per_gpu']) 217 | for gpu in config['gpus']: 218 | for _ in range(config['models_per_gpu']): gpus_available.put_nowait(gpu) 219 | print("Loaded %d gpus into the queue" % gpus_available.qsize()) 220 | 221 | do_single_task = task_generalizability_args.do_single_task 222 | single_task = task_generalizability_args.single_task 223 | 224 | runners = [] 225 | # TODO(mmd): Put in config 226 | # TODO(mmd): Wrong granularity 227 | ablation_settings = EICU_ABLATION_GROUPS if task_generalizability_args.do_eicu else ABLATION_GROUPS.keys() 228 | print(f"Running on {', '.join(ablation_settings)}") 229 | 230 | if do_masked_imputation_PT: 231 | assert do_copy_masked_imputation_PT, \ 232 | "Currently doesn't support masked imputation PT, so must be copied." 233 | 234 | if do_copy_masked_imputation_PT: 235 | tuning_eval_file = 'tuning_perf_metrics.pkl' 236 | test_eval_file = 'test_perf_metrics.pkl' 237 | last_model_file = f"model.epoch-{base_model_args.epochs-1}" 238 | 239 | for fn in (tuning_eval_file, test_eval_file, last_model_file): 240 | assert os.path.isfile(os.path.join(base_dir, fn)), "Missing required file for copied PT!" 241 | else: 242 | raise NotImplementedError("Not yet supported.") 243 | # TODO(mmd): Here is where we'd re-do PT training. 244 | 245 | for ablation_setting in ablation_settings: 246 | if do_single_task and (ablation_setting != single_task): continue 247 | print(f"Setting up Runner for {ablation_setting}") 248 | 249 | task_setting_str = ablation_setting 250 | if do_single_task: 251 | task_dir = base_dir 252 | else: 253 | task_dir = os.path.join(base_dir, task_setting_str) 254 | if not os.path.isdir(task_dir): os.makedirs(task_dir) 255 | 256 | if not do_masked_imputation_PT: 257 | task_setting_args = copy.deepcopy(base_model_args) 258 | task_setting_args.run_dir = task_dir 259 | 260 | if do_single_task: task_setting_args.ablate = [t for t in ablation_settings if t != task_setting_str] 261 | else: task_setting_args.ablate = task_setting_str 262 | 263 | task_setting_args.to_json_file(os.path.join(task_dir, ARGS_FILENAME)) 264 | 265 | if (not do_single_task) and task_generalizability_args.do_eval: 266 | task_setting_eval_args = EvalArgs( 267 | run_dir = task_dir, 268 | rotation = rotation, 269 | do_save_all_reprs = False, 270 | do_eval_train = False, 271 | do_eval_tuning = True, 272 | do_eval_test = True, 273 | do_eicu = task_generalizability_args.do_eicu, 274 | num_dataloader_workers = 8, 275 | ) 276 | task_setting_eval_args.to_json_file(os.path.join(task_dir, EVAL_ARGS_FILENAME)) 277 | 278 | task_setting_fine_tune_args = FineTuneArgs( 279 | run_dir = task_dir, 280 | fine_tune_task = task_setting_str, 281 | num_dataloader_workers = 8, # should be in arg... 282 | do_match_train_windows = task_generalizability_args.do_match_FT_train_windows, #todo add early stopping 283 | train_embedding_after = task_generalizability_args.train_embedding_after, 284 | balanced_race = task_generalizability_args.do_imbalanced_race_data, 285 | do_eicu = task_generalizability_args.do_eicu, 286 | do_frozen_representation = task_generalizability_args.do_frozen_representation, 287 | do_free_representation = task_generalizability_args.do_free_representation, 288 | do_single_task = task_generalizability_args.do_single_task, 289 | do_small_data = task_generalizability_args.do_small_data, 290 | do_masked_imputation_PT = do_masked_imputation_PT, 291 | ) 292 | task_setting_fine_tune_args.to_json_file(os.path.join(task_dir, FINE_TUNE_ARGS_FILENAME)) 293 | 294 | task_setting_fine_tune_dir = os.path.join(task_dir, task_setting_str) 295 | 296 | # TODO(mmd): This is terrible. This should be standardized, and it should only be needed in one place. 297 | # We should just move the eval stuff to the fine_tune.py code. 298 | data_fracs = [1] 299 | if task_generalizability_args.do_small_data: data_fracs.extend(SMALL_DATA_FRACS) 300 | 301 | for frac in data_fracs: 302 | for do, suffix in [ 303 | (task_generalizability_args.do_frozen_representation, "FTD"), 304 | (task_generalizability_args.do_free_representation, "FTF"), 305 | ]: 306 | if not do: continue 307 | 308 | fine_tune_dir_name = task_setting_str 309 | if frac != 1: fine_tune_dir_name += f"_{str(frac).replace('.', '-')}" 310 | 311 | fine_tune_dir = os.path.join(task_dir, fine_tune_dir_name, suffix) 312 | if not os.path.isdir(fine_tune_dir): os.makedirs(fine_tune_dir) 313 | 314 | task_setting_fine_tune_eval_args = EvalArgs( 315 | run_dir = fine_tune_dir, 316 | rotation = rotation, 317 | do_save_all_reprs = False, 318 | do_eval_train = False, 319 | do_eval_tuning = True, 320 | do_eval_test = True, 321 | do_eicu = task_generalizability_args.do_eicu, 322 | num_dataloader_workers = 8, 323 | ) 324 | task_setting_fine_tune_eval_args.to_json_file(os.path.join(fine_tune_dir, EVAL_ARGS_FILENAME)) 325 | 326 | runners.append(Runner( 327 | gpus_per_model = config['gpus_per_model'], 328 | run_dir = task_dir, 329 | task_setting_fine_tune_dir = task_setting_fine_tune_dir, 330 | gpu_queue = gpus_available, 331 | do_train = task_generalizability_args.do_train, 332 | do_eval = task_generalizability_args.do_eval, 333 | do_fine_tune = task_generalizability_args.do_fine_tune, 334 | do_fine_tune_eval = task_generalizability_args.do_fine_tune_eval, 335 | slurm = task_generalizability_args.slurm, 336 | partition = task_generalizability_args.partition, 337 | slurm_args = task_generalizability_args.slurm_args, 338 | do_small_data=task_generalizability_args.do_small_data, 339 | do_imbalanced_sex_data = task_generalizability_args.do_imbalanced_sex_data, 340 | do_imbalanced_race_data = task_generalizability_args.do_imbalanced_race_data, 341 | do_frozen_representation = task_generalizability_args.do_frozen_representation, 342 | do_free_representation = task_generalizability_args.do_free_representation, 343 | do_single_task = task_generalizability_args.do_single_task, 344 | do_masked_imputation_PT = do_masked_imputation_PT, 345 | do_copy_masked_imputation_PT = not do_copy_masked_imputation_PT, 346 | )) 347 | 348 | 349 | processes = [mp.Process(target=r) for r in runners] 350 | for process in processes: process.start() 351 | 352 | results = [process.join() for process in processes] 353 | 354 | if task_generalizability_args.slurm: 355 | return 356 | 357 | with open(os.path.join(exp_dir, 'results.pkl'), mode='wb') as f: pickle.dump(results, f) 358 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/validate_weights.py: -------------------------------------------------------------------------------- 1 | from ..constants import * 2 | import torch, os, pickle, io, numpy as np 3 | 4 | # some constants for this script 5 | 6 | class CPU_Unpickler(pickle.Unpickler): 7 | def find_class(self, module, name): 8 | if module == 'torch.storage' and name == '_load_from_bytes': 9 | return lambda b: torch.load(io.BytesIO(b), map_location='cpu') 10 | else: return super().find_class(module, name) 11 | 12 | def cpu_unpickle(fp): 13 | with open(fp, mode='rb') as f: return CPU_Unpickler(f).load() 14 | 15 | def get_epoch(task, epoch, base): 16 | return torch.load(os.path.join(base, task, 'model.epoch-%d' % epoch), map_location='cpu') 17 | 18 | def get_epoch_direct(d, epoch): 19 | return torch.load(os.path.join(d, 'model.epoch-%d' % epoch), map_location='cpu') 20 | 21 | def validate_multilabel_weights( 22 | difference_w, multilabel_indices, multilabel_order=TASK_BINARY_MULTILABEL_ORDER 23 | ): 24 | # detecting if bias 25 | if difference_w.shape == (len(TASK_BINARY_MULTILABEL_ORDER),): 26 | col_w = difference_w 27 | else: 28 | col_w = difference_w.max(axis=1) 29 | 30 | for i in range(len(col_w)): 31 | if i in multilabel_indices: 32 | assert col_w[i] != 0, f"{multilabel_order[i]} weight doesn't change but should!" 33 | else: 34 | assert col_w[i] == 0, f"{multilabel_order[i]} weight isn't 0. It should be 0" 35 | 36 | def validate_multilabel_weights_task_gen(difference_w, multilabel_indices): 37 | # detecting if bias 38 | if difference_w.shape == (len(TASK_BINARY_MULTILABEL_ORDER),): 39 | col_w = difference_w 40 | else: 41 | col_w = difference_w.max(axis=1) 42 | 43 | for i in range(len(col_w)): 44 | if i in multilabel_indices: 45 | assert col_w[i] == 0, "%s weight changes! It isn't expected to change" % TASK_BINARY_MULTILABEL_ORDER[i] 46 | else: 47 | assert col_w[i] != 0, "%s weight is 0. It shouldnt be 0" % TASK_BINARY_MULTILABEL_ORDER[i] 48 | 49 | 50 | 51 | def pickle_weights_validate(base, task, current_ablation, multilabel_indices): 52 | task_weights = cpu_unpickle(os.path.join(base, task, 'joint_weights.pkl')) 53 | for k in task_weights['task_weights']: 54 | if k in current_ablation or k == 'tasks_binary_multilabel': 55 | assert task_weights['task_weights'][k] == 1, "%s should be set to 1" % k 56 | else: 57 | assert task_weights['task_weights'][k] == 0, "%s should be set to 0" % k 58 | 59 | def pickle_weights_validate_task_gen(base, task, current_ablation, multilabel_indices): 60 | task_weights = cpu_unpickle(os.path.join(base, task, 'joint_weights.pkl')) 61 | for k in task_weights['task_weights']: 62 | if k not in current_ablation or k == 'tasks_binary_multilabel': 63 | assert task_weights['task_weights'][k] == 1, "%s should be set to 1" % k 64 | else: 65 | assert task_weights['task_weights'][k] == 0, "%s should be set to 0" % k 66 | 67 | # multilabel matrix ablation 68 | multilabel_task_weights = task_weights['task_class_weights']['tasks_binary_multilabel'] 69 | for i in range(len(TASK_BINARY_MULTILABEL_ORDER)): 70 | if i not in multilabel_indices: 71 | assert multilabel_task_weights[i] == 1, "%s should be ablated" % TASK_BINARY_MULTILABEL_ORDER[i] 72 | else: 73 | assert multilabel_task_weights[i] == 0, "%s should not be ablated" % TASK_BINARY_MULTILABEL_ORDER[i] 74 | 75 | def validate_singleton_weights_gen( 76 | base, ablate=None, encoder_should_change=True, args=None, do_assert=False, 77 | tasks_binary_multilabel_order=TASK_BINARY_MULTILABEL_ORDER, 78 | ): 79 | assert args is not None 80 | assert ablate == args.ablate 81 | 82 | uses_weight_decay = args.do_weight_decay 83 | do_masked_imputation = args.do_masked_imputation 84 | do_eicu = args.do_eicu 85 | 86 | if uses_weight_decay: 87 | # Can't account for weight_decay here. 88 | return False, None, (None, None, None, None) 89 | 90 | valid_weights_should_change = [] 91 | valid_weights_should_not_change = [] 92 | invalid_weights_should_change_but_do_not = [] 93 | invalid_weights_should_not_change_but_do = [] 94 | 95 | def _validate_multilabel_weights_fn(difference_w, static_multilabel_indices, overall_key): 96 | # detecting if bias 97 | if difference_w.shape == (len(tasks_binary_multilabel_order),): delta_w = difference_w 98 | else: delta_w = difference_w.max(axis=1) 99 | 100 | for i in range(len(delta_w)): 101 | key = tasks_binary_multilabel_order[i] 102 | if i in static_multilabel_indices: 103 | if delta_w[i] != 0: 104 | if do_assert: raise AssertionError(f"{key} weight shouldn't change and does!") 105 | else: invalid_weights_should_not_change_but_do.append((overall_key, key, delta_w[i])) 106 | else: valid_weights_should_not_change.append((overall_key, key, delta_w[i])) 107 | else: 108 | if delta_w[i] == 0: 109 | if do_assert: raise AssertionError(f"{key} weights should change and doesn't!") 110 | else: invalid_weights_should_change_but_do_not.append((overall_key, key)) 111 | else: valid_weights_should_change.append((overall_key, key)) 112 | 113 | def weights_changed(ep_1, ep_2, key=''): 114 | difference_w = ep_1.detach().cpu().numpy() - ep_2.detach().cpu().numpy() 115 | return np.abs(difference_w).max() > 0, (difference_w, np.abs(difference_w).max()) 116 | 117 | def assert_diff(ep_1, ep_2, should_change=True, key=''): 118 | changed, delta = weights_changed(ep_1, ep_2) 119 | 120 | if do_assert: 121 | if should_change: assert changed, f"{key} should change!" + str(delta) 122 | else: assert not changed, f"{key} should not change!" + str(delta) 123 | else: 124 | if should_change and changed: valid_weights_should_change.append((key, delta)) 125 | elif should_change and not changed: invalid_weights_should_change_but_do_not.append(key) 126 | elif not should_change and changed: invalid_weights_should_not_change_but_do.append((key, delta)) 127 | elif not should_change and not changed: valid_weights_should_not_change.append(key) 128 | 129 | if type(ablate) is str: ablate = [ablate] 130 | assert type(ablate) is list or type(ablate) is tuple, f"Ablation {ablate} is the wrong type!" 131 | 132 | # At this point, everything is ablating 'next_timepoint' 133 | ablations = ['next_timepoint'] 134 | ablated_task_heads = [f"task_heads.next_timepoint.{e}" for e in ('weight', 'bias')] 135 | 136 | if not do_masked_imputation: 137 | ablations.append('masked_imputation') 138 | ablated_task_heads.extend([f"task_heads.masked_imputation.{e}" for e in ('weight', 'bias')]) 139 | 140 | if do_eicu: 141 | ablations.append('FTS') 142 | ablated_task_heads.extend([ 143 | 'FTS_decoder.decoder.C_proj.bias', 144 | 'FTS_decoder.decoder.C_proj.weight', 145 | 'FTS_decoder.decoder.H_proj.bias', 146 | 'FTS_decoder.decoder.H_proj.weight', 147 | 'FTS_decoder.decoder.LSTM.bias_hh_l0', 148 | 'FTS_decoder.decoder.LSTM.bias_ih_l0', 149 | 'FTS_decoder.decoder.LSTM.weight_hh_l0', 150 | 'FTS_decoder.decoder.LSTM.weight_ih_l0', 151 | 'FTS_decoder.decoder.X_proj.bias', 152 | 'FTS_decoder.decoder.X_proj.weight', 153 | 'FTS_decoder.decoder.treatment_embeddings.weight', 154 | 'FTS_decoder.predictor.classifier.bias', 155 | 'FTS_decoder.predictor.classifier.weight', 156 | 'treatment_embeddings.weight', 157 | ]) 158 | 159 | for task in ablate: 160 | if task in ABLATION_GROUPS: 161 | ablations.extend(ABLATION_GROUPS[task]) 162 | ablated_task_heads.extend(TASK_HEAD_MAPPING[task]) 163 | elif task in ('next_timepoint_was_measured', 'next_timepoint_info'): 164 | ablations.append('next_timepoint_info') 165 | ablated_task_heads.extend( 166 | [f"task_heads.next_timepoint_was_measured.{e}" for e in ('weight', 'bias')] 167 | ) 168 | elif task == 'next_timepoint': 169 | ablations.append('next_timepoint') 170 | ablated_task_heads.extend( 171 | [f"task_heads.next_timepoint.{e}" for e in ('weight', 'bias')] 172 | ) 173 | elif task == 'masked_imputation': 174 | ablated_task_heads.extend([f"task_heads.masked_imputation.{e}" for e in ('weight', 'bias')]) 175 | else: raise AssertionError( 176 | f"{task} invalid. Should be in {ABLATION_GROUPS.keys()} or next_timepoint_info..." 177 | ) 178 | 179 | # set for fast lookup 180 | # col indices to change in binary_multilabel matrix 181 | static_multilabel_indices = set([ 182 | i for i, t in enumerate(tasks_binary_multilabel_order) if t in ablations 183 | ]) 184 | 185 | ep_1_w = get_epoch_direct(base, 1) 186 | ep_2_w = get_epoch_direct(base, 4) 187 | 188 | for projector in ('ts_projector', 'statics_projector'): 189 | assert ep_1_w[projector].keys() == ep_2_w[projector].keys() 190 | for k in ep_1_w[projector]: 191 | assert_diff( 192 | ep_1_w[projector][k], ep_2_w[projector][k], encoder_should_change, f"{projector}[{k}]" 193 | ) 194 | 195 | assert ep_1_w['model'].keys() == ep_2_w['model'].keys() 196 | for k in ep_1_w['model']: 197 | ep_1, ep_2 = ep_1_w['model'][k], ep_2_w['model'][k] 198 | if k.startswith('bert') or k.startswith('gru') or k.startswith('fc_stack'): 199 | assert_diff(ep_1, ep_2, encoder_should_change, k) 200 | elif k == 'task_losses.tasks_binary_multilabel.BCE_LL.pos_weight': 201 | assert_diff(ep_1, ep_2, False, k) 202 | elif k in ( 203 | 'task_heads.tasks_binary_multilabel.weight', 204 | 'task_heads.tasks_binary_multilabel.bias' 205 | ): 206 | _validate_multilabel_weights_fn( 207 | np.abs(ep_1.detach().cpu().numpy() - ep_2.detach().cpu().numpy()), static_multilabel_indices, 208 | k 209 | ) 210 | else: 211 | assert_diff(ep_1, ep_2, k not in ablated_task_heads, k) 212 | 213 | any_invalid = ( 214 | (len(invalid_weights_should_change_but_do_not) > 0) 215 | or (len(invalid_weights_should_not_change_but_do) > 0) 216 | ) 217 | return True, any_invalid, ( 218 | valid_weights_should_change, 219 | valid_weights_should_not_change, 220 | invalid_weights_should_change_but_do_not, 221 | invalid_weights_should_not_change_but_do, 222 | ) 223 | 224 | def validate_singleton_weights(base, task=None, encoder_should_change=True, args=None): 225 | # TODO(mmd): Update to support masked_imputation validation 226 | all_tasks_should_change = (task is None) 227 | 228 | uses_weight_decay = args is not None and args.do_weight_decay 229 | assert not uses_weight_decay, "Can't validate model that is using weight decay!" 230 | 231 | if uses_weight_decay: print("Accounting for weight decay.") 232 | 233 | def weights_changed(ep_1, ep_2, key=''): 234 | difference_w = ep_1.detach().cpu().numpy() - ep_2.detach().cpu().numpy() 235 | if uses_weight_decay: 236 | mag_delta = (ep_1.detach().cpu().numpy() ** 2).mean() - (ep_2.detach().cpu().numpy() ** 2).mean() 237 | return mag_delta <= 0, (difference_w, difference_w.min(), mag_delta) 238 | #return difference_w.min() <= 0, (difference_w, difference_w.min(), mag_delta) 239 | else: 240 | return np.abs(difference_w).max() > 0, (difference_w, np.abs(difference_w).max()) 241 | 242 | def assert_diff(ep_1, ep_2, should_change=True, key=''): 243 | changed, delta = weights_changed(ep_1, ep_2) 244 | if should_change: assert changed, f"{key} should change!" + str(delta) 245 | else: assert not changed, f"{key} should not change!" + str(delta) 246 | 247 | if task is not None: 248 | if task in ABLATION_GROUPS: 249 | current_ablation = ABLATION_GROUPS[task] 250 | current_task_heads = TASK_HEAD_MAPPING[task] 251 | elif task == 'next_timepoint_was_measured': 252 | current_ablation = 'next_timepoint_info' 253 | current_task_heads = {f"task_heads.next_timepoint_was_measured.{e}" for e in ('weight', 'bias')} 254 | elif task == 'masked_imputation': 255 | current_ablation = '' 256 | current_task_heads = {f"task_heads.masked_imputation.{e}" for e in ('weight', 'bias')} 257 | else: raise AssertionError(f"{task} invalid. Should be in {ABLATION_GROUPS} or next_timepoint_was...") 258 | 259 | # set for fast lookup 260 | # col indices to change in binary_multilabel matrix 261 | multilabel_indices = set([ 262 | i for i, t in enumerate(TASK_BINARY_MULTILABEL_ORDER) if t in current_ablation 263 | ]) 264 | 265 | ep_1_w = get_epoch_direct(base, 1) 266 | ep_2_w = get_epoch_direct(base, 4) 267 | 268 | for projector in ('ts_projector', 'statics_projector'): 269 | for k in ep_1_w[projector]: 270 | assert_diff( 271 | ep_1_w[projector][k], ep_2_w[projector][k], encoder_should_change, f"{projector}[{k}]" 272 | ) 273 | 274 | for k in ep_1_w['model']: 275 | ep_1, ep_2 = ep_1_w['model'][k], ep_2_w['model'][k] 276 | if k.startswith('bert') or k.startswith('gru') or k.startswith('fc_stack'): 277 | assert_diff(ep_1, ep_2, encoder_should_change, k) 278 | elif all_tasks_should_change: 279 | assert_diff(ep_1, ep_2, True, k) 280 | elif k in ( 281 | 'task_heads.tasks_binary_multilabel.weight', 282 | 'task_heads.tasks_binary_multilabel.bias' 283 | ): 284 | validate_multilabel_weights( 285 | np.abs(ep_1.detach().cpu().numpy() - ep_2.detach().cpu().numpy()), multilabel_indices 286 | ) 287 | else: 288 | if k in current_task_heads: print("validating should change") 289 | assert_diff(ep_1, ep_2, k in current_task_heads, k) 290 | 291 | def validate_weights(task, base, encoder_should_change=True): 292 | assert task in ABLATION_GROUPS, "task doesn't exist : not in ABLATION_GROUPS constant" 293 | current_ablation = ABLATION_GROUPS[task] 294 | current_task_heads = TASK_HEAD_MAPPING[task] 295 | # col indices to change in binary_multilabel matrix 296 | # set for fast lookup 297 | multilabel_indices = set([i for i, t in enumerate(TASK_BINARY_MULTILABEL_ORDER) if t in current_ablation]) 298 | 299 | ep_1_w = get_epoch(task, 1, base) 300 | ep_22_w = get_epoch(task, 16, base) 301 | 302 | for projector in ('ts_projector', 'statics_projector'): 303 | for k in ep_1_w[projector]: 304 | difference_w = np.abs( 305 | ep_1_w[projector][k].detach().cpu().numpy() 306 | - ep_22_w[projector][k].detach().cpu().numpy() 307 | ) 308 | if encoder_should_change: 309 | assert difference_w.max() != 0, f"{projector}[{k}] should change for {base}/{task}" 310 | else: 311 | assert difference_w.max() == 0, f"{projector}[{k}] shouldn't change for {base}/{task}" 312 | 313 | for k in ep_1_w['model']: 314 | difference_w = np.abs( 315 | ep_1_w['model'][k].detach().cpu().numpy() 316 | - ep_22_w['model'][k].detach().cpu().numpy() 317 | ) 318 | if ( 319 | k.startswith('bert') or k.startswith('gru') or k.startswith('fc_stack') 320 | ): 321 | if encoder_should_change: 322 | assert difference_w.max() != 0, f"{k}'s weights should change for {base}/{task}" 323 | else: 324 | assert difference_w.max() == 0, f"{k}'s weights shouldn't change for {base}/{task}" 325 | elif k in ( 326 | 'task_heads.tasks_binary_multilabel.weight', 327 | 'task_heads.tasks_binary_multilabel.bias' 328 | ): 329 | validate_multilabel_weights(difference_w, multilabel_indices) 330 | elif k in current_task_heads: 331 | assert difference_w.max() != 0, f"{k} difference is 0 for {base}/{task}" 332 | else: 333 | assert difference_w.max() == 0, f"{k} difference is not 0 for {base}/{task}" 334 | 335 | # pickle weights validation 336 | pickle_weights_validate(base, task, current_ablation, multilabel_indices) 337 | 338 | def validate_weights_task_gen(task, base): 339 | assert task in ABLATION_GROUPS, "task doesn't exist : not in ABLATION_GROUPS constant" 340 | current_ablation = ABLATION_GROUPS[task] 341 | current_task_heads = TASK_HEAD_MAPPING[task] 342 | # col indices to change in binary_multilabel matrix 343 | # set for fast lookup 344 | multilabel_indices = set([i for i, t in enumerate(TASK_BINARY_MULTILABEL_ORDER) if t in current_ablation]) 345 | 346 | ep_1_w = get_epoch(task, 1, base) 347 | ep_22_w = get_epoch(task, 16, base) 348 | 349 | for projector in ('ts_projector', 'statics_projector'): 350 | for k in ep_1_w[projector]: 351 | difference_w = np.abs( 352 | ep_1_w[projector][k].detach().cpu().numpy() 353 | - ep_22_w[projector][k].detach().cpu().numpy() 354 | ) 355 | assert difference_w.max() != 0, f"{projector}[{k}] should change for {base}/{task}" 356 | 357 | for k in ep_1_w['model']: 358 | difference_w = np.abs( 359 | ep_1_w['model'][k].detach().cpu().numpy() 360 | - ep_22_w['model'][k].detach().cpu().numpy() 361 | ) 362 | if k.startswith('bert') or k.startswith('gru') or k.startswith('fc_stack'): 363 | assert difference_w.max() != 0, f"{k}'s weights should change" 364 | elif k == 'task_heads.tasks_binary_multilabel.weight' or k == 'task_heads.tasks_binary_multilabel.bias': 365 | validate_multilabel_weights_task_gen(difference_w, multilabel_indices) 366 | elif k in current_task_heads or k in ALWAYS_EQ_KEYS: 367 | assert difference_w.max() == 0, '%s difference is not 0' % k 368 | else: 369 | assert difference_w.max() != 0, '%s difference is 0' % k 370 | 371 | # pickle weights validation 372 | pickle_weights_validate_task_gen(base, task, current_ablation, multilabel_indices) 373 | 374 | def validate_all(base, single_task=False, ft_encoder=False, tasks=None): 375 | if tasks is None: tasks = list(ABLATION_GROUPS.keys()) 376 | assert os.path.isdir(base) 377 | tasks = set(tasks).intersection(os.listdir(base)) 378 | print(f"validating for path: {base} and tasks {', '.join(tasks)}") 379 | for t in tasks: 380 | print(f"Validating {t}") 381 | if single_task: 382 | validate_weights(t, os.path.join(base, t), encoder_should_change=True) 383 | else: 384 | if not ft_encoder: validate_weights_task_gen(t, base) 385 | validate_weights(t, os.path.join(base, t), encoder_should_change = ft_encoder) 386 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/adapted_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | SelfAttentionEncoder.py 3 | """ 4 | 5 | import torch, torch.optim, torch.nn as nn, torch.nn.functional as F, torch.nn.init as init 6 | from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertModel, BertConfig 7 | from torch.autograd import Variable, set_detect_anomaly 8 | from torch.utils.data import ( 9 | DataLoader, Dataset, RandomSampler, SubsetRandomSampler, Subset, SequentialSampler 10 | ) 11 | from torch.utils.data.distributed import DistributedSampler 12 | 13 | from math import floor 14 | 15 | from ..utils import * 16 | from ..constants import * 17 | from ..data_utils import * 18 | from ..representation_learner.fts_decoder import * 19 | from ..BERT.model import * 20 | from ..BERT.constants import * 21 | 22 | from copy import deepcopy 23 | 24 | class TaskBinaryMultilabelLoss(nn.Module): 25 | def __init__(self, binary_multilabel_loss_weight=None): 26 | super().__init__() 27 | self.weights = binary_multilabel_loss_weight 28 | self.weights.requires_grad_(False) 29 | params = {'pos_weight': binary_multilabel_loss_weight, 'reduction': 'none'} 30 | self.BCE_LL = nn.BCEWithLogitsLoss(**params) 31 | 32 | def forward(self, logits, labels): 33 | new_weights = self.weights.unsqueeze(0).expand_as(logits) 34 | out = self.BCE_LL(logits, labels) 35 | out = out * new_weights 36 | return out 37 | 38 | def get_task_losses(task_class_weights): 39 | task_losses = {} 40 | for t in ('disch_24h', 'disch_48h'): 41 | # May have missingness. 42 | params = {'ignore_index': -1, 'reduction': 'none'} 43 | if t in task_class_weights: params['weight'] = task_class_weights[t] 44 | task_losses[t] = nn.CrossEntropyLoss(**params) 45 | for t in ('Final Acuity Outcome',): 46 | params = {'ignore_index': -1} 47 | if t in task_class_weights: params['weight'] = task_class_weights[t] 48 | task_losses[t] = nn.CrossEntropyLoss(**params) 49 | for t in ('tasks_binary_multilabel', ): 50 | # May have missingness. 51 | # See: 52 | # https://discuss.pytorch.org/t/what-is-the-difference-between-bcewithlogitsloss-and-multilabelsoftmarginloss/14944/13 53 | # params = {'reduction': 'none'} 54 | # if t in task_class_weights: params['pos_weight'] = task_class_weights[t] 55 | # task_losses[t] = nn.BCEWithLogitsLoss(**params) 56 | if t in task_class_weights: 57 | task_losses[t] = TaskBinaryMultilabelLoss(task_class_weights[t]) 58 | else: 59 | task_losses[t] = TaskBinaryMultilabelLoss() 60 | for t in ('next_timepoint_was_measured',): 61 | params = {} 62 | if t in task_class_weights: params['weight'] = task_class_weights[t] 63 | task_losses[t] = nn.MultiLabelSoftMarginLoss(**params) 64 | 65 | return nn.ModuleDict(task_losses) 66 | 67 | POOLING_METHODS = ('max', 'avg', 'last')#, 'attention') 68 | class GRUModel(nn.Module): 69 | """ TODO(this) 70 | """ 71 | def __init__( 72 | self, config, data_shape=[48, 128], use_cuda=False, n_gpu = 0, task_class_weights=None, 73 | task_weights = None, hidden_dim=512, num_layers=2, bidirectional=False, 74 | pooling_method = 'last', fc_layer_sizes = [], verbose=False, 75 | do_eicu = False, 76 | ): 77 | super().__init__() 78 | 79 | self.verbose=verbose 80 | 81 | # TODO: need to activation masked_imputation task somehow... 82 | assert pooling_method in POOLING_METHODS, "Don't know how to do %s pooling" % pooling_method 83 | 84 | if task_class_weights is None: task_class_weights = {} 85 | 86 | # initialise the model and the weights 87 | self.hidden_dim = hidden_dim 88 | self.num_layers = num_layers 89 | self.bidirectional = bidirectional 90 | self.gru = nn.GRU( 91 | input_size=data_shape[-1], hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, 92 | dropout=config.hidden_dropout_prob, bidirectional=bidirectional, 93 | ) 94 | out_dim = hidden_dim * 2 if bidirectional else hidden_dim 95 | 96 | fc_stack = [] 97 | for fc_layer_size in fc_layer_sizes: 98 | fc_stack.append(nn.Linear(out_dim, fc_layer_size)) 99 | fc_stack.append(nn.ReLU()) 100 | out_dim = fc_layer_size 101 | 102 | fc_stack.append(nn.Linear(out_dim, config.hidden_size)) 103 | 104 | self.fc_stack = nn.Sequential(*fc_stack) 105 | 106 | if self.bidirectional: 107 | self.h_0 = torch.zeros(2*num_layers, 1, hidden_dim).float().to('cuda' if use_cuda else 'cpu') 108 | else: 109 | self.h_0 = torch.zeros(num_layers, 1, hidden_dim).float().to('cuda' if use_cuda else 'cpu') 110 | 111 | self.hidden_dim=hidden_dim 112 | 113 | self.pooling_method = pooling_method 114 | #elif pooling_method == 'attention': 115 | # self.attention 116 | 117 | # self.cls = ContinuousBertPreTrainingHeads(config) # modify this to get all of the necessary tasks 118 | self.use_cuda = use_cuda 119 | self.n_gpu = n_gpu 120 | 121 | # TODO: API! 122 | self.task_class_weights = task_class_weights 123 | if do_eicu: 124 | self.task_dims = { 125 | 'disch_24h': 10, 126 | 'disch_48h': 10, 127 | 'Final Acuity Outcome': 12, 128 | 'tasks_binary_multilabel': 3, 129 | 'next_timepoint': 15, # TODO put in config 130 | 'next_timepoint_was_measured': 15, 131 | 'masked_imputation': 15*2, 132 | } 133 | else: 134 | self.task_dims = { 135 | 'disch_24h': 20, 136 | 'disch_48h': 20, 137 | 'Final Acuity Outcome': 20, 138 | 'tasks_binary_multilabel': 26, #7, # ICD10, los, readmission 139 | 'next_timepoint': 56, # TODO put in config 140 | 'next_timepoint_was_measured': 56, 141 | 'masked_imputation': 56*2, 142 | } 143 | 144 | self.masked_imputation_loss = nn.BCEWithLogitsLoss(reduction='none') 145 | 146 | # TODO(mmd): API? 147 | self.task_heads = nn.ModuleDict( 148 | {t: nn.Linear(config.hidden_size, d) for t, d in self.task_dims.items()} 149 | ) 150 | if task_weights is None: 151 | self.task_weights = {t: 1 for t in self.task_heads.keys()} 152 | self.task_weights['rolling_ftseq'] = 1 153 | self.task_weights['masked_imputation'] = 0 154 | # In the case that we're doing all tasks as normal, presume we're doing no masking to mimic 155 | # prior behavior. 156 | else: self.task_weights = task_weights 157 | 158 | if 'masked_imputation' not in self.task_weights: self.task_weights['masked_imputation'] = 0 159 | # Setting these here enables to(device) / .cuda() to naturally affect them. 160 | 161 | self.task_permits_missingness = {'disch_24h', 'disch_48h', 'tasks_binary_multilabel'} 162 | self.task_losses = get_task_losses(self.task_class_weights) 163 | self.next_timepoint_reconstruction_loss = nn.MSELoss(reduction='none') 164 | 165 | # We do Rolling FTS separately 166 | self.treatment_embeddings = nn.Embedding( 167 | num_embeddings = 9, # TODO(mmd): Actually set this... 168 | embedding_dim = 25 # TODO(mmd): Actually set this... Belongs in config... 169 | ) 170 | self.FTS_decoder = FutureTreatmentSequenceDecoder( 171 | decoder_module = LSTMDecoder( 172 | in_dim = config.hidden_size, 173 | treatment_embeddings = self.treatment_embeddings, 174 | ), 175 | predictor_module = SingleTaskPredictor( 176 | in_dim = 25, 177 | num_classes = 9, 178 | ), 179 | ) 180 | 181 | def freeze_representation(self): 182 | for p in self.gru.parameters(): p.requires_grad = False 183 | for p in self.fc_stack.parameters(): p.requires_grad = False 184 | 185 | def unfreeze_representation(self): 186 | for p in self.gru.parameters(): p.requires_grad = True 187 | for p in self.fc_stack.parameters(): p.requires_grad = True 188 | 189 | # forward should be called with a dictionary, via, e.g., model(**batch) 190 | def forward( 191 | self, 192 | dfs, # Should be a dict... 193 | h_0=None 194 | ): 195 | # # TODO(mmd): Embedding Features... 196 | # input_sequence = self.ts_continuous_projector(ts_continuous) 197 | # statics_continuous = self.statics_projector(statics) 198 | # statics_continuous = statics_continuous.unsqueeze(1).expand_as(input_sequence) 199 | 200 | # input_sequence += statics_continuous 201 | 202 | # TODO(mmd): Put type conversions in dataset. 203 | for k in ( 204 | 'ts', 'statics', 'next_timepoint_was_measured', 'next_timepoint', 205 | 'tasks_binary_multilabel', 206 | 'ts_vals', 'ts_is_measured', 'ts_mask', 207 | ): 208 | if k in dfs: dfs[k] = dfs[k].float() 209 | for k in ('disch_24h', 'disch_48h', 'Final Acuity Outcome', 'rolling_ftseq'): 210 | if k in dfs: dfs[k] = dfs[k].squeeze().long() 211 | 212 | #input_sequence = torch.cat((ts, statics), dim=2) 213 | input_sequence = dfs['input_sequence'] 214 | batch_size, seq_len, feat_dim = list(input_sequence.shape) 215 | 216 | if h_0 is None: 217 | h_0 = self.h_0 218 | 219 | if batch_size != 1: 220 | h_0 = h_0.expand(-1, batch_size, -1).contiguous() 221 | 222 | out_unpooled, h = self.gru(input_sequence, h_0) # for gru 223 | 224 | if self.pooling_method == 'last': out = out_unpooled[:, -1, :] 225 | elif self.pooling_method == 'max': out = out_unpooled.max(dim=1)[0] 226 | elif self.pooling_method == 'avg': out = out_unpooled.mean(dim=1) 227 | 228 | out = out.contiguous().view(batch_size, -1) # num directions is 1 for forward-only rnn 229 | 230 | pooled_output = self.fc_stack(out) 231 | unpooled_output = self.fc_stack(out_unpooled) 232 | 233 | # sequence_output.shape is batch_size, max_seq_length, hidden_dim 234 | # pooled_output is batch_size, hidden_dim 235 | 236 | # insert all the prediction tasks here 237 | task_labels = { 238 | k: df for k, df in dfs.items() if \ 239 | k not in ('statics', 'ts', 'rolling_ftseq', 'ts_mask') and df is not None 240 | } 241 | tasks = list(set(task_labels.keys()).intersection(self.task_heads.keys())) 242 | assert 'rolling_ftseq' in dfs or tasks, "Must have some tasks!" 243 | assert 'masked_imputation' not in tasks, "This task must be handled separately." 244 | 245 | task_logits = {t: self.task_heads[t](pooled_output) for t in tasks} 246 | 247 | #for t in tasks: 248 | # if t not in self.task_losses: continue 249 | # print(t, 'labels', dfs[t].dtype, dfs[t].shape, 'logits', task_logits[t].dtype, task_logits[t].shape) 250 | 251 | task_losses = {} 252 | weights_sum = 0 253 | for task, loss_fn, logits, labels, weight in zip_dicts( 254 | self.task_losses, task_logits, dfs, self.task_weights 255 | ): 256 | weights_sum += weight # We do it like this so that we only track tasks that are actually used. 257 | if task in self.task_permits_missingness: 258 | isnan = torch.isnan(labels) 259 | labels_smoothed = torch.where(isnan, torch.zeros_like(labels), labels) 260 | try: 261 | if task in ('disch_24h', 'disch_48h'): 262 | loss = torch.where( 263 | isnan, torch.zeros_like(labels, dtype=torch.float32), 264 | loss_fn(logits, labels_smoothed) 265 | ) 266 | elif task in ('tasks_binary_multilabel',): 267 | loss = (self.task_class_weights[task]!=0).float() * torch.where(isnan, torch.zeros_like(logits), loss_fn(logits, labels_smoothed)) 268 | else: 269 | raise 270 | except: 271 | print(dfs) 272 | print(task, isnan.shape, labels.shape, logits.shape, labels_smoothed.shape) 273 | raise 274 | loss = loss.mean() 275 | loss = weight * loss 276 | else: loss = weight * loss_fn(logits, labels) 277 | 278 | task_losses[task] = loss 279 | 280 | if 'rolling_ftseq' in dfs: 281 | weights_sum += self.task_weights['rolling_ftseq'] 282 | fts_labels = dfs['rolling_ftseq'] 283 | fts_logits, fts_loss = self.FTS_decoder(pooled_output, labels = fts_labels) 284 | task_logits['rolling_ftseq'] = fts_logits 285 | task_losses['rolling_ftseq'] = self.task_weights['rolling_ftseq'] * fts_loss 286 | tasks.append('rolling_ftseq') 287 | 288 | # We need to handle next timepoint separately to deal with masking. 289 | if 'next_timepoint' and 'next_timepoint_was_measured' in dfs: 290 | weights_sum += self.task_weights['next_timepoint'] 291 | recst_loss = self.next_timepoint_reconstruction_loss( 292 | dfs['next_timepoint'], task_logits['next_timepoint'] 293 | ) 294 | recst_loss *= dfs['next_timepoint_was_measured'] # Mask out those not obs. 295 | 296 | # Accounting for the completely unmeasured rows. 297 | # TODO(mmd): we'll need to do the same for some of the classification tasks... 298 | num_measured_per_patient = dfs['next_timepoint_was_measured'].sum(dim=1) 299 | num_measured_per_patient = torch.where( 300 | num_measured_per_patient == 0, torch.ones_like(num_measured_per_patient), 301 | num_measured_per_patient 302 | ) 303 | recst_loss = recst_loss.sum(dim=1) / num_measured_per_patient 304 | 305 | recst_loss = recst_loss.sum(dim=0) 306 | #task_losses['next_timepoint'] = self.task_weights['next_timepoint'] * recst_loss 307 | # Setting this to 0 here as we always want this ablated. This is a poor solution, but doing this 308 | # to avoid any issues. 309 | task_losses['next_timepoint'] = 0 * recst_loss 310 | 311 | # We need to handle next timepoint separately to deal with masking. 312 | if 'masked_imputation' in self.task_weights and self.task_weights['masked_imputation'] > 0: 313 | assert 'ts_vals' in dfs and 'ts_is_measured' in dfs and 'ts_mask' in dfs, \ 314 | 'Expected masked items in the dataset' 315 | 316 | weights_sum += self.task_weights['masked_imputation'] 317 | 318 | # unpooled_output is of shape: batch_size X seq_len X feat_dim 319 | imputation_scores = self.task_heads['masked_imputation'](unpooled_output) 320 | # imputation_scores is of shape: batch_size X seq_len X task_dims['masked_imputation'] 321 | # TODO(mmd): This should be associated to the task dimensionality. 322 | 323 | N = self.task_dims['masked_imputation'] // 2 324 | wbm_logits = imputation_scores[:, :, :N] 325 | imp_preds = imputation_scores[:, :, N:] 326 | 327 | # per-timepoint: 328 | # 1) cross entropy loss on wbm logits (multi-label binary) 329 | # 2) euclidean loss on imp_preds (continuous value) <-- TODO should this be also probabilistic? 330 | # masked to only be applied _when the value is observed_ 331 | 332 | per_feature_wbm_loss = self.masked_imputation_loss(wbm_logits, dfs['ts_is_measured']) 333 | per_timepoint_wbm_loss = per_feature_wbm_loss.mean(dim=2) 334 | 335 | per_feature_imp_loss = (imp_preds - dfs['ts_vals'])**2 336 | 337 | num_measured_per_timepoint_real = dfs['ts_is_measured'].sum(dim=2) 338 | num_measured_per_timepoint_smoothed = torch.where( 339 | num_measured_per_timepoint_real == 0, torch.ones_like(num_measured_per_timepoint_real), 340 | num_measured_per_timepoint_real 341 | ) 342 | 343 | # Should possibly use some probabilistic loss. 344 | per_timepoint_imp_loss = ( 345 | (per_feature_imp_loss * dfs['ts_is_measured']).sum(dim=2) 346 | / 347 | num_measured_per_timepoint_smoothed 348 | ) 349 | # the loss is uniformly zero on timepoints where nothing was measured, and otherwise is the RMSE. 350 | # Use the MSE loss to avoid sqrt giving nans (unknown why sqrt is giving nans) 351 | # per_timepoint_imp_loss = torch.where( 352 | # num_measured_per_timepoint_real == 0, torch.zeros_like(per_timepoint_imp_loss), 353 | # torch.sqrt(per_timepoint_imp_loss) 354 | # ) 355 | 356 | # Next, sum these two losses to obtain the per-timepoint loss. 357 | per_timepoint_loss = per_timepoint_wbm_loss + per_timepoint_imp_loss 358 | # per_timepoint_loss is of shape batch_size X seq_len 359 | 360 | # Then, this entire loss per time-point is masked according to which time-points were masked. 361 | # dfs['mask_indicators'] is binary and of shape batch_size X seq_len 362 | # TODO(mmd): Make the `.squeeze()` not necessary. 363 | per_timepoint_loss = per_timepoint_loss * dfs['ts_mask'].squeeze() 364 | per_seq_loss = per_timepoint_loss.sum(dim=1) 365 | 366 | # Finally, we want to scale by the # of masks per sequence. We set this scaling factor to `1` when 367 | # it is actually 0 to avoid a later divide by zero error, noting that by 368 | # `per_timepoint_loss = per_timepoint_loss * dfs['mask_indicators']` 369 | # above the numerator in that case is guaranteed to be zero. 370 | num_masks_per_seq = dfs['ts_mask'].squeeze().sum(dim=1) 371 | num_masks_per_seq = torch.where( 372 | num_masks_per_seq == 0, torch.ones_like(num_masks_per_seq), 373 | num_masks_per_seq 374 | ) 375 | per_seq_loss = per_seq_loss / num_masks_per_seq 376 | 377 | per_batch_loss = per_seq_loss.sum(dim=0) 378 | task_losses['masked_imputation'] = self.task_weights['masked_imputation'] * per_batch_loss 379 | 380 | try: 381 | total_loss = None 382 | for l in task_losses.values(): 383 | total_loss = l if total_loss is None else (total_loss + l) 384 | total_loss /= weights_sum 385 | except: 386 | print(task_losses) 387 | print(weights_sum) 388 | raise 389 | 390 | # formerly returned: {t: (task_logits[t], dfs[t], task_losses[t]) for t in tasks}, total_loss.unsqueeze(0) if self.n_gpu > 1 else total_loss 391 | 392 | out_data = {t: (task_logits[t], dfs[t], task_losses[t]) for t in tasks} 393 | if 'masked_imputation' in self.task_weights and self.task_weights['masked_imputation'] > 0: 394 | try: 395 | out_data['masked_imputation'] = tuple([ 396 | (wbm_logits, imp_preds), 397 | (dfs['ts_is_measured'], dfs['ts_vals'], dfs['ts_mask']), 398 | ( 399 | per_timepoint_wbm_loss, per_timepoint_imp_loss, task_losses['masked_imputation'], 400 | per_timepoint_loss, per_seq_loss, per_batch_loss 401 | ) 402 | ]) 403 | except: 404 | print(type(out_data), out_data) 405 | for e in ( 406 | wbm_logits, imp_preds, dfs['ts_is_measured'], dfs['ts_vals'], dfs['ts_mask'], 407 | per_timepoint_wbm_loss, per_timepoint_imp_loss, task_losses['masked_imputation'] 408 | ): 409 | print(type(e)) 410 | try: print(e.shape) 411 | except: print("not a tensor", e) 412 | raise 413 | 414 | return ( 415 | None, 416 | pooled_output, 417 | out_data, 418 | total_loss.unsqueeze(0) if self.n_gpu > 1 else total_loss 419 | ) 420 | -------------------------------------------------------------------------------- /latent_patient_trajectories/representation_learner/run_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | run_model.py 3 | """ 4 | 5 | import torch.optim 6 | from torch.autograd import set_detect_anomaly 7 | from torch.utils.data import DataLoader, RandomSampler, SubsetRandomSampler 8 | 9 | import json, os, pickle 10 | from tqdm import tqdm 11 | 12 | # TODO: check these imports. 13 | from ..utils import * 14 | from ..constants import * 15 | from ..data_utils import * 16 | from ..representation_learner.fts_decoder import * 17 | from ..representation_learner.evaluator import * 18 | from ..representation_learner.meta_model import * 19 | from ..BERT.model import * 20 | from ..BERT.constants import * 21 | from .args import Args, EvalArgs 22 | 23 | def train_meta_model( 24 | meta_model, train_dataloader, args, reloaded=False, epoch=0, 25 | tuning_dataloader=None, train_embedding_after=-1, tqdm=tqdm, just_gen_data=False, 26 | ): 27 | all_train_perfs, all_dev_perfs = [], [] 28 | if just_gen_data: 29 | optimizer, scheduler = None, None 30 | else: 31 | optimizer = torch.optim.Adam( 32 | meta_model.parameters, 33 | lr=args.learning_rate, 34 | weight_decay=args.weight_decay if args.do_weight_decay else 0, 35 | ) 36 | scheduler = torch.optim.lr_scheduler.StepLR( 37 | optimizer, args.learning_rate_step, 38 | args.learning_rate_decay if args.do_learning_rate_decay else 1, 39 | ) 40 | 41 | early_stop_count=0 42 | prev_err=10e9 43 | 44 | epoch_rng = range(epoch+1 if reloaded else 0, args.epochs) 45 | if tqdm is not None: epoch_rng = tqdm(epoch_rng, desc='Epoch: N/A', leave=False) 46 | 47 | for epoch in epoch_rng: 48 | if not just_gen_data: 49 | scheduler.step() # This goes before any of the train/validation stuff b/c torch version is 1.0.1 50 | 51 | if train_embedding_after >= epoch: 52 | # to ensure it is unfrozen after reloading 53 | meta_model.unfreeze_representation() 54 | 55 | meta_model.train() 56 | optimizer.zero_grad() 57 | 58 | train_dataloader.dataset.set_epoch(epoch) 59 | 60 | dataloader_rng = train_dataloader 61 | if tqdm is not None: 62 | dataloader_rng = tqdm( 63 | dataloader_rng, desc='Batch: N/A', total=len(train_dataloader), leave=False) 64 | 65 | for i, batch in enumerate(dataloader_rng): 66 | if just_gen_data: 67 | total_loss = torch.tensor(0) 68 | continue 69 | if args.do_detect_anomaly: set_detect_anomaly(True) 70 | 71 | if batch['ts'].shape[0] == 1: 72 | print("Skipping singleton batch.") 73 | continue 74 | 75 | hidden_states, pooled_output, all_outputs, total_loss = meta_model.forward(batch) 76 | 77 | try: 78 | total_loss.backward() 79 | except: 80 | print(total_loss.shape, total_loss) 81 | raise 82 | if i % args.batches_per_gradient == 0: 83 | optimizer.step() 84 | optimizer.zero_grad() 85 | if args.do_detect_anomaly: set_detect_anomaly(False) 86 | 87 | if tqdm is not None: dataloader_rng.set_description('Batch: %.2e' % total_loss) 88 | 89 | if just_gen_data: continue 90 | 91 | if tqdm is None: print("Epoch %d: %.2f" % (epoch, total_loss.item())) 92 | elif (tqdm is not None) and (tuning_dataloader is not None): pass 93 | else: epoch_rng.set_description("Epoch %d: %.2f" % (epoch, total_loss.item())) 94 | 95 | if epoch % args.train_save_every == 0: 96 | if tuning_dataloader is None: 97 | meta_model.save(epoch) 98 | else: 99 | print("Doing early stop") 100 | # do eval to see if this is the best score 101 | tuning_dataloader.dataset.epoch=epoch 102 | tuning_dataloader.dataset.save_place=args.dataset_dir 103 | 104 | dataloader_rng = tuning_dataloader 105 | if tqdm is not None: 106 | dataloader_rng = tqdm( 107 | dataloader_rng, desc='Batch: N/A', total=len(tuning_dataloader), leave=False) 108 | meta_model.eval() 109 | tuning_losses=[] 110 | for i, batch in enumerate(dataloader_rng): 111 | hidden_states, pooled_output, all_outputs, total_loss = meta_model.forward(batch) 112 | tuning_losses.append(total_loss.cpu().data.numpy().ravel()) 113 | meta_model.train() 114 | total_err = np.mean(np.concatenate(tuning_losses)) 115 | 116 | if total_err < prev_err: 117 | # this is the best model 118 | meta_model.save(epoch) 119 | # best_meta_model=meta_model.copy().cpu() 120 | prev_err=total_err 121 | early_stop_count=0 122 | if tqdm is None: print("Epoch %d: %.2f" % (epoch, total_loss.item())) 123 | else: epoch_rng.set_description("Epoch %d: %.2f" % (epoch, total_loss.item())) 124 | else: 125 | early_stop_count+=1 126 | if early_stop_count==10: 127 | print(f"Early stopping at epoch {epoch}. Best model at epoch {epoch} with a loss of {prev_err}") 128 | break 129 | 130 | 131 | if just_gen_data: return None 132 | 133 | if tuning_dataloader is None: 134 | meta_model.save(epoch) 135 | return meta_model 136 | else: return best_meta_model 137 | 138 | def run_model( 139 | args, datasets, train_dataloader, tqdm=None, meta_model=None, tuning_dataloader=None, just_gen_data=False 140 | ): 141 | if just_gen_data: 142 | meta_model = None 143 | reloaded, epoch = None, 0 144 | else: 145 | if meta_model is None: meta_model = MetaModel( 146 | args, datasets['train'][0], 147 | class_names = {'tasks_binary_multilabel': datasets['train'].get_binary_multilabel_keys()} 148 | ) 149 | reloaded, epoch = meta_model.load() 150 | if reloaded: print("Resuming from epoch %d" % (epoch+1)) 151 | 152 | if args.do_train: 153 | print('training') 154 | train_meta_model( 155 | meta_model, train_dataloader, args, reloaded, epoch, tuning_dataloader, tqdm=tqdm, 156 | just_gen_data=just_gen_data 157 | ) 158 | 159 | #for n, do, dataloader in ( 160 | # ('train', args.record_train_perf, train_dataloader), ('dev', args.val, dev_dataloader), 161 | # ('test', args.test, test_dataloader) 162 | #): 163 | # if not do: continue 164 | # perf_dict = eval_loop( 165 | # model, device, dataloader, tqdm=tqdm, run_dir=args.run_dir, n_gpu=n_gpu, ablations=args.ablate, 166 | # integrate_note_bert=args.integrate_note_bert, note_embedding_model=note_embedding_model, notes_projector=notes_projector, 167 | # batch_size=args.batch_size, only_notes=args.only_notes, train_note_bert=args.train_note_bert, 168 | # using_notes_embeddings=using_notes_embeddings 169 | # ) 170 | # with open(os.path.join(args.run_dir, '{}.eval'.format(n)), mode='wb') as f: 171 | # pickle.dump(perf_dict, f) 172 | return meta_model 173 | 174 | def load_datasets( 175 | args, just_gen_data=False, use_stored_epochs=False, use_dataset_shells=True, make_train_dataloader=True 176 | ): 177 | do_splits_dict = {} 178 | if type(args) is Args: 179 | if not hasattr(args, 'dataset_dir') or not args.dataset_dir: 180 | rotations_dir = EICU_ROTATIONS_DIR if args.do_eicu else ROTATIONS_DIR 181 | args.dataset_dir = os.path.join(rotations_dir, args.notes, str(args.rotation)) 182 | 183 | do_splits_dict['train'] = args.do_train or args.do_eval_train 184 | do_splits_dict['tuning'] = args.do_eval_tuning 185 | do_splits_dict['test'] = args.do_eval_test 186 | max_seq_len = args.max_seq_len 187 | set_to_eval_mode = args.set_to_eval_mode 188 | elif type(args) is EvalArgs: 189 | training_args = Args.from_json_file(os.path.join(args.run_dir, ARGS_FILENAME)) 190 | for arg in ('notes', 'rotation', 'dataset_dir', 'imputation_mask_rate', 'do_masked_imputation'): 191 | if hasattr(args, arg) and getattr(args, arg) not in (None, ''): 192 | assert hasattr(training_args, arg) 193 | assert getattr(training_args, arg) == getattr(args, arg), \ 194 | f"Dataset parameters disagree ({arg})!" 195 | if not hasattr(args, 'dataset_dir') or not args.dataset_dir: 196 | if hasattr(training_args, 'dataset_dir'): args.dataset_dir = training_args.dataset_dir 197 | else: 198 | rotations_dir = EICU_ROTATIONS_DIR if args.do_eicu else ROTATIONS_DIR 199 | args.dataset_dir = os.path.join( 200 | rotations_dir, training_args.notes, str(training_args.rotation) 201 | ) 202 | 203 | do_splits_dict['train'] = args.do_eval_train 204 | do_splits_dict['tuning'] = args.do_eval_tuning 205 | do_splits_dict['test'] = args.do_eval_test 206 | max_seq_len = training_args.max_seq_len 207 | set_to_eval_mode = EVAL_MODES[1] # "first_24" 208 | else: raise AssertionError(f"Args must be of a recognized type! Is {type(args)}.") 209 | 210 | datasets = {} 211 | for split, do in do_splits_dict.items(): 212 | if not do: 213 | datasets[split] = None 214 | continue 215 | 216 | load_start = time.time() 217 | dataset_shell_path = os.path.join(args.dataset_dir, f"{split}_dataset_shell.pkl") 218 | if use_dataset_shells and os.path.isfile(dataset_shell_path): load_path = dataset_shell_path 219 | else: load_path = os.path.join(args.dataset_dir, f"{split}_dataset.pkl") 220 | datasets[split] = depickle(load_path) 221 | 222 | if hasattr(datasets[split], 'skip_cache') and datasets[split].skip_cache: 223 | print(f"{load_path} has skip_cache true.") 224 | 225 | datasets[split].skip_cache = False 226 | if just_gen_data: datasets[split].save_data_only = True 227 | else: datasets[split].save_data_only = False 228 | print('loading %s data from disk took %.2f minutes' % (split, (time.time() - load_start)/60)) 229 | 230 | datasets[split].reload_self_dir = args.dataset_dir 231 | datasets[split].train_tune_test = split 232 | datasets[split].save_place = os.path.join( 233 | args.dataset_dir, "stored_epochs" if use_stored_epochs else "stored_items" 234 | ) 235 | if split == 'train': 236 | datasets[split].max_seq_len = max_seq_len 237 | 238 | if set_to_eval_mode: datasets[split].set_to_eval_mode(set_to_eval_mode) 239 | elif split != 'train': datasets[split].set_to_eval_mode(EVAL_MODES[1]) 240 | 241 | datasets[split].set_epoch(0) 242 | if not os.path.isdir(datasets[split].save_place): os.makedirs(datasets[split].save_place) 243 | if args.do_masked_imputation: 244 | assert args.imputation_mask_rate > 0, "Can't do imputation masking if we mask nothing!" 245 | datasets[split].imputation_mask_rate = args.imputation_mask_rate 246 | else: 247 | assert args.imputation_mask_rate == 0, "Can't mask if imputation masking is not enabled." 248 | assert datasets[split].imputation_mask_rate == 0, "Shouldn't mask!" 249 | 250 | return datasets 251 | 252 | def args_run_setup(args): 253 | # Make run_dir if it doesn't exist. 254 | if not os.path.exists(args.run_dir): os.mkdir(os.path.abspath(args.run_dir)) 255 | elif not args.do_overwrite: 256 | raise ValueError("Save dir %s exists and overwrite is not enabled!" % args.run_dir) 257 | 258 | if not args.dataset_dir: 259 | rotations_dir = EICU_ROTATIONS_DIR if args.do_eicu else ROTATIONS_DIR 260 | args.dataset_dir = os.path.join(rotations_dir, args.notes, str(args.rotation)) 261 | 262 | args.to_json_file(os.path.join(args.run_dir, ARGS_FILENAME)) 263 | return 264 | 265 | def setup_datasets_and_dataloaders(args, just_gen_data=False, use_stored_epochs=False): 266 | datasets = load_datasets( 267 | args, just_gen_data=just_gen_data, use_stored_epochs=use_stored_epochs 268 | ) 269 | 270 | if not args.do_train: return datasets 271 | 272 | sampler = SubsetRandomSampler(list(range(50))) if args.do_test_run else RandomSampler(datasets['train']) 273 | if just_gen_data: 274 | # In this case we override the typical collate_fn to avoid any errors with partially read and 275 | # partially constructed keys. 276 | train_dataloader = DataLoader( 277 | datasets['train'], sampler=sampler, batch_size=args.batch_size, 278 | num_workers=args.num_dataloader_workers, collate_fn=lambda xs: dict(), 279 | ) 280 | else: 281 | train_dataloader = DataLoader( 282 | datasets['train'], sampler=sampler, batch_size=args.batch_size, 283 | num_workers=args.num_dataloader_workers 284 | ) 285 | 286 | return datasets, train_dataloader 287 | 288 | def setup_for_run(args, just_gen_data=False, use_stored_epochs=False): 289 | args_run_setup(args) 290 | return setup_datasets_and_dataloaders( 291 | args, just_gen_data=just_gen_data, use_stored_epochs=use_stored_epochs 292 | ) 293 | 294 | def main(args, tqdm): 295 | datasets, train_dataloader = setup_for_run(args) 296 | 297 | # added to restrict the data in the dataset 298 | if hasattr(args, 'frac_data'): 299 | if args.frac_data != 1: 300 | import random 301 | # get index of train_dataset 302 | 303 | if args.frac_data_seed !=0: 304 | random.seed(args.frac_data_seed) 305 | 306 | orig_len=len(datasets['train']) 307 | subjects_hours = list(zip(datasets['train'].orig_subjects, datasets['train'].orig_max_hours)) 308 | assert len(set(subjects_hours))==len(subjects_hours) 309 | random.shuffle(subjects_hours) 310 | subjects_hours=subjects_hours[:int(args.frac_data *len(subjects_hours))] 311 | print(len(subjects_hours)) 312 | subjects, hours =zip(*subjects_hours) 313 | datasets['train'].orig_subjects = subjects 314 | datasets['train'].orig_max_hours = hours 315 | datasets['train'].reset_sequence_len(datasets['train'].sequence_len, reset_index=False) 316 | datasets['train'].reset_index() 317 | assert len(datasets['train']) < orig_len, f"Failed to assert that {len(datasets['train'])} < {orig_len}" 318 | # # reset the dataset to that length 319 | # datasets['train'].index = train_data_index[:int(fine_tune_args.frac_fine_tune_data *len(train_data_index))] 320 | 321 | 322 | sampler = RandomSampler(datasets['train']) 323 | 324 | train_dataloader = DataLoader( 325 | datasets['train'], sampler=sampler, batch_size=train_dataloader.batch_size, 326 | num_workers=train_dataloader.num_workers 327 | ) 328 | 329 | if args.frac_female != 1: 330 | import random 331 | # get index of train_dataset 332 | 333 | if args.frac_data_seed !=0: 334 | random.seed(args.frac_data_seed) 335 | 336 | # get female and male participants. 337 | 338 | 339 | subjects = datasets['train'].orig_subjects 340 | # subjects_index = [datasets['train'].index[item] for item in subjects] 341 | males = datasets['train'].dfs['statics'].loc[idx[subjects], 'gender_2'].values 342 | assert len(set(subjects))==len(males) 343 | male_subjects = [item[0] for item in zip(subjects, males) if item[1]==1] 344 | #list(subjects[females==0]) 345 | print(male_subjects[:10]) 346 | # print(" ".join([str(datasets['train'].index[item]) for item in male_subjects[:10]])) 347 | females = datasets['train'].dfs['statics'].loc[idx[subjects], 'gender_1'].values 348 | assert len(set(subjects))==len(females) 349 | female_subjects = [item[0] for item in zip(subjects, females) if item[1]==1] #list(subjects[females==1]) 350 | print('\n\n') 351 | print(female_subjects[:10]) 352 | # print(" ".join([str(datasets['train'].index[item]) for item in female_subjects[:10]])) 353 | print(len(male_subjects), len(female_subjects)) 354 | 355 | # print('check that these are females: ', [datasets['train'].index[item] for item in female_subjects[:3]]) # these subject IDs are checkind in the psql db and they are indeed female 356 | 357 | 358 | random.shuffle(female_subjects) 359 | num_females = min(int(args.frac_female *len(male_subjects)), len(female_subjects)) 360 | female_subjects = female_subjects[:num_females] 361 | 362 | print(f'There are {len(female_subjects)} female patients, and {len(male_subjects)} male subjects') 363 | subjects = female_subjects+male_subjects 364 | 365 | 366 | orig_len=len(datasets['train']) 367 | subjects_hours = list(zip(datasets['train'].orig_subjects, datasets['train'].orig_max_hours)) 368 | subjects_hours = [item for item in subjects_hours if item[0] in subjects] 369 | print(len(subjects_hours)) 370 | subjects, hours =zip(*subjects_hours) 371 | datasets['train'].orig_subjects = subjects 372 | datasets['train'].orig_max_hours = hours 373 | datasets['train'].reset_sequence_len(datasets['train'].sequence_len, reset_index=False) 374 | datasets['train'].reset_index() 375 | assert len(datasets['train']) < orig_len, f"Failed to assert that {len(datasets['train'])} < {orig_len}" 376 | # # reset the dataset to that length 377 | # datasets['train'].index = train_data_index[:int(fine_tune_args.frac_fine_tune_data *len(train_data_index))] 378 | 379 | 380 | sampler = RandomSampler(datasets['train']) 381 | 382 | # print('\n\n', vars(datasets['train'])['max_seq_len']) 383 | 384 | train_dataloader = DataLoader( 385 | datasets['train'], sampler=sampler, batch_size=train_dataloader.batch_size, 386 | num_workers=train_dataloader.num_workers 387 | ) 388 | 389 | 390 | 391 | if not(hasattr(args, 'frac_black')): 392 | pass 393 | elif args.frac_black != 1 or args.balanced_race: 394 | # TODO(mmd/bnestor): We will need to modify this for eICU runs. 395 | import random 396 | # get index of train_dataset 397 | 398 | if args.frac_data_seed !=0: 399 | random.seed(args.frac_data_seed) 400 | 401 | # get female and male participants. 402 | 403 | 404 | subjects = datasets['train'].orig_subjects 405 | # subjects_index = [datasets['train'].index[item] for item in subjects] 406 | white = datasets['train'].dfs['statics'].loc[idx[subjects], 'ethnicity_4'].values 407 | assert len(set(subjects))==len(white) 408 | white_subjects = [item[0] for item in zip(subjects, white) if item[1]==1] 409 | #list(subjects[females==0]) 410 | print(white_subjects[:10]) 411 | # print(" ".join([str(datasets['train'].index[item]) for item in male_subjects[:10]])) 412 | black = datasets['train'].dfs['statics'].loc[idx[subjects], 'ethnicity_2'].values 413 | assert len(set(subjects))==len(black) 414 | black_subjects = [item[0] for item in zip(subjects, black) if item[1]==1] #list(subjects[females==1]) 415 | print('\n\n') 416 | print(black_subjects[:10]) 417 | # print(" ".join([str(datasets['train'].index[item]) for item in female_subjects[:10]])) 418 | print(len(white_subjects), len(black_subjects)) 419 | if args.balanced_race: 420 | random.shuffle(white_subjects) 421 | white_subjects = white_subjects[:len(black_subjects)] 422 | 423 | # print('check that these are females: ', [datasets['train'].index[item] for item in female_subjects[:3]]) # these subject IDs are checkind in the psql db and they are indeed female 424 | 425 | 426 | random.shuffle(black_subjects) 427 | num_black = min(int(args.frac_black *len(white_subjects)), len(black_subjects)) 428 | black_subjects = black_subjects[:num_black] 429 | 430 | print(f'There are {len(black_subjects)} female patients, and {len(white_subjects)} male subjects') 431 | subjects = black_subjects+white_subjects 432 | 433 | 434 | orig_len=len(datasets['train']) 435 | subjects_hours = list(zip(datasets['train'].orig_subjects, datasets['train'].orig_max_hours)) 436 | subjects_hours = [item for item in subjects_hours if item[0] in subjects] 437 | print(len(subjects_hours)) 438 | subjects, hours =zip(*subjects_hours) 439 | datasets['train'].orig_subjects = subjects 440 | datasets['train'].orig_max_hours = hours 441 | datasets['train'].reset_sequence_len(datasets['train'].sequence_len, reset_index=False) 442 | datasets['train'].reset_index() 443 | assert len(datasets['train']) < orig_len, f"Failed to assert that {len(datasets['train'])} < {orig_len}" 444 | # # reset the dataset to that length 445 | # datasets['train'].index = train_data_index[:int(fine_tune_args.frac_fine_tune_data *len(train_data_index))] 446 | 447 | 448 | sampler = RandomSampler(datasets['train']) 449 | 450 | # print('\n\n', vars(datasets['train'])['max_seq_len']) 451 | 452 | train_dataloader = DataLoader( 453 | datasets['train'], sampler=sampler, batch_size=train_dataloader.batch_size, 454 | num_workers=train_dataloader.num_workers 455 | ) 456 | 457 | 458 | tuning_dataloader=None 459 | if args.epochs==-1: 460 | # do random stopping and pass in tuning dataset 461 | tuning_dataloader=DataLoader( 462 | datasets['tuning'], sampler=RandomSampler(datasets['tuning']), batch_size=train_dataloader.batch_size, 463 | num_workers=train_dataloader.num_workers 464 | ) 465 | args.epochs=500 466 | return run_model(args, datasets, train_dataloader, tqdm=tqdm, tuning_dataloader=tuning_dataloader) 467 | -------------------------------------------------------------------------------- /latent_patient_trajectories/BERT/continuous_pretraining_data_processor.py: -------------------------------------------------------------------------------- 1 | # TODO: Use Dask--integrate processor and converter (maybe tensor dataset too). Do as Mapreduce proper. 2 | import sys 3 | 4 | import gc, random, numpy as np 5 | from multiprocessing import Pool 6 | from contextlib import closing 7 | 8 | 9 | from latent_patient_trajectories.utils import * 10 | from latent_patient_trajectories.data_utils import * 11 | from latent_patient_trajectories.constants import * 12 | 13 | from latent_patient_trajectories.BERT.constants import * 14 | from latent_patient_trajectories.BERT.data_processor import * 15 | from latent_patient_trajectories.BERT.model import * 16 | 17 | 18 | def flatten(arr): 19 | if type(arr) is np.ndarray: return np.reshape(arr, [len(arr), -1]) 20 | elif type(arr) is list: 21 | r = [] 22 | for l in arr: r += l 23 | return r 24 | raise NotImplementedError 25 | 26 | def I(x): return x 27 | 28 | class ContinuousInputFeatures(object): 29 | """A single set of features of data.""" 30 | 31 | def __init__( 32 | self, input_sequence_orig, input_sequence_masked, input_mask, segment_ids, el_was_masked, labels 33 | ): 34 | self.input_sequence_orig = input_sequence_orig 35 | self.input_sequence_masked = input_sequence_masked 36 | self.input_mask = input_mask 37 | self.segment_ids = segment_ids 38 | self.el_was_masked = el_was_masked 39 | self.whole_sequence_labels = labels 40 | 41 | def _mask_seq( 42 | seq, random_token_seq, can_mask, mask_prob, random_token_mask_prob, random_token_replace_prob, 43 | ): 44 | if seq is None: return None, None, None 45 | 46 | mask_ps = np.random.uniform(low=0, high=1, size=(len(seq), )) 47 | mask_vector = CONTROL_VECTOR_PREFIXES[MASK](np.zeros((1, seq.shape[1]))) 48 | 49 | was_masked = [0] * len(seq) 50 | seq_orig = seq.copy() 51 | 52 | random_seq_i = 0 # we track this separately because we don't have enough random seqs given ctrl tokens. 53 | for i, can_mask in enumerate(can_mask): 54 | if not can_mask: continue 55 | 56 | random_seq_i += 1 # we increment here and subtract below so we don't need to increment in the if. 57 | p = mask_ps[i] 58 | if p >= mask_prob: continue 59 | 60 | was_masked[i] = 1 61 | 62 | p /= mask_prob 63 | 64 | if p < random_token_mask_prob: seq[i, :] = mask_vector 65 | elif p < random_token_mask_prob + random_token_replace_prob: 66 | seq[i, :] = random_token_seq[random_seq_i - 1, :] 67 | 68 | return seq_orig, seq, was_masked 69 | 70 | def _truncate_seq_pair(seq_a, seq_b, max_length): 71 | """Truncates a sequence pair in place to the maximum length. You don't seed this function as otherwise it 72 | would always favor one sequence.""" 73 | 74 | total_length = len(seq_a) + len(seq_b) 75 | if total_length <= max_length: return seq_a, seq_b 76 | 77 | elif len(seq_b) <= max_length // 2: seq_a = seq_a[:max_length - len(seq_b)] 78 | elif len(seq_a) <= max_length // 2: seq_b = seq_b[:max_length - len(seq_a)] 79 | elif max_length % 2 == 0: 80 | seq_a = seq_a[:max_length // 2] 81 | seq_b = seq_b[:max_length // 2] 82 | else: 83 | if random.random() > 0.5: 84 | seq_a = seq_a[:(max_length // 2) + 1] 85 | seq_b = seq_b[:max_length // 2] 86 | else: 87 | seq_a = seq_a[:max_length // 2] 88 | seq_b = seq_b[:(max_length // 2) + 1] 89 | return seq_a, seq_b 90 | 91 | def __convert_example(args): 92 | # TODO(mmd): modify to work with just values. 93 | example, max_seq_length, random_seq, mask_select_prob, token_mask_prob, token_replace_prob = args 94 | mask = lambda seq, random_seq, can_mask: _mask_seq( 95 | seq, random_seq, can_mask, mask_select_prob, token_mask_prob, token_replace_prob 96 | ) 97 | 98 | seq_a = example.seq_a 99 | seq_dim = seq_a.shape[1] 100 | 101 | seq_b = None 102 | if example.seq_b is not None: 103 | seq_b = example.seq_b 104 | seq_len = len(seq_a) + len(seq_b) 105 | 106 | # Modifies `seq_a` and `seq_b` in place so that the total 107 | # length is less than the specified length. 108 | # Account for [CLS], [SEP], [SEP] with "- 3" 109 | seq_a, seq_b = _truncate_seq_pair(seq_a, seq_b, max_seq_length - 3) 110 | else: 111 | seq_len = len(seq_a) 112 | # Account for [CLS] and [SEP] with "- 2" 113 | if len(seq_a) > max_seq_length - 2: seq_a = seq_a[:(max_seq_length - 2)] 114 | 115 | # TODO: Update comment. 116 | # The convention in BERT is: 117 | # (a) For sequence pairs: 118 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 119 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 120 | # (b) For single sequences: 121 | # tokens: [CLS] the dog is hairy . [SEP] 122 | # type_ids: 0 0 0 0 0 0 0 123 | # 124 | # Where "type_ids" are used to indicate whether this is the first 125 | # sequence or the second sequence. The embedding vectors for `type=0` and 126 | # `type=1` were learned during pre-training and are added to the wordpiece 127 | # embedding vector (and position vector). This is not *strictly* necessary 128 | # since the [SEP] token unambigiously separates the sequences, but it makes 129 | # it easier for the model to learn the concept of sequences. 130 | # 131 | # For classification tasks, the first vector (corresponding to [CLS]) is 132 | # used as as the "sentence vector". Note that this only makes sense because 133 | # the entire model is fine-tuned. 134 | 135 | # First, we pad the sequence vectors with zeros so we can introduce our auxiliary tokens. Then, we add 136 | # the CLS and SEP tokens into the sequence. 137 | 138 | augmented_seq_dim = seq_dim + NUM_CONTROL_TOKENS 139 | 140 | can_mask = [False] + [True] * len(seq_a) + [False] 141 | seq_a = np.concatenate((np.zeros((len(seq_a), NUM_CONTROL_TOKENS)), seq_a), axis=1) 142 | cls_vector = CONTROL_VECTOR_PREFIXES[CLS](np.zeros((1, augmented_seq_dim))) 143 | sep_vector = CONTROL_VECTOR_PREFIXES[SEP](np.zeros((1, augmented_seq_dim))) 144 | 145 | sequence = np.concatenate((cls_vector, seq_a, sep_vector), axis=0) 146 | segment_ids = [0] * sequence.shape[0] 147 | 148 | if seq_b is not None: 149 | # Adding the other [SEP] token. 150 | can_mask += [True] * len(seq_b) + [False] 151 | seq_b = np.concatenate((np.zeros((len(seq_b), NUM_CONTROL_TOKENS)), seq_b), axis=1) 152 | sequence = np.concatenate((sequence, seq_b, sep_vector), axis=0) 153 | segment_ids += [1] * (seq_b.shape[0] + 1) 154 | 155 | # Masking for masked reconstruction loss. 156 | 157 | seq_orig, seq_masked, el_was_masked = mask(sequence, random_seq, can_mask) 158 | 159 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 160 | # tokens are attended to. 161 | input_mask = [1] * len(sequence) 162 | 163 | # Zero-pad up to the sequence length. 164 | 165 | padding = [0] * (max_seq_length - len(sequence)) 166 | input_mask += padding 167 | segment_ids += padding 168 | el_was_masked += padding 169 | 170 | seq_orig = np.concatenate((seq_orig, np.zeros((len(padding), augmented_seq_dim))), axis=0) 171 | seq_masked = np.concatenate((seq_masked, np.zeros((len(padding), augmented_seq_dim))), axis=0) 172 | 173 | # TODO(mmd): Continue from here. 174 | 175 | assert len(seq_orig) == max_seq_length 176 | assert len(seq_masked) == max_seq_length 177 | assert len(input_mask) == max_seq_length 178 | assert len(segment_ids) == max_seq_length 179 | assert len(el_was_masked) == max_seq_length 180 | 181 | return ContinuousInputFeatures( 182 | input_sequence_orig=seq_orig, 183 | input_sequence_masked=seq_masked, 184 | input_mask=input_mask, 185 | segment_ids=segment_ids, 186 | el_was_masked=el_was_masked, 187 | labels=example.whole_sequence_labels, 188 | ) 189 | 190 | def convert_examples_to_features( 191 | examples, max_seq_length, seed=1, mask_select_prob=0.15, token_mask_prob=0.8, 192 | token_replace_prob=0.1, tqdm=None, multiprocessing_pool_size=1, shuffle=False 193 | ): 194 | """Loads a data file into a list of `InputBatch`s.""" 195 | random.seed(seed) 196 | np.random.seed(seed) 197 | 198 | all_seqs = [] 199 | for example in examples: 200 | all_seqs.append(example.seq_a) 201 | if example.seq_b is not None: all_seqs.append(example.seq_b) 202 | 203 | all_seqs = np.concatenate(all_seqs, axis=0) 204 | if shuffle: all_seqs = np.random.permutation(all_seqs) 205 | all_seqs = np.concatenate((np.zeros((len(all_seqs), NUM_CONTROL_TOKENS)), all_seqs), axis=1) 206 | 207 | seen_so_far = 0 208 | random_seqs = [] 209 | for example in examples: 210 | total_seq_len = len(example.seq_a) + (0 if example.seq_b is None else len(example.seq_b)) 211 | random_seqs.append(all_seqs[seen_so_far : seen_so_far + total_seq_len]) 212 | seen_so_far += total_seq_len 213 | 214 | zipped = zip(examples, random_seqs) 215 | inputs = ( 216 | (e, max_seq_length, rs, mask_select_prob, token_mask_prob, token_replace_prob) for e, rs in zipped 217 | ) 218 | 219 | if multiprocessing_pool_size > 1: 220 | with Pool(multiprocessing_pool_size) as p: 221 | if tqdm is not None: features = list(tqdm(p.imap(__convert_example, inputs), total=len(examples))) 222 | else: features = p.map(__convert_example, inputs) 223 | else: 224 | if tqdm is not None: inputs = tqdm(inputs, total=len(examples)) 225 | features = [__convert_example(i) for i in inputs] 226 | 227 | return features 228 | 229 | def to_set(x): return set(x) if type(x) in [list, tuple, set] else set([x]) 230 | 231 | class ContinuousInputExample(object): 232 | """A single training/test example for simple sequence classification.""" 233 | 234 | def __init__(self, guid, seq_a, seq_b, labels): 235 | """Constructs a InputExample. 236 | 237 | Args: 238 | guid: Unique id for the example. 239 | seq_a: ?. The first sequence. 240 | seq_b: ?. The second sequence. 241 | Only must be specified for sequence pair tasks. 242 | """ 243 | self.guid = guid 244 | self.seq_a = seq_a 245 | self.seq_b = seq_b 246 | self.whole_sequence_labels = labels 247 | 248 | def any_subsequent_indices_ordered(l): 249 | for i in range(len(l)-1): 250 | if l[i] == (l[i+1] - 1): return True 251 | return False 252 | 253 | class ContinuousBertGetEmbeddingsProcessor(): 254 | SEQUENCE_ID = 'sequence_id' 255 | STREAM_ID = 'sequence_stream_id' 256 | 257 | def __init__( 258 | self, sequence_id_idxs, 259 | tuning_folds=[], held_out_folds=[], 260 | seed=1, tqdm = None, multiprocessing_pool_size=1, 261 | chunksize = None, 262 | ): 263 | self.tuning_folds, self.held_out_folds = to_set(tuning_folds), to_set(held_out_folds) 264 | self.train_folds = set([f for f in range(K) if f not in self.tuning_folds.union(self.held_out_folds)]) 265 | 266 | self.seed = seed 267 | self.tqdm = tqdm 268 | self.multiprocessing_pool_size = multiprocessing_pool_size 269 | self.chunksize = chunksize 270 | 271 | self.sequence_id_idxs = list(to_set(sequence_id_idxs)) 272 | if len(self.sequence_id_idxs) > 1: 273 | self.should_add_id = True 274 | self.sequence_id_idx = self.SEQUENCE_ID 275 | else: 276 | self.should_add_id = False 277 | self.sequence_id_idx = self.sequence_id_idxs[0] 278 | 279 | 280 | def get_train_examples(self, df, save_path=None): 281 | return self.get_examples(df, self.train_folds, save_path) 282 | def get_dev_examples(self, df, save_path=None): 283 | return self.get_examples(df, self.tuning_folds, save_path) 284 | def get_test_examples(self, df, save_path=None): 285 | return self.get_examples(df, self.held_out_folds, save_path) 286 | 287 | def get_examples(self, df, folds=None, save_path=None): 288 | """Creates examples for the training and dev sets.""" 289 | print(df.shape, folds) 290 | 291 | fold_idx = df.index.get_level_values(FOLD_IDX_LVL) 292 | if folds is not None: df = df[fold_idx.isin(to_set(folds))] 293 | 294 | if self.should_add_id: add_id_col(df, self.sequence_id_idxs, self.SEQUENCE_ID) 295 | 296 | random.seed(self.seed) 297 | np.random.seed(self.seed) 298 | 299 | sequence_idx = df.index.get_level_values(self.sequence_id_idx) 300 | sequences = sorted(list(set(sequence_idx))) # Make this determinisitic for crying out loud. 301 | 302 | if self.chunksize is None: return self.process_sequences(df, sequences, save_path=save_path) 303 | 304 | sequences_chunks = np.array_split(list(sequences), len(sequences)//self.chunksize) 305 | 306 | for chunk_num, sequences_chunk in enumerate(sequences_chunks): 307 | if save_path is not None: 308 | save_path_chunk = '%s.chunk_%d' % (save_path, chunk_num) 309 | if os.path.isfile(save_path_chunk): 310 | print("Already finished chunk %d at %s" % (chunk_num, save_path_chunk)) 311 | continue 312 | 313 | print("Processing Chunk %d/%d" % (chunk_num + 1, len(sequences_chunks))) 314 | examples = self.process_sequences( 315 | df[sequence_idx.isin(sequences_chunk)], sequences_chunk, save_path=save_path_chunk 316 | ) 317 | 318 | del examples 319 | gc.collect() 320 | 321 | def process_sequences(self, df, sequences, save_path=None): 322 | sequence_idx = df.index.get_level_values(self.sequence_id_idx) 323 | examples_gen = (ContinuousInputExample( 324 | guid=seq, seq_a=df[sequence_idx == seq].values, seq_b=None, labels={} 325 | ) for seq in sequences) 326 | 327 | N = len(sequences) 328 | assert N > 0, "Must process some sequencs!" 329 | print("Processing %d sequences" % N) 330 | 331 | tqdm = lambda i: (i if self.tqdm is None else self.tqdm(i, total=N)) 332 | 333 | if self.multiprocessing_pool_size > 1: 334 | with closing(Pool(self.multiprocessing_pool_size)) as p: 335 | if self.tqdm is None: examples = p.map(I, examples_gen) 336 | else: examples = list(tqdm(p.imap(I, examples_gen))) 337 | else: examples = list(tqdm(examples_gen)) 338 | 339 | if save_path is not None: 340 | with open(save_path, mode='wb') as f: pickle.dump(examples, f) 341 | 342 | assert len(examples) == N, "this really ought not be necessary..." 343 | 344 | return examples 345 | 346 | class ContinuousBertPretrainingDataProcessor(): 347 | SEQUENCE_ID = 'sequence_id' 348 | STREAM_ID = 'sequence_stream_id' 349 | 350 | def __init__( 351 | self, sequence_id_idxs, sequence_stream_idxs, stream_order_idxs, 352 | tuning_folds=[], held_out_folds=[], 353 | guid_sep = '-', seed=1, 354 | tqdm = None, multiprocessing_pool_size=1, chunksize=None, 355 | ): 356 | self.tuning_folds, self.held_out_folds = to_set(tuning_folds), to_set(held_out_folds) 357 | self.train_folds = set([f for f in range(K) if f not in self.tuning_folds.union(self.held_out_folds)]) 358 | 359 | self.sequence_id_idxs = list(to_set(sequence_id_idxs)) 360 | self.sequence_stream_idxs = list(to_set(sequence_stream_idxs)) 361 | self.stream_order_idxs = stream_order_idxs 362 | self.guid_sep = guid_sep 363 | self.seed = seed 364 | self.whole_sequence_tasks = {SEQUENCES_ORDERED: len(LABEL_ENUMS[SEQUENCES_ORDERED])} 365 | self.tqdm = tqdm 366 | self.multiprocessing_pool_size = multiprocessing_pool_size 367 | self.chunksize = chunksize 368 | 369 | def get_train_examples(self, df, save_path=None): 370 | return self.get_examples(df, self.train_folds, save_path) 371 | def get_dev_examples(self, df, save_path=None): 372 | return self.get_examples(df, self.tuning_folds, save_path) 373 | def get_test_examples(self, df, save_path=None): 374 | return self.get_examples(df, self.held_out_folds, save_path) 375 | 376 | def process_stream(self, args): 377 | sequence_id_idx = self.SEQUENCE_ID if len(self.sequence_id_idxs) > 1 else self.sequence_id_idxs[0] 378 | 379 | # TODO(mmd): Make this not need whole df. 380 | s_df, non_stream_df = args 381 | 382 | examples = [] 383 | seq_idx = s_df.index.get_level_values(sequence_id_idx) 384 | seqs = list(set(seq_idx)) 385 | if len(seqs) == 1: 386 | examples.append(ContinuousInputExample( 387 | guid=seqs[0], seq_a=s_df.values, seq_b=None, labels={}, 388 | )) 389 | return examples 390 | 391 | n_examples = len(seqs) - 1 392 | for i in range(n_examples): 393 | seq_a_id, seq_b_id = seqs[i], seqs[i+1] 394 | seq_a = s_df[seq_idx == seq_a_id].values 395 | seq_b = s_df[seq_idx == seq_b_id].values 396 | guid = '%s%s%s' % (seq_a_id, self.guid_sep, seq_b_id) 397 | 398 | examples.append(ContinuousInputExample( 399 | guid=guid, seq_a=seq_a, seq_b=seq_b, 400 | labels={SEQUENCES_ORDERED: SequenceOrderType.ORDERED_AND_FROM_SAME_STREAM}, 401 | )) 402 | 403 | # TODO: Set thresholds 404 | mismatched_subseqs = np.random.permutation(list(range(n_examples+1))) 405 | while any_subsequent_indices_ordered(mismatched_subseqs): 406 | mismatched_subseqs = np.random.permutation(list(range(n_examples+1))) 407 | 408 | for i in range(n_examples): 409 | seq_a_id, seq_b_id = seqs[mismatched_subseqs[i]], seqs[mismatched_subseqs[i+1]] 410 | seq_a = s_df[seq_idx == seq_a_id].values 411 | seq_b = s_df[seq_idx == seq_b_id].values 412 | guid = '%s%s%s' % (seq_a_id, self.guid_sep, seq_b_id) 413 | 414 | examples.append(ContinuousInputExample( 415 | guid=guid, seq_a=seq_a, seq_b=seq_b, 416 | labels={SEQUENCES_ORDERED: SequenceOrderType.NOT_ORDERED_BUT_FROM_SAME_STREAM}, 417 | )) 418 | 419 | # Do non-same-stream. 420 | non_stream_seqs_idx = non_stream_df.index.get_level_values(sequence_id_idx) 421 | non_stream_seqs = np.random.permutation(list(set(non_stream_seqs_idx)))[:n_examples+2] 422 | shuffled_seqs_idx = np.random.permutation(list(range(n_examples+1))) 423 | 424 | for i in range(min(len(non_stream_seqs), n_examples+1)): 425 | seq_a_id, seq_b_id = seqs[shuffled_seqs_idx[i]], non_stream_seqs[i] 426 | seq_a = s_df[seq_idx == seq_a_id].values 427 | seq_b = non_stream_df[non_stream_seqs_idx == seq_b_id].values 428 | guid = '%s%s%s' % (seq_a_id, self.guid_sep, seq_b_id) 429 | examples.append(ContinuousInputExample( 430 | guid=guid, seq_a=seq_a, seq_b=seq_b, 431 | labels={SEQUENCES_ORDERED: SequenceOrderType.NOT_FROM_SAME_STREAM}, 432 | )) 433 | 434 | return examples 435 | 436 | def get_examples(self, df, folds=None, save_path=None): 437 | """Creates examples for the training and dev sets.""" 438 | fold_idx = df.index.get_level_values(FOLD_IDX_LVL) 439 | if folds is not None: df = df[fold_idx.isin(to_set(folds))] 440 | 441 | if len(self.sequence_id_idxs) > 1: 442 | add_id_col(df, self.sequence_id_idxs, self.SEQUENCE_ID) 443 | sequence_id_idx = self.SEQUENCE_ID 444 | else: sequence_id_idx = self.sequence_id_idxs[0] 445 | if len(self.sequence_stream_idxs) > 1: 446 | add_id_col(df, self.sequence_stream_idxs, self.SEQUENCE_ID) 447 | sequence_stream_idx = self.STREAM_ID 448 | else: sequence_stream_idx = self.sequence_stream_idxs[0] 449 | 450 | df = df.sort_index(level=[sequence_stream_idx] + self.stream_order_idxs, axis=0) 451 | 452 | random.seed(self.seed) 453 | np.random.seed(self.seed) 454 | 455 | stream_idx = df.index.get_level_values(sequence_stream_idx) 456 | sequence_idx = df.index.get_level_values(sequence_id_idx) 457 | streams = np.random.permutation(list(set(stream_idx))) 458 | sequences = set(sequence_idx) 459 | 460 | # This processing step is _very_ expensive. Running it with no bells and whistles takes ~70 hours. 461 | # So, we have the option to parallelize it via a multiprocessing pool, which requires serializing and 462 | # sending input data across the cpu pool. We can't do this if we're passing the whole dataframe 463 | # around, so instead we build these generators below which do the slicing in the main thread before 464 | # sending to the workers. At the end, the system looks like it will take on the order of 2 - 6.5 hours 465 | # with a pool of size 256 over 80 cpu nodes. 466 | 467 | N = len(streams) 468 | if self.chunksize is None: streams_chunks = [streams] 469 | else: streams_chunks = np.array_split(streams, N//self.chunksize) 470 | 471 | # TODO(mmd): Make all work... 472 | all_examples = [] 473 | for chunk_num, streams_chunk in enumerate(streams_chunks): 474 | if save_path is not None: 475 | save_path_chunk = '%s.chunk_%d' % (save_path, chunk_num) 476 | if os.path.isfile(save_path_chunk): 477 | print("Already finished chunk %d at %s" % (chunk_num, save_path_chunk)) 478 | continue 479 | 480 | print("Processing Chunk %d/%d" % (chunk_num, len(streams_chunks))) 481 | 482 | print(chunk_num, 'streams_chunk', sys.getsizeof(streams_chunk)) 483 | 484 | stream_dfs = (df[stream_idx == stream] for stream in streams_chunk) 485 | 486 | print(chunk_num, 'streams_df', sys.getsizeof(stream_dfs)) 487 | 488 | stream_sequences = (set(s_df.index.get_level_values(sequence_id_idx)) for s_df in stream_dfs) 489 | 490 | print(chunk_num, 'stream_sequences', sys.getsizeof(stream_sequences)) 491 | 492 | nonstream_dfs = (df[ 493 | sequence_idx.isin(np.random.choice(list(sequences-stream_seqs), len(stream_seqs)+2)) 494 | ] for stream_seqs in stream_sequences) 495 | 496 | print(chunk_num, 'nonstream_dfs', sys.getsizeof(nonstream_dfs)) 497 | 498 | inputs = zip(stream_dfs, nonstream_dfs) 499 | if self.multiprocessing_pool_size > 1: 500 | with closing(Pool(self.multiprocessing_pool_size)) as p: 501 | if self.tqdm is not None: 502 | examples = list(self.tqdm( 503 | p.imap(self.process_stream, inputs), total=len(streams_chunk) 504 | )) 505 | else: examples = p.map(self.process_stream, inputs) 506 | examples = flatten(examples) 507 | else: 508 | examples = [] 509 | for i in (inputs if self.tqdm is None else self.tqdm(inputs, total=len(streams_chunk))): 510 | examples.extend(self.process_stream(i)) 511 | 512 | if save_path is not None: 513 | save_path_chunk = '%s.chunk_%d' % (save_path, chunk_num) 514 | print("Saving partial processor output to %s" % save_path_chunk) 515 | with open(save_path_chunk, mode='wb') as f: pickle.dump(examples, f) 516 | else: all_examples.extend(examples) 517 | 518 | del stream_dfs 519 | del stream_sequences 520 | del nonstream_dfs 521 | del inputs 522 | del examples 523 | gc.collect() 524 | 525 | return all_examples 526 | --------------------------------------------------------------------------------