├── continuous_variables ├── data │ ├── end_pos.pkl │ ├── begin_pos.pkl │ ├── idx_to_lab.pkl │ ├── channel_to_id.pkl │ ├── id_to_channel.pkl │ ├── possible_values.pkl │ ├── variable_ranges.pkl │ └── is_categorical_channel.pkl ├── discretized_convert.py ├── config.py ├── README.md ├── generate.py ├── train_model.py ├── pickleToCSV.py ├── discretize.py ├── genDatasetContinuous.py └── model.py ├── plot_visit_lengths.py ├── baselines ├── synteg │ ├── config.py │ ├── export_condition.py │ ├── condition_simulation.py │ ├── generate_data.py │ ├── synteg.py │ └── dependency_learning.py ├── eva │ ├── config.py │ ├── train_eva.py │ ├── test_eva.py │ └── eva.py ├── gpt │ ├── config.py │ ├── train_gpt.py │ ├── test_gpt.py │ └── gpt.py ├── lstm │ ├── config.py │ ├── lstm.py │ ├── train_lstm.py │ └── test_lstm.py ├── haloCoarse │ ├── config.py │ ├── train_gpt.py │ └── haloCoarse.py ├── evaluate_privacy_nearest.py ├── evaluate_privacy_attribute.py └── evaluate_privacy_membership.py ├── plot_label_probs.py ├── config.py ├── README.md ├── evaluate_privacy_nearest.py ├── evaluate_privacy_attribute.py ├── train_model.py ├── evaluate_privacy_membership.py ├── hcup_ccs_2015_definitions_benchmark.yaml ├── test_model.py └── model.py /continuous_variables/data/end_pos.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btheodorou99/HALO_Inpatient/HEAD/continuous_variables/data/end_pos.pkl -------------------------------------------------------------------------------- /continuous_variables/data/begin_pos.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btheodorou99/HALO_Inpatient/HEAD/continuous_variables/data/begin_pos.pkl -------------------------------------------------------------------------------- /continuous_variables/data/idx_to_lab.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btheodorou99/HALO_Inpatient/HEAD/continuous_variables/data/idx_to_lab.pkl -------------------------------------------------------------------------------- /continuous_variables/data/channel_to_id.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btheodorou99/HALO_Inpatient/HEAD/continuous_variables/data/channel_to_id.pkl -------------------------------------------------------------------------------- /continuous_variables/data/id_to_channel.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btheodorou99/HALO_Inpatient/HEAD/continuous_variables/data/id_to_channel.pkl -------------------------------------------------------------------------------- /continuous_variables/data/possible_values.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btheodorou99/HALO_Inpatient/HEAD/continuous_variables/data/possible_values.pkl -------------------------------------------------------------------------------- /continuous_variables/data/variable_ranges.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btheodorou99/HALO_Inpatient/HEAD/continuous_variables/data/variable_ranges.pkl -------------------------------------------------------------------------------- /continuous_variables/data/is_categorical_channel.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/btheodorou99/HALO_Inpatient/HEAD/continuous_variables/data/is_categorical_channel.pkl -------------------------------------------------------------------------------- /plot_visit_lengths.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import pandas as pd 4 | from tqdm import tqdm 5 | import matplotlib.pyplot as plt 6 | 7 | stats = pickle.load(open('results/shape.pkl', 'rb')) 8 | 9 | l = ['EVA', 'SynTEG', 'LSTM', 'HALO Coarse', 'GPT', 'HALO', 'Train'] 10 | stats['SynTEG']['Visit Lengths'] = np.random.choice(stats['SynTEG']['Visit Lengths'], len(stats['Train']['Visit Lengths']), replace=False) 11 | B = pd.DataFrame({key: pd.Series(stats[key]['Visit Lengths']) for key in l}) 12 | B.plot.kde() 13 | plt.xlim(-5,45) 14 | plt.xlabel('Number of Codes') 15 | plt.title('Inpatient EHR Visit Lengths Probability Density') 16 | plt.savefig("results/plots/visit_lengths.png") -------------------------------------------------------------------------------- /baselines/synteg/config.py: -------------------------------------------------------------------------------- 1 | class SyntegConfig(object): 2 | def __init__(self): 3 | self.embedding_dim = 112 4 | self.word_embedding_dim = 80 5 | self.attention_size = 128 6 | self.ff_dim = 128 7 | self.max_num_visit = 48 8 | self.max_length_visit = 80 9 | self.num_head = 4 10 | self.code_vocab_dim = 6841 11 | self.label_vocab_dim = 25 12 | self.vocab_dim = self.code_vocab_dim + self.label_vocab_dim + 2 # Plus the start and end tokens 13 | self.head_dim = 32 14 | self.lstm_dim = 512 15 | self.n_layer = 3 16 | self.condition_dim = 256 17 | self.dependency_batchsize = 40 18 | self.args = [-1, self.max_num_visit, self.max_length_visit, self.num_head, self.head_dim] 19 | self.z_dim = 128 20 | self.g_dims = [256, 256, 512, 512, 512, 512, self.vocab_dim] 21 | self.d_dims = [256, 256, 256, 128, 128, 128] 22 | self.gan_batchsize = 2500 23 | self.gp_weight = 10 -------------------------------------------------------------------------------- /plot_label_probs.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from sklearn.metrics import r2_score 5 | 6 | stats = pickle.load(open('results/shape.pkl', 'rb')) 7 | 8 | # Some fake data to plot 9 | l1 = 'Train' 10 | list2 = ['SynTEG', 'EVA', 'LSTM', 'GPT', 'HALO Coarse', 'HALO'] 11 | colors = ['brown', 'purple', 'orange', 'green', 'red', 'blue'] 12 | 13 | for label, col in zip(list2, colors): 14 | X = np.expand_dims(np.array(stats[l1]['Labels']), 1) 15 | y = np.expand_dims(np.array(stats[l2]['Labels']), 1) 16 | theta = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y) 17 | y_line = X.dot(theta) 18 | r2 = r2_score(y, X) 19 | plt.scatter(X, y, c=col, label=f"{label} ({r2:.3f})", marker='x' if label == 'HALO' else 'o') 20 | 21 | plt.xlim(0, 0.25) 22 | plt.xlabel('MIMIC Label Probability') 23 | plt.ylim(0, 0.25) 24 | plt.ylabel('Synthetic Dataset Label Probability') 25 | plt.title('Synthetic Dataset vs. MIMIC Chronic Condition Label Probabilities') 26 | plt.plot([0,0.25], [0,0.25], 'k-', zorder=0) 27 | plt.legend() 28 | plt.savefig('results/plots/label_probs.png') -------------------------------------------------------------------------------- /baselines/eva/config.py: -------------------------------------------------------------------------------- 1 | class EVAConfig(object): 2 | def __init__( 3 | self, 4 | total_vocab_size=6869, 5 | code_vocab_size=6841, 6 | label_vocab_size=25, 7 | special_vocab_size=3, 8 | n_ctx=57, 9 | n_embd=768, 10 | latent_dim=32, 11 | n_lstm_layer=1, 12 | n_conv1d_layer=3, 13 | n_deconv_layer=4, 14 | dilation_factor=2, 15 | deconv_factor=3, 16 | batch_size=128, 17 | prob_batch_size=4, 18 | epoch=50, 19 | lr=1e-4, 20 | pos_loss_weight=None 21 | ): 22 | self.total_vocab_size = total_vocab_size 23 | self.code_vocab_size = code_vocab_size 24 | self.label_vocab_size = label_vocab_size 25 | self.special_vocab_size = special_vocab_size 26 | self.n_ctx = n_ctx 27 | self.n_embd = n_embd 28 | self.latent_dim = latent_dim 29 | self.n_lstm_layer = n_lstm_layer 30 | self.n_conv1d_layer = n_conv1d_layer 31 | self.n_deconv_layer = n_deconv_layer 32 | self.dilation_factor = dilation_factor 33 | self.deconv_factor = deconv_factor 34 | self.batch_size = batch_size 35 | self.prob_batch_size = prob_batch_size 36 | self.epoch = epoch 37 | self.lr = lr 38 | self.pos_loss_weight = pos_loss_weight -------------------------------------------------------------------------------- /baselines/gpt/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Brandon Theodorou 3 | Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 4 | Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT 5 | GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch 6 | ''' 7 | class GPTConfig(object): 8 | def __init__( 9 | self, 10 | total_vocab_size=6871, 11 | code_vocab_size=6841, 12 | label_vocab_size=25, 13 | special_vocab_size=5, # start, start visits, end visit, end record, pad 14 | n_positions=750, 15 | n_ctx=700, 16 | n_embd=384, 17 | n_layer=3, 18 | n_head=4, 19 | layer_norm_epsilon=1e-5, 20 | initializer_range=0.02, 21 | batch_size=48, 22 | epoch=50, 23 | lr=1e-4, 24 | ): 25 | self.total_vocab_size = total_vocab_size 26 | self.code_vocab_size = code_vocab_size 27 | self.label_vocab_size = label_vocab_size 28 | self.special_vocab_size = special_vocab_size 29 | self.n_positions = n_positions 30 | self.n_ctx = n_ctx 31 | self.n_embd = n_embd 32 | self.n_layer = n_layer 33 | self.n_head = n_head 34 | self.layer_norm_epsilon = layer_norm_epsilon 35 | self.initializer_range = initializer_range 36 | self.batch_size = batch_size 37 | self.epoch = epoch 38 | self.lr = lr -------------------------------------------------------------------------------- /baselines/lstm/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Brandon Theodorou 3 | Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 4 | Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT 5 | GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch 6 | ''' 7 | class LSTMConfig(object): 8 | def __init__( 9 | self, 10 | total_vocab_size=6869, 11 | code_vocab_size=6841, 12 | label_vocab_size=25, 13 | special_vocab_size=3, 14 | n_positions=56, 15 | n_ctx=48, 16 | n_embd=768, 17 | n_layer=12, 18 | n_head=12, 19 | layer_norm_epsilon=1e-5, 20 | initializer_range=0.02, 21 | batch_size=128, 22 | epoch=25, 23 | pos_loss_weight=None, 24 | lr=1e-4, 25 | ): 26 | self.total_vocab_size = total_vocab_size 27 | self.code_vocab_size = code_vocab_size 28 | self.label_vocab_size = label_vocab_size 29 | self.special_vocab_size = special_vocab_size 30 | self.n_positions = n_positions 31 | self.n_ctx = n_ctx 32 | self.n_embd = n_embd 33 | self.n_layer = n_layer 34 | self.n_head = n_head 35 | self.layer_norm_epsilon = layer_norm_epsilon 36 | self.initializer_range = initializer_range 37 | self.batch_size = batch_size 38 | self.epoch = epoch 39 | self.pos_loss_weight = pos_loss_weight 40 | self.lr = lr -------------------------------------------------------------------------------- /baselines/haloCoarse/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Brandon Theodorou 3 | Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 4 | Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT 5 | GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch 6 | ''' 7 | class HALOCoarseConfig(object): 8 | def __init__( 9 | self, 10 | total_vocab_size=6869, 11 | code_vocab_size=6841, 12 | label_vocab_size=25, 13 | special_vocab_size=3, 14 | n_positions=56, 15 | n_ctx=48, 16 | n_embd=768, 17 | n_layer=12, 18 | n_head=12, 19 | layer_norm_epsilon=1e-5, 20 | initializer_range=0.02, 21 | batch_size=128, 22 | epoch=25, 23 | pos_loss_weight=None, 24 | lr=1e-4, 25 | ): 26 | self.total_vocab_size = total_vocab_size 27 | self.code_vocab_size = code_vocab_size 28 | self.label_vocab_size = label_vocab_size 29 | self.special_vocab_size = special_vocab_size 30 | self.n_positions = n_positions 31 | self.n_ctx = n_ctx 32 | self.n_embd = n_embd 33 | self.n_layer = n_layer 34 | self.n_head = n_head 35 | self.layer_norm_epsilon = layer_norm_epsilon 36 | self.initializer_range = initializer_range 37 | self.batch_size = batch_size 38 | self.epoch = epoch 39 | self.pos_loss_weight = pos_loss_weight 40 | self.lr = lr -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Brandon Theodorou 3 | Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 4 | Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT 5 | GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch 6 | ''' 7 | class HALOConfig(object): 8 | def __init__( 9 | self, 10 | total_vocab_size=6869, 11 | code_vocab_size=6841, 12 | label_vocab_size=25, 13 | special_vocab_size=3, 14 | n_positions=56, 15 | n_ctx=48, 16 | n_embd=768, 17 | n_layer=12, 18 | n_head=12, 19 | layer_norm_epsilon=1e-5, 20 | initializer_range=0.02, 21 | batch_size=48, 22 | sample_batch_size=256, 23 | epoch=50, 24 | pos_loss_weight=None, 25 | lr=1e-4, 26 | ): 27 | self.total_vocab_size = total_vocab_size 28 | self.code_vocab_size = code_vocab_size 29 | self.label_vocab_size = label_vocab_size 30 | self.special_vocab_size = special_vocab_size 31 | self.n_positions = n_positions 32 | self.n_ctx = n_ctx 33 | self.n_embd = n_embd 34 | self.n_layer = n_layer 35 | self.n_head = n_head 36 | self.layer_norm_epsilon = layer_norm_epsilon 37 | self.initializer_range = initializer_range 38 | self.batch_size = batch_size 39 | self.sample_batch_size = sample_batch_size 40 | self.epoch = epoch 41 | self.pos_loss_weight = pos_loss_weight 42 | self.lr = lr -------------------------------------------------------------------------------- /baselines/lstm/lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LSTMBaseline(nn.Module): 5 | def __init__(self, config): 6 | super(LSTMBaseline, self).__init__() 7 | self.embedding_matrix = nn.Linear(config.total_vocab_size, config.n_embd, bias=False) 8 | self.lstm = nn.LSTM(input_size=config.n_embd, 9 | hidden_size=config.n_embd, 10 | num_layers=6, 11 | batch_first=True) 12 | self.ehr_head = nn.Linear(config.n_embd, config.total_vocab_size) 13 | 14 | def forward(self, input_visits, ehr_labels=None, ehr_masks=None, pos_loss_weight=None): 15 | embeddings = self.embedding_matrix(input_visits) 16 | hidden_states, _ = self.lstm(embeddings) 17 | code_logits = self.ehr_head(hidden_states) 18 | sig = nn.Sigmoid() 19 | code_probs = sig(code_logits) 20 | if ehr_labels is not None: 21 | shift_probs = code_probs[..., :-1, :].contiguous() 22 | shift_labels = ehr_labels[..., 1:, :].contiguous() 23 | loss_weights = None 24 | if pos_loss_weight is not None: 25 | loss_weights = torch.ones(shift_probs.shape, device=code_probs.device) 26 | loss_weights = loss_weights + (pos_loss_weight-1) * shift_labels 27 | if ehr_masks is not None: 28 | shift_probs = shift_probs * ehr_masks 29 | shift_labels = shift_labels * ehr_masks 30 | if pos_loss_weight is not None: 31 | loss_weights = loss_weights * ehr_masks 32 | 33 | bce = nn.BCELoss(weight=loss_weights) 34 | loss = bce(shift_probs, shift_labels) 35 | return loss, shift_probs, shift_labels 36 | 37 | return code_probs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HALO 2 | [![DOI](https://zenodo.org/badge/627704812.svg)](https://zenodo.org/badge/latestdoi/627704812) 3 | 4 | This is the source code for reproducing the inpatient dataset experiments found in the paper "Synthesizing Extremely High Dimensional Electronic Health Records." 5 | 6 | ## Generating the Dataset 7 | This code interfaces with the pubilc MIMIC-III ICU stay database. Before using the code, you will need to apply, complete training, and download the ADMISSIONS and DIAGNOSES_ICD tables from . From there, generate an empty directory `data/`, edit the `mimic_dir` variable in the file `build_dataset.py`, and run that file. It will generate all of the relevant data files. 8 | 9 | ## Training a Model 10 | Next, a model can be training by creating an empt `save/` directory and running the `train_model.py` script. 11 | 12 | ## Training Baseline Models 13 | Next, any desired baseline models may be trained by changing your working directory to `baselines/{baseline_model}` and running the corresponding `train_{baseline_model}.py` script 14 | 15 | ## Evaluating the Model(s) 16 | Finally, the trained model and its synthetic data may be evaluated. Before beginning, create the following directory paths: 17 | * `results/datasets` 18 | * `results/dataset_stats/plots` 19 | * `results/testing_stats` 20 | * `results/synthetic_training_stats` 21 | * `results/privacy_evaluations` 22 | 23 | After these directories are created, first run the `test_model.py` script (along with any corresponding `test_{baseline_model}.py` in the directories from the previous section). This will generate perplexity, prediction, and synthetic dataset results. From there, you may run any other evaluation scripts (prefixed with evaluate_), making sure any references to unrun baseline models are commented out. All corresponding results will be printed and saved to pickle files. 24 | -------------------------------------------------------------------------------- /continuous_variables/discretized_convert.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | 4 | idToLab = pickle.load(open('discretized_data/idToLab.pkl', 'rb')) 5 | isCategorical = pickle.load(open('discretized_data/isCategorical.pkl', 'rb')) 6 | discretization = pickle.load(open('discretized_data/discretization.pkl', 'rb')) 7 | possibleValues = pickle.load(open('discretized_data/possibleValues.pkl', 'rb')) 8 | discretization = pickle.load(open('discretized_data/discretization.pkl', 'rb')) 9 | formatMap = pickle.load(open('discretized_data/formatMap.pkl', 'rb')) 10 | idToLabel = pickle.load(open('discretized_data/idToLabel.pkl', 'rb')) 11 | indexToCode = pickle.load(open('discretized_data/indexToCode.pkl', 'rb')) 12 | 13 | dataset = pickle.load(open('results/datasets/haloDataset.pkl', 'rb')) 14 | 15 | def formatCont(value, key): 16 | return formatMap[key][1](("{:" + formatMap[key][0] + "}").format(value)) 17 | 18 | for p in dataset: 19 | new_visits = [] 20 | firstVisit = True 21 | for v in p['visits']: 22 | new_labs = [] 23 | new_values = [] 24 | for i in range(len(v[1])): 25 | new_labs.append(idToLab[v[1][i]]) 26 | if isCategorical[idToLab[v[1][i]]]: 27 | new_values.append(possibleValues[idToLab[v[1][i]]][v[2][i]]) 28 | else: 29 | new_values.append(formatCont(random.uniform(discretization[idToLab[v[1][i]]][v[2][i]], discretization[idToLab[v[1][i]]][v[2][i]+1]), idToLab[v[1][i]])) 30 | contType = 'Hours' if new_labs != [] else 'Age' if firstVisit else 'Days' 31 | if contType == 'Age': 32 | firstVisit = False 33 | new_cont = formatCont(random.uniform(discretization[contType][v[4][-1]], discretization[contType][v[4][-1]+1]), contType) 34 | new_visits.append((v[0], new_labs, new_values, [new_cont])) 35 | p['visits'] = new_visits 36 | 37 | pickle.dump(dataset, open('results/datasets/haloDataset_converted.pkl', 'wb')) 38 | -------------------------------------------------------------------------------- /continuous_variables/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Brandon Theodorou 3 | Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 4 | Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT 5 | GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch 6 | ''' 7 | class HALOConfig(object): 8 | def __init__( 9 | self, 10 | total_vocab_size=14487, 11 | code_vocab_size=14167, 12 | lab_vocab_size=237, 13 | continuous_vocab_size=15, 14 | label_vocab_size=65, 15 | special_vocab_size=3, 16 | 17 | categorical_lab_vocab_size=47, 18 | continuous_lab_vocab_size=190, 19 | 20 | phenotype_labels=25, 21 | n_positions=150, 22 | n_ctx=150, 23 | n_embd=1440, 24 | n_layer=12, 25 | n_head=18, 26 | layer_norm_epsilon=1e-5, 27 | initializer_range=0.02, 28 | 29 | batch_size=56, 30 | sample_batch_size=128, 31 | epoch=50, 32 | lr=1e-4, 33 | ): 34 | self.total_vocab_size = total_vocab_size 35 | self.code_vocab_size = code_vocab_size 36 | self.label_vocab_size = label_vocab_size 37 | self.lab_vocab_size = lab_vocab_size 38 | self.categorical_lab_vocab_size = categorical_lab_vocab_size 39 | self.continuous_lab_vocab_size = continuous_lab_vocab_size 40 | self.continuous_vocab_size = continuous_vocab_size 41 | self.special_vocab_size = special_vocab_size 42 | self.phenotype_labels = phenotype_labels 43 | self.n_positions = n_positions 44 | self.n_ctx = n_ctx 45 | self.n_embd = n_embd 46 | self.n_layer = n_layer 47 | self.n_head = n_head 48 | self.layer_norm_epsilon = layer_norm_epsilon 49 | self.initializer_range = initializer_range 50 | self.batch_size = batch_size 51 | self.sample_batch_size = sample_batch_size 52 | self.epoch = epoch 53 | self.lr = lr 54 | -------------------------------------------------------------------------------- /evaluate_privacy_nearest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from sklearn import metrics 7 | from config import HALOConfig 8 | 9 | SEED = 4 10 | random.seed(SEED) 11 | np.random.seed(SEED) 12 | torch.manual_seed(SEED) 13 | NUM_SAMPLES = 5000 14 | 15 | config = HALOConfig() 16 | train_ehr_dataset = pickle.load(open('./data/trainDataset.pkl', 'rb')) 17 | train_ehr_dataset = np.random.choice(train_ehr_dataset, NUM_SAMPLES) 18 | test_ehr_dataset = pickle.load(open('./data/testDataset.pkl', 'rb')) 19 | test_ehr_dataset = np.random.choice(test_ehr_dataset, NUM_SAMPLES) 20 | synthetic_ehr_dataset = pickle.load(open('./results/datasets/haloDataset.pkl', 'rb')) 21 | synthetic_ehr_dataset = np.random.choice([p for p in synthetic_ehr_dataset if len(p['visits']) > 0], NUM_SAMPLES) 22 | synthetic_ehr_dataset = [{'labels': p['labels'], 'visits': [set(v) for v in p['visits']]} for p in synthetic_ehr_dataset] 23 | 24 | def find_hamming(ehr, dataset): 25 | min_d = 1e10 26 | visits = ehr['visits'] 27 | labels = ehr['labels'] 28 | for p in dataset: 29 | d = 0 if len(visits) == len(p['visits']) else 1 30 | l = p['labels'] 31 | d += ((labels + l) == 1).sum() 32 | for i in range(len(visits)): 33 | v = visits[i] 34 | if i >= len(p['visits']): 35 | d += len(v) 36 | else: 37 | v2 = p['visits'][i] 38 | d += len(v) + len(v2) - (2 * len(v.intersection(v2))) 39 | 40 | min_d = d if d < min_d and d > 0 else min_d 41 | return min_d 42 | 43 | def calc_nnaar(train, evaluation, synthetic): 44 | val1 = 0 45 | val2 = 0 46 | val3 = 0 47 | val4 = 0 48 | for p in tqdm(evaluation): 49 | des = find_hamming(p, synthetic) 50 | dee = find_hamming(p, evaluation) 51 | if des > dee: 52 | val1 += 1 53 | 54 | for p in tqdm(train): 55 | dts = find_hamming(p, synthetic) 56 | dtt = find_hamming(p, train) 57 | if dts > dtt: 58 | val3 += 1 59 | 60 | for p in tqdm(synthetic): 61 | dse = find_hamming(p, evaluation) 62 | dst = find_hamming(p, train) 63 | dss = find_hamming(p, synthetic) 64 | if dse > dss: 65 | val2 += 1 66 | if dst > dss: 67 | val4 += 1 68 | 69 | val1 = val1 / NUM_SAMPLES 70 | val2 = val2 / NUM_SAMPLES 71 | val3 = val3 / NUM_SAMPLES 72 | val4 = val4 / NUM_SAMPLES 73 | 74 | aaes = (0.5 * val1) + (0.5 * val2) 75 | aaet = (0.5 * val3) + (0.5 * val4) 76 | return aaes - aaet 77 | 78 | nnaar = calc_nnaar(train_ehr_dataset, test_ehr_dataset, synthetic_ehr_dataset) 79 | results = { 80 | "NNAAE": nnaar 81 | } 82 | pickle.dump(results, open("results/privacy_evaluation/nnaar.pkl", "wb")) 83 | print(results) -------------------------------------------------------------------------------- /continuous_variables/README.md: -------------------------------------------------------------------------------- 1 | # Handling Continuous Variables 2 | 3 | This directory contains the source code for reproducing the continuous variable experiments found in the paper "Synthesizing Extremely High Dimensional Electronic Health Records." 4 | 5 | ## Generating the Dataset 6 | This code interfaces with the pubilc MIMIC-III ICU stay database. Before using the code, you will need to apply, complete training, and download the requisite files from . The required files are: 7 | * 'PATIENTS.csv' 8 | * 'ADMISSIONS.csv' 9 | * 'DIAGNOSES_ICD.csv' 10 | * 'PROCEDURES_ICD.csv' 11 | * 'PRESCRIPTIONS.csv' 12 | * 'CHARTEVENTS.csv' 13 | 14 | Next, you need to perform the mimic3-benchmarks preprocessing according the the repository found at . That repository has comprehensive documentation, and it will create a series of .csv files containing lab timeseries information for each ICU stay. You just need to get through the `extract_episodes_from_subjects` step. 15 | 16 | From there, edit the `mimic_dir` and `timeseries_dir` variables in the file `genDatasetContinuous.py`, and run that file. It will generate all of the base data files for these experiments. 17 | 18 | Next, according to the paper and HALO method, we need to discretize the continuous variables (lab values and inter-visit gaps) in order to feed them into our model. To do so, create a `discretized_data/` directory and run the file `discretize.py` 19 | 20 | At this point, the discretized data and correpsonding artifacts will be available, and your dataset will be fully processed. 21 | 22 | ## Setting the Config 23 | Depending on any dataset changes, you may need to adjust the `config.py` file according to the dataset you are using. Specifically, you may need to set `code_vocab_size` and `label_vocab_size` based on what is printed at the end of running the `genDatasetContinuous.py` file and then set `lab_vocab_size` and `continuous_vocab_size` based on what is printed at the end of running the `discretize.py` file. 24 | 25 | ## Training a Model 26 | Next, a model can be training by creating an empt `save/` directory and running the `train_model.py` script. 27 | 28 | ## Generating Data 29 | With this model, you are ready to create comprehensive data including continous variables. Ensure that the path `results/datasets` is created, and run the file `generate.py` followed by the file `discretized_convert.py` to convert the data back to a full continuous format in the style of the original training data before it was discretized. 30 | 31 | Note, if you want a different amount of data rather than the size of the training dataset, set the totEHRs variable on line 93 of `generate.py`. 32 | 33 | ## Evaluating the Model(s) 34 | Finally, the trained model and its synthetic data may be evaluated. Before beginning, make sure the path `results/dataset_stats/plots` exists. Then run the `evaluate.py` script. This will generate a series of plots showcasing a wide variety of both standard and continuous valued results. 35 | -------------------------------------------------------------------------------- /baselines/evaluate_privacy_nearest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from sklearn import metrics 7 | from config import HALOConfig 8 | 9 | SEED = 4 10 | random.seed(SEED) 11 | np.random.seed(SEED) 12 | torch.manual_seed(SEED) 13 | NUM_SAMPLES = 5000 14 | 15 | key = 'haloCoarse' 16 | 17 | config = HALOConfig() 18 | train_ehr_dataset = pickle.load(open('./data/trainDataset.pkl', 'rb')) 19 | train_ehr_dataset = np.random.choice(train_ehr_dataset, NUM_SAMPLES) 20 | train_ehr_dataset = [{'labels': p['labels'], 'visits': [set(v) for v in p['visits']]} for p in train_ehr_dataset] 21 | test_ehr_dataset = pickle.load(open('./data/testDataset.pkl', 'rb')) 22 | test_ehr_dataset = np.random.choice(test_ehr_dataset, NUM_SAMPLES) 23 | test_ehr_dataset = [{'labels': p['labels'], 'visits': [set(v) for v in p['visits']]} for p in test_ehr_dataset] 24 | synthetic_ehr_dataset = pickle.load(open(f'./results/datasets/{key}Dataset.pkl', 'rb')) 25 | synthetic_ehr_dataset = np.random.choice([p for p in synthetic_ehr_dataset if len(p['visits']) > 0], NUM_SAMPLES) 26 | synthetic_ehr_dataset = [{'labels': p['labels'], 'visits': [set(v) for v in p['visits']]} for p in synthetic_ehr_dataset] 27 | 28 | def find_hamming(ehr, dataset): 29 | min_d = 1e10 30 | visits = ehr['visits'] 31 | labels = ehr['labels'] 32 | for p in dataset: 33 | d = 0 if len(visits) == len(p['visits']) else 1 34 | l = p['labels'] 35 | d += ((labels + l) == 1).sum() 36 | for i in range(len(visits)): 37 | v = visits[i] 38 | if i >= len(p['visits']): 39 | d += len(v) 40 | else: 41 | v2 = p['visits'][i] 42 | d += len(v) + len(v2) - (2 * len(v.intersection(v2))) 43 | 44 | min_d = d if d < min_d and d > 0 else min_d 45 | return min_d 46 | 47 | def calc_nnaar(train, evaluation, synthetic): 48 | val1 = 0 49 | val2 = 0 50 | val3 = 0 51 | val4 = 0 52 | for p in tqdm(evaluation): 53 | des = find_hamming(p, synthetic) 54 | dee = find_hamming(p, evaluation) 55 | if des > dee: 56 | val1 += 1 57 | 58 | for p in tqdm(train): 59 | dts = find_hamming(p, synthetic) 60 | dtt = find_hamming(p, train) 61 | if dts > dtt: 62 | val3 += 1 63 | 64 | for p in tqdm(synthetic): 65 | dse = find_hamming(p, evaluation) 66 | dst = find_hamming(p, train) 67 | dss = find_hamming(p, synthetic) 68 | if dse > dss: 69 | val2 += 1 70 | if dst > dss: 71 | val4 += 1 72 | 73 | val1 = val1 / NUM_SAMPLES 74 | val2 = val2 / NUM_SAMPLES 75 | val3 = val3 / NUM_SAMPLES 76 | val4 = val4 / NUM_SAMPLES 77 | 78 | aaes = (0.5 * val1) + (0.5 * val2) 79 | aaet = (0.5 * val3) + (0.5 * val4) 80 | return aaes - aaet 81 | 82 | nnaar = calc_nnaar(train_ehr_dataset, test_ehr_dataset, synthetic_ehr_dataset) 83 | results = { 84 | "NNAAE": nnaar 85 | } 86 | pickle.dump(results, open("results/privacy_evaluation/nnaar_{key}.pkl", "wb")) 87 | print(results) -------------------------------------------------------------------------------- /evaluate_privacy_attribute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from collections import Counter 7 | from config import HALOConfig 8 | 9 | SEED = 4 10 | random.seed(SEED) 11 | np.random.seed(SEED) 12 | torch.manual_seed(SEED) 13 | 14 | local_rank = -1 15 | fp16 = False 16 | if local_rank == -1: 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | n_gpu = torch.cuda.device_count() 19 | else: 20 | torch.cuda.set_device(local_rank) 21 | device = torch.device("cuda", local_rank) 22 | n_gpu = 1 23 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 24 | torch.distributed.init_process_group(backend='nccl') 25 | if torch.cuda.is_available(): 26 | torch.cuda.manual_seed_all(SEED) 27 | 28 | config = HALOConfig() 29 | test_ehr_dataset = pickle.load(open('./data/testDataset.pkl', 'rb')) 30 | test_ehr_dataset = [{'labels': p['labels'], 'visits': set([c for v in p['visits'] for c in v])} for p in test_ehr_dataset] 31 | train_ehr_dataset = pickle.load(open('./data/trainDataset.pkl', 'rb')) 32 | train_ehr_dataset = [{'labels': p['labels'], 'visits': set([c for v in p['visits'] for c in v])} for p in train_ehr_dataset] 33 | train_ehr_dataset = np.random.choice(train_ehr_dataset, len(test_ehr_dataset), replace=False) 34 | synthetic_ehr_dataset = pickle.load(open('./results/datasets/haloDataset.pkl', 'rb')) 35 | synthetic_ehr_dataset = [{'labels': p['labels'], 'visits': set([c for v in p['visits'] for c in v])} for p in synthetic_ehr_dataset if len(p['visits']) > 0] 36 | synthetic_ehr_dataset = np.random.choice(synthetic_ehr_dataset, len(test_ehr_dataset), replace=False) 37 | 38 | common_codes = set([cd for cd, _ in Counter([c for p in train_ehr_dataset for c in p['visits']]).most_common()[0:100]]) 39 | 40 | test_ehr_dataset = [{'labels': set([c for c in p['labels'].nonzero()[0].tolist()] + [c + config.label_vocab_size for c in p['visits'] if c in common_codes]), 'codes': set([c for c in p['visits'] if c not in common_codes])} for p in test_ehr_dataset] 41 | train_ehr_dataset = [{'labels': set([c for c in p['labels'].nonzero()[0].tolist()] + [c + config.label_vocab_size for c in p['visits'] if c in common_codes]), 'codes': set([c for c in p['visits'] if c not in common_codes])} for p in train_ehr_dataset] 42 | synthetic_ehr_dataset = [{'labels': set([c for c in p['labels'].nonzero()[0].tolist()] + [c + config.label_vocab_size for c in p['visits'] if c in common_codes]), 'codes': set([c for c in p['visits'] if c not in common_codes])} for p in synthetic_ehr_dataset] 43 | 44 | def calc_dist(lab1, lab2): 45 | return len(lab1.union(lab2)) - len(lab1.intersection(lab2)) 46 | 47 | def find_closest(patient, data, k): 48 | cond = patient['labels'] 49 | dists = [(calc_dist(cond, ehr['labels']), ehr['codes']) for ehr in data] 50 | dists.sort(key= lambda x: x[0], reverse=False) 51 | options = [o[1] for o in dists[:k]] 52 | return options 53 | 54 | def calc_attribute_risk(train_dataset, reference_dataset, k): 55 | tp = 0 56 | fp = 0 57 | fn = 0 58 | for p in tqdm(train_dataset): 59 | closest_k = find_closest(p, reference_dataset, k) 60 | pred_codes = set([cd for cd, cnt in Counter([c for p in closest_k for c in p]).items() if cnt > k/2]) 61 | true_pos = len(pred_codes.intersection(p['codes'])) 62 | false_pos = len(pred_codes) - true_pos 63 | false_neg = len(p['codes']) - true_pos 64 | tp += true_pos 65 | fp += false_pos 66 | fn += false_neg 67 | 68 | f1 = tp / (tp + (0.5 * (fp + fn))) 69 | return f1 70 | 71 | K = 1 72 | att_risk = calc_attribute_risk(train_ehr_dataset, synthetic_ehr_dataset, K) 73 | baseline_risk = calc_attribute_risk(train_ehr_dataset, test_ehr_dataset, K) 74 | results = { 75 | "Attribute Attack F1 Score": att_risk, 76 | "Baseline Attack F1 Score": baseline_risk 77 | } 78 | pickle.dump(results, open("results/privacy_evaluation/attribute_inference.pkl", "wb")) 79 | print(results) -------------------------------------------------------------------------------- /baselines/evaluate_privacy_attribute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from collections import Counter 7 | from config import HALOConfig 8 | 9 | SEED = 4 10 | random.seed(SEED) 11 | np.random.seed(SEED) 12 | torch.manual_seed(SEED) 13 | 14 | key = 'haloCoarse' 15 | 16 | local_rank = -1 17 | fp16 = False 18 | if local_rank == -1: 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | n_gpu = torch.cuda.device_count() 21 | else: 22 | torch.cuda.set_device(local_rank) 23 | device = torch.device("cuda", local_rank) 24 | n_gpu = 1 25 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 26 | torch.distributed.init_process_group(backend='nccl') 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(SEED) 29 | 30 | config = HALOConfig() 31 | test_ehr_dataset = pickle.load(open('./data/testDataset.pkl', 'rb')) 32 | test_ehr_dataset = [{'labels': p['labels'], 'visits': set([c for v in p['visits'] for c in v])} for p in test_ehr_dataset] 33 | train_ehr_dataset = pickle.load(open('./data/trainDataset.pkl', 'rb')) 34 | train_ehr_dataset = [{'labels': p['labels'], 'visits': set([c for v in p['visits'] for c in v])} for p in train_ehr_dataset] 35 | train_ehr_dataset = np.random.choice(train_ehr_dataset, len(test_ehr_dataset), replace=False) 36 | synthetic_ehr_dataset = pickle.load(open(f'./results/datasets/{key}Dataset.pkl', 'rb')) 37 | synthetic_ehr_dataset = [{'labels': p['labels'], 'visits': set([c for v in p['visits'] for c in v])} for p in synthetic_ehr_dataset if len(p['visits']) > 0] 38 | synthetic_ehr_dataset = np.random.choice(synthetic_ehr_dataset, len(test_ehr_dataset), replace=False) 39 | 40 | common_codes = set([cd for cd, _ in Counter([c for p in train_ehr_dataset for c in p['visits']]).most_common()[0:100]]) 41 | 42 | test_ehr_dataset = [{'labels': set([c for c in p['labels'].nonzero()[0].tolist()] + [c + config.label_vocab_size for c in p['visits'] if c in common_codes]), 'codes': set([c for c in p['visits'] if c not in common_codes])} for p in test_ehr_dataset] 43 | train_ehr_dataset = [{'labels': set([c for c in p['labels'].nonzero()[0].tolist()] + [c + config.label_vocab_size for c in p['visits'] if c in common_codes]), 'codes': set([c for c in p['visits'] if c not in common_codes])} for p in train_ehr_dataset] 44 | synthetic_ehr_dataset = [{'labels': set([c for c in p['labels'].nonzero()[0].tolist()] + [c + config.label_vocab_size for c in p['visits'] if c in common_codes]), 'codes': set([c for c in p['visits'] if c not in common_codes])} for p in synthetic_ehr_dataset] 45 | 46 | def calc_dist(lab1, lab2): 47 | return len(lab1.union(lab2)) - len(lab1.intersection(lab2)) 48 | 49 | def find_closest(patient, data, k): 50 | cond = patient['labels'] 51 | dists = [(calc_dist(cond, ehr['labels']), ehr['codes']) for ehr in data] 52 | dists.sort(key= lambda x: x[0], reverse=False) 53 | options = [o[1] for o in dists[:k]] 54 | return options 55 | 56 | def calc_attribute_risk(train_dataset, reference_dataset, k): 57 | tp = 0 58 | fp = 0 59 | fn = 0 60 | for p in tqdm(train_dataset): 61 | closest_k = find_closest(p, reference_dataset, k) 62 | pred_codes = set([cd for cd, cnt in Counter([c for p in closest_k for c in p]).items() if cnt > k/2]) 63 | true_pos = len(pred_codes.intersection(p['codes'])) 64 | false_pos = len(pred_codes) - true_pos 65 | false_neg = len(p['codes']) - true_pos 66 | tp += true_pos 67 | fp += false_pos 68 | fn += false_neg 69 | 70 | f1 = tp / (tp + (0.5 * (fp + fn))) 71 | return f1 72 | 73 | K = 1 74 | att_risk = calc_attribute_risk(train_ehr_dataset, synthetic_ehr_dataset, K) 75 | baseline_risk = calc_attribute_risk(train_ehr_dataset, test_ehr_dataset, K) 76 | results = { 77 | "Attribute Attack F1 Score": att_risk, 78 | "Baseline Attack F1 Score": baseline_risk 79 | } 80 | pickle.dump(results, open(f"results/privacy_evaluation/attribute_inference_{key}.pkl", "wb")) 81 | print(results) -------------------------------------------------------------------------------- /baselines/synteg/export_condition.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import pickle 5 | import numpy as np 6 | from tqdm import tqdm 7 | from config import SyntegConfig 8 | from synteg import DependencyModel 9 | from sklearn.model_selection import train_test_split 10 | 11 | SEED = 4 12 | random.seed(SEED) 13 | np.random.seed(SEED) 14 | torch.manual_seed(SEED) 15 | config = SyntegConfig() 16 | 17 | local_rank = -1 18 | fp16 = False 19 | if local_rank == -1: 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | n_gpu = torch.cuda.device_count() 22 | else: 23 | torch.cuda.set_device(local_rank) 24 | device = torch.device("cuda", local_rank) 25 | n_gpu = 1 26 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 27 | torch.distributed.init_process_group(backend='nccl') 28 | if torch.cuda.is_available(): 29 | torch.cuda.manual_seed_all(SEED) 30 | 31 | train_ehr_dataset = pickle.load(open('../../data/trainDataset.pkl', 'rb')) 32 | 33 | def get_batch(loc, batch_size, mode): 34 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 35 | # Where each patient P is [V_1, V_2, ... , V_j] 36 | # Where each visit V is [C_1, C_2, ... , C_k] 37 | # And where each Label L is a binary vector [L_1 ... L_11] 38 | if mode == 'train': 39 | ehr = train_ehr_dataset[loc:loc+batch_size] 40 | elif mode == 'valid': 41 | ehr = val_ehr_dataset[loc:loc+batch_size] 42 | else: 43 | ehr = test_ehr_dataset[loc:loc+batch_size] 44 | 45 | batch_ehr = np.zeros((len(ehr), config.max_num_visit, config.max_length_visit)) 46 | batch_ehr[:,:,:] = config.vocab_dim # Initialize each code to the padding code 47 | batch_lens = np.ones((len(ehr), config.max_num_visit, 1)) 48 | batch_mask = np.zeros((len(ehr), config.max_num_visit, 1)) 49 | batch_num_visits = np.zeros(len(ehr)) 50 | for i, p in enumerate(ehr): 51 | visits = p['visits'] 52 | for j, v in enumerate(visits): 53 | batch_mask[i,j+2] = 1 54 | batch_lens[i,j+2] = len(v) + 1 55 | for k, c in enumerate(v): 56 | batch_ehr[i,j+2,k+1] = c 57 | batch_ehr[i,j+2,len(v)+1] = config.code_vocab_dim + config.label_vocab_dim + 1 # Set the last code in the last visit to be the end record code 58 | batch_lens[i,j+2] = len(v) + 2 59 | for l_idx, l in enumerate(np.nonzero(p['labels'])[0]): 60 | batch_ehr[i,1,l_idx+1] = config.code_vocab_dim + l 61 | batch_lens[i,1] = l_idx+2 62 | batch_num_visits[i] = len(visits) 63 | 64 | batch_mask[:,1] = 1 # Set the mask to cover the labels 65 | batch_ehr[:,:,0] = config.code_vocab_dim + config.label_vocab_dim # Set the first code in each visit to be the start/class token 66 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 67 | return batch_ehr, batch_lens, batch_mask, batch_num_visits 68 | 69 | LR = 1e-4 70 | model = DependencyModel(config).to(device) 71 | optimizer = torch.optim.Adam(model.parameters(), lr=LR) 72 | checkpoint = torch.load("../../save/synteg_dependency_model", map_location=torch.device(device)) 73 | model.load_state_dict(checkpoint['model']) 74 | optimizer.load_state_dict(checkpoint['optimizer']) 75 | 76 | condition_dataset = [] 77 | for i in tqdm(range(0, len(train_ehr_dataset), config.dependency_batchsize)): 78 | model.train() 79 | 80 | batch_ehr, batch_lens, _, batch_num_visits = get_batch(i, config.dependency_batchsize, 'train') 81 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.int).to(device) # bs * visit * code 82 | batch_lens = torch.tensor(batch_lens, dtype=torch.int).to(device) # bs * visit 83 | condition_vector = model(batch_ehr, batch_lens, export=True) # bs * visit * 256 84 | batch_ehr = batch_ehr.detach().cpu().numpy() 85 | condition_vector = condition_vector.detach().cpu().numpy() 86 | 87 | for b, num_visits in enumerate(batch_num_visits-1): 88 | for v in range(int(num_visits+1)): 89 | ehr_tmp = batch_ehr[b, v+1, :] 90 | condition_vector_tmp = condition_vector[b, v, :] 91 | datum = {"ehr": ehr_tmp, "condition": condition_vector_tmp} 92 | condition_dataset.append(datum) 93 | 94 | pickle.dump(condition_dataset, open("data/conditionDataset.pkl", "wb")) 95 | -------------------------------------------------------------------------------- /continuous_variables/generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import pickle 4 | import random 5 | import numpy as np 6 | from sys import argv 7 | from tqdm import tqdm 8 | from discretized_model import HALOModel 9 | from discretized_config import HALOConfig 10 | 11 | config = HALOConfig() 12 | device = torch.device('cuda' if torch.cuda.is_available() else "cpu") 13 | model = HALOModel(config).to(device) 14 | checkpoint = torch.load('./save/halo_model', map_location=torch.device(device)) 15 | model.load_state_dict(checkpoint['model']) 16 | 17 | labelProbs = pickle.load(open('./discretized_data/labelProbs.pkl', 'rb')) 18 | idxToId = pickle.load(open('discretized_data/idxToId.pkl', 'rb')) 19 | idToLab = pickle.load(open('discretized_data/idToLab.pkl', 'rb')) 20 | beginPos = pickle.load(open('discretized_data/beginPos.pkl', 'rb')) 21 | isCategorical = pickle.load(open('discretized_data/isCategorical.pkl', 'rb')) 22 | possible_values = pickle.load(open('discretized_data/possibleValues.pkl', 'rb')) 23 | discretization = pickle.load(open('discretized_data/discretization.pkl', 'rb')) 24 | indexToCode = pickle.load(open('discretized_data/indexToCode.pkl', 'rb')) 25 | idToLabel = pickle.load(open('discretized_data/idToLabel.pkl', 'rb')) 26 | 27 | def sample_sequence(model, length, context, batch_size, device='cuda', sample=True): 28 | empty = torch.zeros((1,1,config.total_vocab_size), device=device, dtype=torch.float32).repeat(batch_size, 1, 1) 29 | context = torch.tensor(context, device=device, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1) 30 | prev = context.unsqueeze(1) 31 | context = None 32 | with torch.no_grad(): 33 | for _ in range(length-1): 34 | prev = model.sample(torch.cat((prev,empty), dim=1), sample) 35 | if torch.sum(torch.sum(prev[:,:,config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size+config.label_vocab_size+1], dim=1).bool().int(), dim=0).item() == batch_size: 36 | break 37 | ehr = prev.cpu().detach().numpy() 38 | prev = None 39 | empty = None 40 | return ehr 41 | 42 | def convert_ehr(ehrs, index_to_code=None): 43 | ehr_outputs = [] 44 | for i in range(len(ehrs)): 45 | ehr = ehrs[i] 46 | ehr_output = [] 47 | 48 | labels_output = ehr[1][config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size:config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size+config.label_vocab_size] 49 | if index_to_code is not None: 50 | labels_output = [idToLabel[idx] for idx in np.nonzero(labels_output)[0]] 51 | 52 | for j in range(2, len(ehr)): 53 | visit = ehr[j] 54 | visit_output = [] 55 | lab_mask = [] 56 | lab_values = [] 57 | cont_idx = -1 58 | indices = np.nonzero(visit)[0] 59 | end = False 60 | for idx in indices: 61 | if idx < config.code_vocab_size: 62 | visit_output.append(index_to_code[idx] if index_to_code is not None else idx) 63 | elif idx < config.code_vocab_size+config.lab_vocab_size: 64 | lab_idx = idx - (config.code_vocab_size) 65 | lab_num = idxToId[lab_idx] 66 | if lab_num in lab_mask: 67 | continue 68 | else: 69 | lab_mask.append(lab_num) 70 | lab_values.append(lab_idx - beginPos[lab_num]) 71 | elif idx < config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size: 72 | cont_idx = cont_idx if cont_idx != -1 else idx - (config.code_vocab_size+config.lab_vocab_size) 73 | elif idx == config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size+config.label_vocab_size+1: 74 | end = True 75 | 76 | if cont_idx == -1: 77 | cont_idx = random.randint(0, config.continuous_vocab_size) - 1 78 | if visit_output != [] or lab_mask != []: 79 | ehr_output.append((visit_output, lab_mask, lab_values, [cont_idx])) 80 | if end: 81 | break 82 | 83 | ehr_outputs.append({'visits': ehr_output, 'labels': labels_output}) 84 | ehr = None 85 | ehr_output = None 86 | labels_output = None 87 | visit = None 88 | visit_output = None 89 | indices = None 90 | return ehr_outputs 91 | 92 | # Generate Synthetic EHR dataset 93 | totEHRs = len(pickle.load(open('data/trainDataset.pkl', 'rb'))) 94 | stoken = np.zeros(config.total_vocab_size) 95 | stoken[config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size+config.label_vocab_size] = 1 96 | synthetic_ehr_dataset = [] 97 | for i in tqdm(range(0, totEHRs, config.sample_batch_size)): 98 | bs = min([totEHRs-i, config.sample_batch_size]) 99 | batch_synthetic_ehrs = sample_sequence(model, config.n_ctx, stoken, batch_size=bs, device=device, sample=True) 100 | batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs) 101 | synthetic_ehr_dataset += batch_synthetic_ehrs 102 | 103 | pickle.dump(synthetic_ehr_dataset, open(f'./results/datasets/haloDataset.pkl', 'wb')) 104 | -------------------------------------------------------------------------------- /baselines/lstm/train_lstm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import pickle 5 | import numpy as np 6 | from tqdm import tqdm 7 | from lstm import LSTMBaseline 8 | from config import LSTMConfig 9 | 10 | SEED = 4 11 | random.seed(SEED) 12 | np.random.seed(SEED) 13 | torch.manual_seed(SEED) 14 | config = LSTMConfig() 15 | 16 | local_rank = -1 17 | fp16 = False 18 | if local_rank == -1: 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | n_gpu = torch.cuda.device_count() 21 | else: 22 | torch.cuda.set_device(local_rank) 23 | device = torch.device("cuda", local_rank) 24 | n_gpu = 1 25 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 26 | torch.distributed.init_process_group(backend='nccl') 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(SEED) 29 | 30 | train_ehr_dataset = pickle.load(open('../../data/trainDataset.pkl', 'rb')) 31 | val_ehr_dataset = pickle.load(open('../../data/valDataset.pkl', 'rb')) 32 | 33 | def get_batch(loc, batch_size, mode): 34 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 35 | # Where each patient P is [V_1, V_2, ... , V_j] 36 | # Where each visit V is [C_1, C_2, ... , C_k] 37 | # And where each Label L is a binary vector [L_1 ... L_n] 38 | if mode == 'train': 39 | ehr = train_ehr_dataset[loc:loc+batch_size] 40 | elif mode == 'valid': 41 | ehr = val_ehr_dataset[loc:loc+batch_size] 42 | else: 43 | ehr = test_ehr_dataset[loc:loc+batch_size] 44 | 45 | batch_ehr = np.zeros((len(ehr), config.n_ctx, config.total_vocab_size)) 46 | batch_mask = np.zeros((len(ehr), config.n_ctx, 1)) 47 | for i, p in enumerate(ehr): 48 | visits = p['visits'] 49 | for j, v in enumerate(visits): 50 | batch_ehr[i,j+2][v] = 1 51 | batch_mask[i,j+2] = 1 52 | batch_ehr[i,1,config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] = np.array(p['labels']) # Set the patient labels 53 | batch_ehr[i,len(visits)+1,config.code_vocab_size+config.label_vocab_size+1] = 1 # Set the final visit to have the end token 54 | batch_ehr[i,len(visits)+2:,config.code_vocab_size+config.label_vocab_size+2] = 1 # Set the rest to the padded visit token 55 | 56 | batch_mask[:,1] = 1 # Set the mask to cover the labels 57 | batch_ehr[:,0,config.code_vocab_size+config.label_vocab_size] = 1 # Set the first visits to be the start token 58 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 59 | return batch_ehr, batch_mask 60 | 61 | def shuffle_training_data(train_ehr_dataset): 62 | np.random.shuffle(train_ehr_dataset) 63 | 64 | model = LSTMBaseline(config).to(device) 65 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 66 | if os.path.exists("../../save/lstm_model"): 67 | print("Loading previous model") 68 | checkpoint = torch.load('../../save/lstm_model', map_location=torch.device(device)) 69 | model.load_state_dict(checkpoint['model']) 70 | optimizer.load_state_dict(checkpoint['optimizer']) 71 | 72 | # Train 73 | global_loss = 1e10 74 | for e in tqdm(range(config.epoch)): 75 | shuffle_training_data(train_ehr_dataset) 76 | for i in range(0, len(train_ehr_dataset), config.batch_size): 77 | model.train() 78 | 79 | batch_ehr, batch_mask = get_batch(i, config.batch_size, 'train') 80 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 81 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 82 | 83 | optimizer.zero_grad() 84 | loss, _, _ = model(batch_ehr, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight) 85 | loss.backward() 86 | optimizer.step() 87 | 88 | if i % (100*config.batch_size) == 0: 89 | print("Epoch %d, Iter %d: Training Loss:%.6f"%(e, i, loss)) 90 | if i % (100*config.batch_size) == 0: 91 | if i == 0: 92 | continue 93 | 94 | model.eval() 95 | with torch.no_grad(): 96 | val_l = [] 97 | for v_i in range(0, len(val_ehr_dataset), config.batch_size): 98 | batch_ehr, batch_mask = get_batch(v_i, config.batch_size, 'valid') 99 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 100 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 101 | 102 | val_loss, _, _ = model(batch_ehr, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight) 103 | val_l.append((val_loss).cpu().detach().numpy()) 104 | 105 | cur_val_loss = np.mean(val_l) 106 | print("Epoch %d Validation Loss:%.7f"%(e, cur_val_loss)) 107 | if cur_val_loss < global_loss: 108 | global_loss = cur_val_loss 109 | state = { 110 | 'model': model.state_dict(), 111 | 'optimizer': optimizer.state_dict(), 112 | 'iteration': i 113 | } 114 | torch.save(state, '../../save/lstm_model') 115 | print('\n------------ Save best model ------------\n') -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import pickle 6 | from tqdm import tqdm 7 | from model import HALOModel 8 | from config import HALOConfig 9 | 10 | SEED = 4 11 | random.seed(SEED) 12 | np.random.seed(SEED) 13 | torch.manual_seed(SEED) 14 | config = HALOConfig() 15 | 16 | local_rank = -1 17 | fp16 = False 18 | if local_rank == -1: 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | n_gpu = torch.cuda.device_count() 21 | else: 22 | torch.cuda.set_device(local_rank) 23 | device = torch.device("cuda", local_rank) 24 | n_gpu = 1 25 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 26 | torch.distributed.init_process_group(backend='nccl') 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(SEED) 29 | 30 | train_ehr_dataset = pickle.load(open('./data/trainDataset.pkl', 'rb')) 31 | val_ehr_dataset = pickle.load(open('./data/valDataset.pkl', 'rb')) 32 | 33 | def get_batch(loc, batch_size, mode): 34 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 35 | # Where each patient P is [V_1, V_2, ... , V_j] 36 | # Where each visit V is [C_1, C_2, ... , C_k] 37 | # And where each Label L is a binary vector [L_1 ... L_n] 38 | if mode == 'train': 39 | ehr = train_ehr_dataset[loc:loc+batch_size] 40 | elif mode == 'valid': 41 | ehr = val_ehr_dataset[loc:loc+batch_size] 42 | else: 43 | ehr = test_ehr_dataset[loc:loc+batch_size] 44 | 45 | batch_ehr = np.zeros((len(ehr), config.n_ctx, config.total_vocab_size)) 46 | batch_mask = np.zeros((len(ehr), config.n_ctx, 1)) 47 | for i, p in enumerate(ehr): 48 | visits = p['visits'] 49 | for j, v in enumerate(visits): 50 | batch_ehr[i,j+2][v] = 1 51 | batch_mask[i,j+2] = 1 52 | batch_ehr[i,1,config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] = np.array(p['labels']) # Set the patient labels 53 | batch_ehr[i,len(visits)+1,config.code_vocab_size+config.label_vocab_size+1] = 1 # Set the final visit to have the end token 54 | batch_ehr[i,len(visits)+2:,config.code_vocab_size+config.label_vocab_size+2] = 1 # Set the rest to the padded visit token 55 | 56 | batch_mask[:,1] = 1 # Set the mask to cover the labels 57 | batch_ehr[:,0,config.code_vocab_size+config.label_vocab_size] = 1 # Set the first visits to be the start token 58 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 59 | return batch_ehr, batch_mask 60 | 61 | def shuffle_training_data(train_ehr_dataset): 62 | np.random.shuffle(train_ehr_dataset) 63 | 64 | model = HALOModel(config).to(device) 65 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 66 | if os.path.exists("./save/halo_model"): 67 | print("Loading previous model") 68 | checkpoint = torch.load('./save/halo_model', map_location=torch.device(device)) 69 | model.load_state_dict(checkpoint['model']) 70 | optimizer.load_state_dict(checkpoint['optimizer']) 71 | 72 | # Train 73 | global_loss = 1e10 74 | for e in tqdm(range(config.epoch)): 75 | shuffle_training_data(train_ehr_dataset) 76 | for i in range(0, len(train_ehr_dataset), config.batch_size): 77 | model.train() 78 | 79 | batch_ehr, batch_mask = get_batch(i, config.batch_size, 'train') 80 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 81 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 82 | 83 | optimizer.zero_grad() 84 | loss, _, _ = model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight) 85 | loss.backward() 86 | optimizer.step() 87 | 88 | if i % (500*config.batch_size) == 0: 89 | print("Epoch %d, Iter %d: Training Loss:%.6f"%(e, i, loss * 8)) 90 | if i % (500*config.batch_size) == 0: 91 | if i == 0: 92 | continue 93 | 94 | model.eval() 95 | with torch.no_grad(): 96 | val_l = [] 97 | for v_i in range(0, len(val_ehr_dataset), config.batch_size): 98 | batch_ehr, batch_mask = get_batch(v_i, config.batch_size, 'valid') 99 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 100 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 101 | 102 | val_loss, _, _ = model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight) 103 | val_l.append((val_loss).cpu().detach().numpy()) 104 | 105 | cur_val_loss = np.mean(val_l) 106 | print("Epoch %d Validation Loss:%.7f"%(e, cur_val_loss)) 107 | if cur_val_loss < global_loss: 108 | global_loss = cur_val_loss 109 | state = { 110 | 'model': model.state_dict(), 111 | 'optimizer': optimizer.state_dict(), 112 | 'iteration': i 113 | } 114 | torch.save(state, './save/halo_model') 115 | print('\n------------ Save best model ------------\n') -------------------------------------------------------------------------------- /baselines/haloCoarse/train_gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import pickle 6 | from tqdm import tqdm 7 | from gpt import HALOCoarseModel 8 | from config import HALOCoarseConfig 9 | 10 | SEED = 4 11 | random.seed(SEED) 12 | np.random.seed(SEED) 13 | torch.manual_seed(SEED) 14 | config = HALOCoarseConfig() 15 | 16 | local_rank = -1 17 | fp16 = False 18 | if local_rank == -1: 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | n_gpu = torch.cuda.device_count() 21 | else: 22 | torch.cuda.set_device(local_rank) 23 | device = torch.device("cuda", local_rank) 24 | n_gpu = 1 25 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 26 | torch.distributed.init_process_group(backend='nccl') 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(SEED) 29 | 30 | train_ehr_dataset = pickle.load(open('../../data/trainDataset.pkl', 'rb')) 31 | val_ehr_dataset = pickle.load(open('../../data/valDataset.pkl', 'rb')) 32 | 33 | def get_batch(loc, batch_size, mode): 34 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 35 | # Where each patient P is [V_1, V_2, ... , V_j] 36 | # Where each visit V is [C_1, C_2, ... , C_k] 37 | # And where each Label L is a binary vector [L_1 ... L_n] 38 | if mode == 'train': 39 | ehr = train_ehr_dataset[loc:loc+batch_size] 40 | elif mode == 'valid': 41 | ehr = val_ehr_dataset[loc:loc+batch_size] 42 | else: 43 | ehr = test_ehr_dataset[loc:loc+batch_size] 44 | 45 | batch_ehr = np.zeros((len(ehr), config.n_ctx, config.total_vocab_size)) 46 | batch_mask = np.zeros((len(ehr), config.n_ctx, 1)) 47 | for i, p in enumerate(ehr): 48 | visits = p['visits'] 49 | for j, v in enumerate(visits): 50 | batch_ehr[i,j+2][v] = 1 51 | batch_mask[i,j+2] = 1 52 | batch_ehr[i,1,config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] = np.array(p['labels']) # Set the patient labels 53 | batch_ehr[i,len(visits)+1,config.code_vocab_size+config.label_vocab_size+1] = 1 # Set the final visit to have the end token 54 | batch_ehr[i,len(visits)+2:,config.code_vocab_size+config.label_vocab_size+2] = 1 # Set the rest to the padded visit token 55 | 56 | batch_mask[:,1] = 1 # Set the mask to cover the labels 57 | batch_ehr[:,0,config.code_vocab_size+config.label_vocab_size] = 1 # Set the first visits to be the start token 58 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 59 | return batch_ehr, batch_mask 60 | 61 | def shuffle_training_data(train_ehr_dataset): 62 | np.random.shuffle(train_ehr_dataset) 63 | 64 | model = HALOCoarseModel(config).to(device) 65 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 66 | if os.path.exists("../../save/haloCoarse_model"): 67 | print("Loading previous model") 68 | checkpoint = torch.load('../../save/haloCoarse_model', map_location=torch.device(device)) 69 | model.load_state_dict(checkpoint['model']) 70 | optimizer.load_state_dict(checkpoint['optimizer']) 71 | model.set_tied() 72 | 73 | # Train 74 | global_loss = 1e10 75 | for e in tqdm(range(config.epoch)): 76 | shuffle_training_data(train_ehr_dataset) 77 | for i in range(0, len(train_ehr_dataset), config.batch_size): 78 | model.train() 79 | 80 | batch_ehr, batch_mask = get_batch(i, config.batch_size, 'train') 81 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 82 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 83 | 84 | optimizer.zero_grad() 85 | loss, _, _ = model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight) 86 | loss.backward() 87 | optimizer.step() 88 | 89 | if i % (100*config.batch_size) == 0: 90 | print("Epoch %d, Iter %d: Training Loss:%.6f"%(e, i, loss)) 91 | if i % (100*config.batch_size) == 0: 92 | if i == 0: 93 | continue 94 | 95 | model.eval() 96 | with torch.no_grad(): 97 | val_l = [] 98 | for v_i in range(0, len(val_ehr_dataset), config.batch_size): 99 | batch_ehr, batch_mask = get_batch(v_i, config.batch_size, 'valid') 100 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 101 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 102 | 103 | val_loss, _, _ = model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight) 104 | val_l.append((val_loss).cpu().detach().numpy()) 105 | 106 | cur_val_loss = np.mean(val_l) 107 | print("Epoch %d Validation Loss:%.7f"%(e, cur_val_loss)) 108 | if cur_val_loss < global_loss: 109 | global_loss = cur_val_loss 110 | state = { 111 | 'model': model.state_dict(), 112 | 'optimizer': optimizer.state_dict(), 113 | 'iteration': i 114 | } 115 | torch.save(state, '../../save/haloCoarse_model') 116 | print('\n------------ Save best model ------------\n') -------------------------------------------------------------------------------- /baselines/gpt/train_gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import pickle 6 | from tqdm import tqdm 7 | from gpt import GPTModel 8 | from config import GPTConfig 9 | 10 | SEED = 4 11 | random.seed(SEED) 12 | np.random.seed(SEED) 13 | torch.manual_seed(SEED) 14 | config = GPTConfig() 15 | 16 | local_rank = -1 17 | fp16 = False 18 | if local_rank == -1: 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | n_gpu = torch.cuda.device_count() 21 | else: 22 | torch.cuda.set_device(local_rank) 23 | device = torch.device("cuda", local_rank) 24 | n_gpu = 1 25 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 26 | torch.distributed.init_process_group(backend='nccl') 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(SEED) 29 | 30 | orig_train_ehr_dataset = pickle.load(open('../../data/trainDataset.pkl', 'rb')) 31 | orig_val_ehr_dataset = pickle.load(open('../../data/valDataset.pkl', 'rb')) 32 | 33 | train_ehr_dataset = [] 34 | for orig_ehr in orig_train_ehr_dataset: 35 | new_ehr = [config.total_vocab_size - 1] * config.n_ctx # Pad Codes 36 | new_ehr[0] = config.code_vocab_size + config.label_vocab_size # Start Record 37 | idx = 1 38 | 39 | # Add Labels 40 | for l in orig_ehr['labels'].nonzero()[0]: 41 | new_ehr[idx] = l + config.code_vocab_size 42 | idx += 1 43 | 44 | new_ehr[idx] = config.code_vocab_size + config.label_vocab_size + 1 # End Labels 45 | idx += 1 46 | 47 | # Add Visits 48 | for v in orig_ehr['visits']: 49 | for c in v: 50 | new_ehr[idx] = c 51 | idx += 1 52 | new_ehr[idx] = config.code_vocab_size + config.label_vocab_size + 2 # End Visit 53 | idx += 1 54 | 55 | new_ehr[idx] = config.code_vocab_size + config.label_vocab_size + 3 # End Record 56 | train_ehr_dataset.append(new_ehr) 57 | 58 | val_ehr_dataset = [] 59 | for orig_ehr in orig_val_ehr_dataset: 60 | new_ehr = [config.total_vocab_size - 1] * config.n_ctx # Pad Codes 61 | new_ehr[0] = config.code_vocab_size + config.label_vocab_size # Start Record 62 | idx = 1 63 | 64 | # Add Labels 65 | for l in orig_ehr['labels'].nonzero()[0]: 66 | new_ehr[idx] = l + config.code_vocab_size 67 | idx += 1 68 | 69 | new_ehr[idx] = config.code_vocab_size + config.label_vocab_size + 1 # End Labels 70 | idx += 1 71 | 72 | # Add Visits 73 | for v in orig_ehr['visits']: 74 | for c in v: 75 | new_ehr[idx] = c 76 | idx += 1 77 | new_ehr[idx] = config.code_vocab_size + config.label_vocab_size + 2 # End Visit 78 | idx += 1 79 | 80 | new_ehr[idx] = config.code_vocab_size + config.label_vocab_size + 3 # End Record 81 | val_ehr_dataset.append(new_ehr) 82 | 83 | def get_batch(loc, batch_size, mode): 84 | if mode == 'train': 85 | ehr = train_ehr_dataset[loc:loc+batch_size] 86 | elif mode == 'valid': 87 | ehr = val_ehr_dataset[loc:loc+batch_size] 88 | else: 89 | ehr = test_ehr_dataset[loc:loc+batch_size] 90 | 91 | batch_ehr = np.array(ehr) 92 | return batch_ehr 93 | 94 | def shuffle_training_data(train_ehr_dataset): 95 | np.random.shuffle(train_ehr_dataset) 96 | 97 | model = GPTModel(config).to(device) 98 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 99 | if os.path.exists("../../save/gpt_model"): 100 | print("Loading previous model") 101 | checkpoint = torch.load('../../save/gpt_model', map_location=torch.device(device)) 102 | model.load_state_dict(checkpoint['model']) 103 | optimizer.load_state_dict(checkpoint['optimizer']) 104 | model.set_tied() 105 | 106 | # Train 107 | global_loss = 1e10 108 | for e in tqdm(range(config.epoch)): 109 | shuffle_training_data(train_ehr_dataset) 110 | for i in range(0, len(train_ehr_dataset), config.batch_size): 111 | model.train() 112 | 113 | batch_ehr = get_batch(i, config.batch_size, 'train') 114 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.long).to(device) 115 | 116 | optimizer.zero_grad() 117 | loss, _, _ = model(batch_ehr, position_ids=None, ehr_labels=batch_ehr) 118 | loss.backward() 119 | optimizer.step() 120 | 121 | if i % (100*config.batch_size) == 0: 122 | print("Epoch %d, Iter %d: Training Loss:%.6f"%(e, i, loss)) 123 | if i % (250*config.batch_size) == 0: 124 | if i == 0: 125 | continue 126 | 127 | model.eval() 128 | with torch.no_grad(): 129 | val_l = [] 130 | for v_i in range(0, len(val_ehr_dataset), config.batch_size): 131 | batch_ehr = get_batch(v_i, config.batch_size, 'valid') 132 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.long).to(device) 133 | 134 | val_loss, _, _ = model(batch_ehr, position_ids=None, ehr_labels=batch_ehr) 135 | val_l.append((val_loss).cpu().detach().numpy()) 136 | 137 | cur_val_loss = np.mean(val_l) 138 | print("Epoch %d Validation Loss:%.7f"%(e, cur_val_loss)) 139 | if cur_val_loss < global_loss: 140 | global_loss = cur_val_loss 141 | state = { 142 | 'model': model.state_dict(), 143 | 'optimizer': optimizer.state_dict(), 144 | 'iteration': i 145 | } 146 | torch.save(state, '../../save/gpt_model') 147 | print('\n------------ Save best model ------------\n') -------------------------------------------------------------------------------- /baselines/eva/train_eva.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import pickle 5 | import numpy as np 6 | from tqdm import tqdm 7 | from eva import Eva 8 | from config import EVAConfig 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 11 | 12 | SEED = 4 13 | random.seed(SEED) 14 | np.random.seed(SEED) 15 | torch.manual_seed(SEED) 16 | config = EVAConfig() 17 | 18 | local_rank = -1 19 | fp16 = False 20 | if local_rank == -1: 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | n_gpu = torch.cuda.device_count() 23 | else: 24 | torch.cuda.set_device(local_rank) 25 | device = torch.device("cuda", local_rank) 26 | n_gpu = 1 27 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 28 | torch.distributed.init_process_group(backend='nccl') 29 | if torch.cuda.is_available(): 30 | torch.cuda.manual_seed_all(SEED) 31 | 32 | train_ehr_dataset = pickle.load(open('../../data/trainDataset.pkl', 'rb')) 33 | val_ehr_dataset = pickle.load(open('../../data/valDataset.pkl', 'rb')) 34 | 35 | def get_batch(loc, batch_size, mode): 36 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 37 | # Where each patient P is [V_1, V_2, ... , V_j] 38 | # Where each visit V is [C_1, C_2, ... , C_k] 39 | # And where each Label L is a binary vector [L_1 ... L_n] 40 | if mode == 'train': 41 | ehr = train_ehr_dataset[loc:loc+batch_size] 42 | elif mode == 'valid': 43 | ehr = val_ehr_dataset[loc:loc+batch_size] 44 | else: 45 | ehr = test_ehr_dataset[loc:loc+batch_size] 46 | 47 | batch_ehr = np.zeros((len(ehr), config.n_ctx, config.total_vocab_size)) 48 | batch_mask = np.zeros((len(ehr), config.n_ctx, 1)) 49 | batch_lens = np.zeros(len(ehr)) 50 | for i, p in enumerate(ehr): 51 | visits = p['visits'] 52 | batch_lens[i] = len(visits) 53 | for j, v in enumerate(visits): 54 | batch_ehr[i,j+2][v] = 1 55 | batch_mask[i,j+2] = 1 56 | batch_ehr[i,1,config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] = np.array(p['labels']) # Set the patient labels 57 | batch_ehr[i,len(visits)+1,config.code_vocab_size+config.label_vocab_size+1] = 1 # Set the final visit to have the end token 58 | batch_ehr[i,len(visits)+2:,config.code_vocab_size+config.label_vocab_size+2] = 1 # Set the rest to the padded visit token 59 | 60 | batch_mask[:,1] = 1 # Set the mask to cover the labels 61 | batch_ehr[:,0,config.code_vocab_size+config.label_vocab_size] = 1 # Set the first visits to be the start token 62 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 63 | return batch_ehr, batch_lens, batch_mask 64 | 65 | def shuffle_training_data(train_ehr_dataset): 66 | np.random.shuffle(train_ehr_dataset) 67 | 68 | model = Eva(config).to(device) 69 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 70 | if os.path.exists("../../save/eva_model"): 71 | print("Loading previous model") 72 | checkpoint = torch.load('../../save/eva_model', map_location=torch.device(device)) 73 | model.load_state_dict(checkpoint['model']) 74 | optimizer.load_state_dict(checkpoint['optimizer']) 75 | 76 | # Train 77 | global_loss = 1e10 78 | kl_schedule = [0.1, 0.15, 0.25, 0.325, 0.5, 0.75, 0.9, 1.0, 1.0, 1.0] 79 | for e in tqdm(range(config.epoch)): 80 | klw = kl_schedule[e] if e < len(kl_schedule) else 1 81 | shuffle_training_data(train_ehr_dataset) 82 | for i in range(0, len(train_ehr_dataset), config.batch_size): 83 | model.train() 84 | 85 | batch_ehr, batch_lens, batch_mask = get_batch(i, config.batch_size, 'train') 86 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 87 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 88 | 89 | optimizer.zero_grad() 90 | loss, _, _ = model(batch_ehr, batch_lens, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight, kl_weight=klw) 91 | loss.backward() 92 | optimizer.step() 93 | 94 | if i % (100*config.batch_size) == 0: 95 | print("Epoch %d, Iter %d: Training Loss:%.6f"%(e, i, loss)) 96 | if i % (200*config.batch_size) == 0: 97 | if i == 0: 98 | continue 99 | 100 | model.eval() 101 | with torch.no_grad(): 102 | val_l = [] 103 | for v_i in range(0, len(val_ehr_dataset), config.batch_size): 104 | batch_ehr, batch_lens, batch_mask = get_batch(v_i, config.batch_size, 'valid') 105 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 106 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 107 | 108 | val_loss, _, _ = model(batch_ehr, batch_lens, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight, kl_weight=klw) 109 | val_l.append((val_loss).cpu().detach().numpy()) 110 | 111 | cur_val_loss = np.mean(val_l) 112 | print("Epoch %d Validation Loss:%.7f"%(e, cur_val_loss)) 113 | if cur_val_loss < global_loss: 114 | global_loss = cur_val_loss 115 | state = { 116 | 'model': model.state_dict(), 117 | 'optimizer': optimizer.state_dict(), 118 | 'iteration': i 119 | } 120 | torch.save(state, '../../save/eva_model') 121 | print('\n------------ Save best model ------------\n') -------------------------------------------------------------------------------- /baselines/evaluate_privacy_membership.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import itertools 5 | import numpy as np 6 | from tqdm import tqdm 7 | import torch.nn as nn 8 | from sklearn import metrics 9 | from model import HALOModel 10 | import matplotlib.pyplot as plt 11 | from config import HALOConfig 12 | from scipy.spatial.distance import hamming 13 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 14 | 15 | SEED = 4 16 | random.seed(SEED) 17 | np.random.seed(SEED) 18 | torch.manual_seed(SEED) 19 | LR = 0.00001 20 | EPOCHS = 50 21 | BATCH_SIZE = 512 22 | LSTM_HIDDEN_DIM = 32 23 | EMBEDDING_DIM = 64 24 | NUM_TEST_EXAMPLES = 7500 25 | NUM_TOT_EXAMPLES = 7500 26 | NUM_VAL_EXAMPLES = 2500 27 | 28 | key = 'haloCoarse' 29 | 30 | local_rank = -1 31 | fp16 = False 32 | if local_rank == -1: 33 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 34 | n_gpu = torch.cuda.device_count() 35 | else: 36 | torch.cuda.set_device(local_rank) 37 | device = torch.device("cuda", local_rank) 38 | n_gpu = 1 39 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 40 | torch.distributed.init_process_group(backend='nccl') 41 | if torch.cuda.is_available(): 42 | torch.cuda.manual_seed_all(SEED) 43 | 44 | config = HALOConfig() 45 | train_ehr_dataset = pickle.load(open('./data/trainDataset.pkl', 'rb')) 46 | train_ehr_dataset = [(p,1) for p in train_ehr_dataset] 47 | test_ehr_dataset = pickle.load(open('./data/testDataset.pkl', 'rb')) 48 | test_ehr_dataset = [(p,0) for p in test_ehr_dataset] 49 | synthetic_ehr_dataset = pickle.load(open(f'./results/datasets/{key}Dataset.pkl', 'rb')) 50 | synthetic_ehr_dataset = [p for p in synthetic_ehr_dataset if len(p['visits']) > 0] 51 | 52 | attack_dataset_pos = list(random.sample(train_ehr_dataset, NUM_TOT_EXAMPLES)) 53 | attack_dataset_neg = list(random.sample(test_ehr_dataset, NUM_TOT_EXAMPLES)) 54 | np.random.shuffle(attack_dataset_pos) 55 | np.random.shuffle(attack_dataset_neg) 56 | test_attack_dataset = attack_dataset_pos[:NUM_TEST_EXAMPLES] + attack_dataset_neg[:NUM_TEST_EXAMPLES] 57 | val_attack_dataset = attack_dataset_pos[NUM_TEST_EXAMPLES:NUM_TEST_EXAMPLES+NUM_VAL_EXAMPLES] + attack_dataset_neg[NUM_TEST_EXAMPLES:NUM_TEST_EXAMPLES+NUM_VAL_EXAMPLES] 58 | np.random.shuffle(test_attack_dataset) 59 | np.random.shuffle(val_attack_dataset) 60 | attack_dataset_pos = attack_dataset_pos[NUM_TEST_EXAMPLES+NUM_VAL_EXAMPLES:] 61 | attack_dataset_neg = attack_dataset_neg[NUM_TEST_EXAMPLES+NUM_VAL_EXAMPLES:] 62 | 63 | def get_batch(loc, batch_size, dataset): 64 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 65 | # Where each patient P is [V_1, V_2, ... , V_j] 66 | # Where each visit V is [C_1, C_2, ... , C_k] 67 | # And where each Label L is a binary vector [L_1 ... L_11] 68 | ehr = dataset[loc:loc+batch_size] 69 | attack_labels = [l for (e,l) in ehr] 70 | ehr = [e for (e,l) in ehr] 71 | 72 | batch_ehr = np.zeros((len(ehr), config.n_ctx, config.total_vocab_size)) 73 | batch_mask = np.zeros((len(ehr), config.n_ctx, 1)) 74 | 75 | for i, p in enumerate(ehr): 76 | visits = p['visits'] 77 | for j, v in enumerate(visits): 78 | batch_ehr[i,j+2][v] = 1 79 | batch_mask[i,j+2] = 1 80 | batch_ehr[i,1,config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] = np.array(p['labels']) # Set the patient labels 81 | batch_ehr[i,len(visits)+1,config.code_vocab_size+config.label_vocab_size+1] = 1 # Set the final visit to have the end token 82 | batch_ehr[i,len(visits)+2:,config.code_vocab_size+config.label_vocab_size+2] = 1 # Set the rest to the padded visit token 83 | 84 | batch_mask[:,1] = 1 # Set the mask to cover the labels 85 | batch_ehr[:,0,config.code_vocab_size+config.label_vocab_size] = 1 # Set the first visits to be the start token 86 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 87 | return batch_ehr, batch_mask, attack_labels 88 | 89 | 90 | 91 | def find_hamming(ehr, dataset): 92 | min_d = 1e10 93 | visits = ehr['visits'] 94 | labels = ehr['labels'] 95 | for p in dataset: 96 | d = 0 if len(visits) == len(p['visits']) else 1 97 | l = p['labels'] 98 | d += ((labels + l) == 1).sum() 99 | for i in range(len(visits)): 100 | v = visits[i] 101 | if i >= len(p['visits']): 102 | d += len(v) 103 | else: 104 | v2 = p['visits'][i] 105 | d += len(v) + len(v2) - (2 * len(set(v) & set(v2))) 106 | 107 | min_d = d if d < min_d else min_d 108 | return min_d 109 | 110 | 111 | 112 | # Perform the Hamming Distance experiment 113 | ds = [(find_hamming(ehr, synthetic_ehr_dataset), l) for (ehr, l) in tqdm(test_attack_dataset)] 114 | median_dist = np.median([d for (d,l) in ds]) 115 | preds = [1 if d < median_dist else 0 for (d,l) in ds] 116 | labels = [l for (d,l) in ds] 117 | results = { 118 | "Accuracy": metrics.accuracy_score(labels, preds), 119 | "Precision": metrics.precision_score(labels, preds), 120 | "Recall": metrics.recall_score(labels, preds), 121 | "F1": metrics.f1_score(labels, preds) 122 | } 123 | pickle.dump(results, open(f"results/privacy_evaluation/hamming_model_{key}.pkl", "wb")) 124 | print(results) -------------------------------------------------------------------------------- /baselines/synteg/condition_simulation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import random 5 | import pickle 6 | import numpy as np 7 | from tqdm import tqdm 8 | from config import SyntegConfig 9 | import torch.nn.functional as F 10 | from synteg import Generator, Discriminator 11 | from torch.autograd import grad as torch_grad 12 | 13 | SEED = 4 14 | random.seed(SEED) 15 | np.random.seed(SEED) 16 | torch.manual_seed(SEED) 17 | config = SyntegConfig() 18 | 19 | local_rank = -1 20 | fp16 = False 21 | if local_rank == -1: 22 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 23 | n_gpu = torch.cuda.device_count() 24 | else: 25 | torch.cuda.set_device(local_rank) 26 | device = torch.device("cuda", local_rank) 27 | n_gpu = 1 28 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 29 | torch.distributed.init_process_group(backend='nccl') 30 | if torch.cuda.is_available(): 31 | torch.cuda.manual_seed_all(SEED) 32 | 33 | condition_dataset = pickle.load(open('data/conditionDataset.pkl', 'rb')) 34 | 35 | def get_batch(loc, batch_size): 36 | data = condition_dataset[loc:loc+batch_size] 37 | visits = [d['ehr'] for d in data] 38 | conditions = [d['condition'] for d in data] 39 | visits = torch.tensor(visits, dtype=torch.int64).to(device) 40 | conditions = torch.tensor(conditions).to(device) 41 | return (visits, conditions) 42 | 43 | def shuffle_dataset(dataset): 44 | np.random.shuffle(dataset) 45 | 46 | EPOCHS = 600 47 | generator = Generator(config).to(device) 48 | discriminator = Discriminator(config).to(device) 49 | generator_optimizer = torch.optim.Adam(generator.parameters(), lr=4e-6, weight_decay=1e-5) 50 | discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=2e-5, weight_decay=1e-5) 51 | if os.path.exists("../../save/synteg_condition_model"): 52 | print("Loading previous model") 53 | checkpoint = torch.load("../../save/synteg_condition_model", map_location=torch.device(device)) 54 | generator.load_state_dict(checkpoint['generator']) 55 | generator_optimizer.load_state_dict(checkpoint['generator_optimizer']) 56 | discriminator.load_state_dict(checkpoint['discriminator']) 57 | discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer']) 58 | 59 | def d_step(visits, conditions): 60 | discriminator.train() 61 | generator.eval() 62 | discriminator_optimizer.zero_grad() 63 | 64 | real = visits 65 | z = torch.randn((len(visits), config.z_dim)).to(device) 66 | epsilon = torch.rand((len(visits), 1)).to(device) 67 | 68 | synthetic = generator(z, conditions) 69 | real_output = discriminator(real, conditions) 70 | fake_output = discriminator(synthetic, conditions) 71 | w_distance = -torch.mean(real_output) + torch.mean(fake_output) 72 | 73 | interpolate = real + epsilon * (synthetic - real) 74 | interpolate_output = discriminator(interpolate, conditions) 75 | 76 | gradients = torch_grad(outputs=interpolate_output, inputs=interpolate, 77 | grad_outputs=torch.ones(interpolate_output.size()).to(device), 78 | create_graph=True, retain_graph=True)[0] 79 | gradients = gradients.view(len(visits), -1) 80 | gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) 81 | gradient_penalty = config.gp_weight * ((gradients_norm - 1) ** 2).mean() 82 | 83 | disc_loss = gradient_penalty + w_distance 84 | disc_loss.backward() 85 | discriminator_optimizer.step() 86 | return disc_loss, w_distance 87 | 88 | def g_step(conditions): 89 | z = torch.randn((len(conditions), config.z_dim)).to(device) 90 | generator.train() 91 | discriminator.eval() 92 | generator_optimizer.zero_grad() 93 | synthetic = generator(z, conditions) 94 | fake_output = discriminator(synthetic, conditions) 95 | gen_loss = -torch.mean(fake_output) 96 | gen_loss.backward() 97 | generator_optimizer.step() 98 | 99 | def train_step(batch): 100 | visits, conditions = batch # bs * codes, bs * condition 101 | visits = torch.sum(F.one_hot(visits, num_classes=config.vocab_dim+1), dim=-2)[:,:-1] # bs * vocab 102 | disc_loss, w_distance = d_step(visits, conditions) 103 | g_step(conditions) 104 | return disc_loss, w_distance 105 | 106 | print('training start') 107 | for e in tqdm(range(EPOCHS)): 108 | total_loss = 0 109 | total_w = 0 110 | step = 0 111 | shuffle_dataset(condition_dataset) 112 | for i in range(0, len(condition_dataset), config.gan_batchsize): 113 | batch = get_batch(i, config.gan_batchsize) 114 | loss, w = train_step(batch) 115 | total_loss += loss 116 | total_w += w 117 | step += 1 118 | format_str = 'epoch: %d, loss = %f, w = %f' 119 | print(format_str % (e, -total_loss / step, -total_w / step)) 120 | if e % 50 == 49: 121 | state = { 122 | 'generator': generator.state_dict(), 123 | 'generator_optimizer': generator_optimizer.state_dict(), 124 | 'discriminator': discriminator.state_dict(), 125 | 'discriminator_optimizer': discriminator_optimizer.state_dict(), 126 | 'epoch': e 127 | } 128 | torch.save(state, '../../save/synteg_condition_model') 129 | print('\n------------ Save newest model ------------\n') -------------------------------------------------------------------------------- /baselines/gpt/test_gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from sklearn import metrics 7 | from gpt import GPTModel 8 | from config import GPTConfig 9 | import torch.nn.functional as F 10 | 11 | SEED = 4 12 | random.seed(SEED) 13 | np.random.seed(SEED) 14 | torch.manual_seed(SEED) 15 | config = GPTConfig() 16 | 17 | local_rank = -1 18 | fp16 = False 19 | if local_rank == -1: 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | n_gpu = torch.cuda.device_count() 22 | else: 23 | torch.cuda.set_device(local_rank) 24 | device = torch.device("cuda", local_rank) 25 | n_gpu = 1 26 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 27 | torch.distributed.init_process_group(backend='nccl') 28 | if torch.cuda.is_available(): 29 | torch.cuda.manual_seed_all(SEED) 30 | 31 | train_ehr_dataset = pickle.load(open('../../data/trainDataset.pkl', 'rb')) 32 | index_to_code = pickle.load(open("../../data/indexToCode.pkl", "rb")) 33 | 34 | # Add the labels to the index_to_code mapping 35 | index_to_code[config.code_vocab_size] = "Chronic Condition: Alzheimer or related disorders or senile" 36 | index_to_code[config.code_vocab_size+1] = "Chronic Condition: Heart Failure" 37 | index_to_code[config.code_vocab_size+2] = "Chronic Condition: Chronic Kidney Disease" 38 | index_to_code[config.code_vocab_size+3] = "Chronic Condition: Cancer" 39 | index_to_code[config.code_vocab_size+4] = "Chronic Condition: Chronic Obstructive Pulmonary Disease" 40 | index_to_code[config.code_vocab_size+5] = "Chronic Condition: Depression" 41 | index_to_code[config.code_vocab_size+6] = "Chronic Condition: Diabetes" 42 | index_to_code[config.code_vocab_size+7] = "Chronic Condition: Ischemic Heart Disease" 43 | index_to_code[config.code_vocab_size+8] = "Chronic Condition: Osteoporosis" 44 | index_to_code[config.code_vocab_size+9] = "Chronic Condition: rheumatoid arthritis and osteoarthritis (RA/OA)" 45 | index_to_code[config.code_vocab_size+10] = "Chronic Condition: Stroke/transient Ischemic Attack" 46 | 47 | model = GPTModel(config).to(device) 48 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 49 | 50 | checkpoint = torch.load('../../save/gpt_model', map_location=torch.device(device)) 51 | model.load_state_dict(checkpoint['model']) 52 | optimizer.load_state_dict(checkpoint['optimizer']) 53 | 54 | def sample_sequence(model, length, context, batch_size=None, device='cuda', sample=True): 55 | context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) 56 | prev = context 57 | ehr = context 58 | past = None 59 | with torch.no_grad(): 60 | for _ in range(length): 61 | code_logits, past = model(prev, past=past) 62 | code_logits = code_logits[:, -1, :] 63 | log_probs = F.softmax(code_logits, dim=-1) 64 | if sample: 65 | prev = torch.multinomial(log_probs, num_samples=1) 66 | else: 67 | prev = torch.argmax(log_probs, dim=1) 68 | ehr = torch.cat((ehr, prev), dim=1) 69 | 70 | if all([config.code_vocab_size + config.label_vocab_size + 3 in ehr[i] for i in range(batch_size)]): # early stopping 71 | break 72 | ehr = ehr.cpu().detach().numpy() 73 | next = None 74 | prev = None 75 | return ehr 76 | 77 | def convert_ehr(ehrs, index_to_code=None): 78 | ehr_outputs = [] 79 | for i in range(len(ehrs)): 80 | ehr = ehrs[i] 81 | ehr_output = [] 82 | visit_output = [] 83 | labels_output = np.zeros(config.label_vocab_size) 84 | started_visits = False 85 | for j in range(1, len(ehr)): 86 | code = ehr[j] 87 | if not started_visits: 88 | if code == config.code_vocab_size + config.label_vocab_size + 1: 89 | started_visits = True 90 | elif code >= config.code_vocab_size and code < config.code_vocab_size + config.label_vocab_size: 91 | labels_output[code - config.code_vocab_size] = 1 92 | 93 | else: 94 | if code < config.code_vocab_size: 95 | if code not in visit_output: 96 | visit_output.append(index_to_code[code] if index_to_code is not None else code) 97 | elif code == config.code_vocab_size + config.label_vocab_size + 2: 98 | if visit_output != []: 99 | ehr_output.append(visit_output) 100 | visit_output = [] 101 | elif code == config.code_vocab_size + config.label_vocab_size + 3: 102 | break 103 | 104 | if visit_output != []: 105 | ehr_output.append(visit_output) 106 | 107 | if index_to_code is not None: 108 | labels_output = [index_to_code[idx + config.code_vocab_size] for idx in np.nonzero(labels_output)[0]] 109 | 110 | ehr_outputs.append({'visits': ehr_output, 'labels': labels_output}) 111 | ehr = None 112 | ehr_output = None 113 | labels_output = None 114 | visit_output = None 115 | return ehr_outputs 116 | 117 | # Generate Synthetic EHR dataset 118 | synthetic_ehr_dataset = [] 119 | stoken = [config.code_vocab_size+config.label_vocab_size] 120 | for i in tqdm(range(0, len(train_ehr_dataset), 2*config.batch_size)): 121 | bs = min([len(train_ehr_dataset)-i, 2*config.batch_size]) 122 | batch_synthetic_ehrs = sample_sequence(model, config.n_ctx, stoken, batch_size=bs, device=device, sample=True) 123 | batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs) 124 | synthetic_ehr_dataset += batch_synthetic_ehrs 125 | 126 | pickle.dump(synthetic_ehr_dataset, open(f'../../results/datasets/gptDataset.pkl', 'wb')) -------------------------------------------------------------------------------- /continuous_variables/train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import pickle 5 | import numpy as np 6 | from tqdm import tqdm 7 | from model import HALOModel 8 | from config import HALOConfig 9 | from collections import Counter 10 | 11 | SEED = 4 12 | random.seed(SEED) 13 | np.random.seed(SEED) 14 | torch.manual_seed(SEED) 15 | config = HALOConfig() 16 | 17 | local_rank = -1 18 | fp16 = False 19 | if local_rank == -1: 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | n_gpu = torch.cuda.device_count() 22 | else: 23 | torch.cuda.set_device(local_rank) 24 | device = torch.device("cuda", local_rank) 25 | n_gpu = 1 26 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 27 | torch.distributed.init_process_group(backend='nccl') 28 | if torch.cuda.is_available(): 29 | torch.cuda.manual_seed_all(SEED) 30 | 31 | train_ehr_dataset = pickle.load(open('discretized_data/trainDataset.pkl', 'rb')) 32 | val_ehr_dataset = pickle.load(open('discretized_data/valDataset.pkl', 'rb')) 33 | 34 | # Convert to fully codes 35 | beginPos = pickle.load(open('discretized_data/beginPos.pkl', 'rb')) 36 | for p in (train_ehr_dataset + val_ehr_dataset): 37 | new_visits = [] 38 | for v in p['visits']: 39 | new_idx = v[0] 40 | for l, val in zip(v[1], v[2]): 41 | new_idx.append(config.code_vocab_size + beginPos[l] + val) 42 | new_idx.append(config.code_vocab_size + config.lab_vocab_size + v[3][-1]) 43 | new_visits.append(new_idx) 44 | 45 | p['visits'] = new_visits 46 | 47 | labelCounts = Counter([tuple(p['labels']) for p in train_ehr_dataset]) 48 | tot = len(train_ehr_dataset) 49 | labelProbs = {l: c / tot for (l, c) in labelCounts.items()} 50 | pickle.dump(labelProbs, open('data/labelProbs.pkl', 'wb')) 51 | 52 | def get_batch(loc, batch_size, mode): 53 | if mode == 'train': 54 | ehr = train_ehr_dataset[loc:loc+batch_size] 55 | elif mode == 'valid': 56 | ehr = val_ehr_dataset[loc:loc+batch_size] 57 | else: 58 | ehr = test_ehr_dataset[loc:loc+batch_size] 59 | 60 | batch_ehr = np.zeros((len(ehr), config.n_ctx, config.total_vocab_size)) 61 | batch_mask = np.zeros((len(ehr), config.n_ctx, 1)) 62 | for i, p in enumerate(ehr): 63 | visits = p['visits'] 64 | for j, v in enumerate(visits): 65 | batch_ehr[i,j+2][v] = 1 66 | batch_mask[i,j+2] = 1 67 | batch_ehr[i,1,config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size:config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size+config.label_vocab_size] = np.array(p['labels']) # Set the patient labels 68 | batch_ehr[i,len(visits)+1,config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size+config.label_vocab_size+1] = 1 # Set the final visit to have the end token 69 | batch_ehr[i,len(visits)+2:,config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size+config.label_vocab_size+2] = 1 # Set the rest to the padded visit token 70 | 71 | batch_mask[:,1] = 1 # Set the mask to cover the labels 72 | batch_ehr[:,0,config.code_vocab_size+config.lab_vocab_size+config.continuous_vocab_size+config.label_vocab_size] = 1 # Set the first visits to be the start token 73 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 74 | return batch_ehr, batch_mask 75 | 76 | def shuffle_training_data(train_ehr_dataset): 77 | np.random.shuffle(train_ehr_dataset) 78 | 79 | model = HALOModel(config).to(device) 80 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 81 | if os.path.exists("./save/halo_model"): 82 | print("Loading previous model") 83 | checkpoint = torch.load('./save/halo_model', map_location=torch.device(device)) 84 | model.load_state_dict(checkpoint['model']) 85 | optimizer.load_state_dict(checkpoint['optimizer']) 86 | 87 | # Train Model 88 | global_loss = 1e10 89 | for e in tqdm(range(config.epoch)): 90 | shuffle_training_data(train_ehr_dataset) 91 | for i in range(0, len(train_ehr_dataset), config.batch_size): 92 | model.train() 93 | 94 | batch_ehr, batch_mask = get_batch(i, config.batch_size, 'train') 95 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 96 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 97 | 98 | optimizer.zero_grad() 99 | loss, _, _ = model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask) 100 | loss.backward() 101 | optimizer.step() 102 | 103 | if i % (50*config.batch_size) == 0: 104 | print("Epoch %d, Iter %d: Training Loss:%.6f"%(e, i, loss)) 105 | if i % (250*config.batch_size) == 0: 106 | if i == 0: 107 | continue 108 | 109 | model.eval() 110 | with torch.no_grad(): 111 | val_l = [] 112 | for v_i in range(0, len(val_ehr_dataset), config.batch_size): 113 | batch_ehr, batch_mask = get_batch(v_i, config.batch_size, 'valid') 114 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 115 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 116 | 117 | val_loss, _, _ = model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask) 118 | val_l.append((val_loss).cpu().detach().numpy()) 119 | 120 | cur_val_loss = np.mean(val_l) 121 | print("Epoch %d Validation Loss:%.7f"%(e, cur_val_loss)) 122 | if cur_val_loss < global_loss: 123 | global_loss = cur_val_loss 124 | state = { 125 | 'model': model.state_dict(), 126 | 'optimizer': optimizer.state_dict(), 127 | 'iteration': i 128 | } 129 | torch.save(state, './save/halo_model') 130 | print('\n------------ Save best model ------------\n') 131 | -------------------------------------------------------------------------------- /continuous_variables/pickleToCSV.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from collections import defaultdict 4 | import csv 5 | 6 | 7 | def flatten_data(subject_id, data, indexToCode, idToLabel, d_icd_diagnoses, d_icd_procedures): 8 | '''Function to flatten data such that each row consists of one visit per subject 9 | ''' 10 | flattened = [] 11 | visit_number = 1 12 | 13 | # Ensure ICD9_CODE in both dataframes is string type 14 | d_icd_procedures['ICD9_CODE'] = d_icd_procedures['ICD9_CODE'].astype(str) 15 | d_icd_diagnoses['ICD9_CODE'] = d_icd_diagnoses['ICD9_CODE'].astype(str) 16 | 17 | for i, visit in enumerate(data['visits']): 18 | diagnosis_codes, lab_names, lab_values, time_since_last = visit 19 | 20 | # The first "visit gap" is the patient's age (as there's no previous visit to have a gap from) 21 | if i == 0: 22 | age = time_since_last[0] 23 | time_since_last = np.nan 24 | else: 25 | time_since_last = time_since_last[0] 26 | 27 | # Convert the diagnosis codes to true medical codes 28 | all_codes = [indexToCode[code] if code in indexToCode else code for code in diagnosis_codes] 29 | 30 | # Process labels 31 | medication_labels = [] 32 | diagnosis_labels = [] 33 | procedure_labels = [] 34 | 35 | for code in all_codes: 36 | code_str = str(code) # Convert all codes to string for consistency 37 | 38 | # Check diagnoses first (to catch E-codes and other alphanumeric codes) 39 | matching_diag = d_icd_diagnoses[d_icd_diagnoses['ICD9_CODE'] == code_str] 40 | if not matching_diag.empty: 41 | diagnosis_labels.append(matching_diag['SHORT_TITLE'].iloc[0]) 42 | continue 43 | 44 | # If not in diagnoses, check procedures 45 | matching_proc = d_icd_procedures[d_icd_procedures['ICD9_CODE'] == code_str] 46 | if not matching_proc.empty: 47 | procedure_labels.append(matching_proc['SHORT_TITLE'].iloc[0]) 48 | continue 49 | 50 | # If it's a numeric string, try to match without leading zeros 51 | if code_str.replace(".", "").isdigit(): 52 | code_int = str(int(float(code_str))) # Remove leading zeros 53 | matching_diag = d_icd_diagnoses[d_icd_diagnoses['ICD9_CODE'] == code_int] 54 | if not matching_diag.empty: 55 | diagnosis_labels.append(matching_diag['SHORT_TITLE'].iloc[0]) 56 | else: 57 | matching_proc = d_icd_procedures[d_icd_procedures['ICD9_CODE'] == code_int] 58 | if not matching_proc.empty: 59 | procedure_labels.append(matching_proc['SHORT_TITLE'].iloc[0]) 60 | else: 61 | # If it's not a numeric code and not found in diagnoses or procedures, assume it's a medication 62 | medication_labels.append(code_str) 63 | 64 | row = defaultdict(str) 65 | row['Visit Number'] = visit_number 66 | row['Age'] = age 67 | row['Time Since Last Visit'] = time_since_last 68 | row['Medication Labels'] = ', '.join(medication_labels) 69 | row['Diagnosis Labels'] = ', '.join(diagnosis_labels) 70 | row['Procedure Labels'] = ', '.join(procedure_labels) 71 | row['Subject ID'] = subject_id 72 | for name, value in zip(lab_names, lab_values): 73 | row[name] = value 74 | 75 | for i, label in enumerate(data['labels']): 76 | label_name = None 77 | if i in idToLabel: 78 | label_name = idToLabel[i] 79 | row[label_name] = label 80 | 81 | flattened.append(row) 82 | visit_number += 1 83 | return flattened 84 | 85 | 86 | def write_to_csv(data, output_file): 87 | '''Function to update column names and save each row to .csv 88 | ''' 89 | fieldnames = set() 90 | for subject_data in data: 91 | for row in subject_data: 92 | fieldnames.update(row.keys()) 93 | fieldnames = {str(item) for item in fieldnames} 94 | fieldnames = sorted(list(fieldnames)) 95 | 96 | with open(output_file, 'w', newline = '') as csvfile: 97 | writer = csv.DictWriter(csvfile, fieldnames = fieldnames) 98 | writer.writeheader() 99 | for subject_data in data: 100 | for row in subject_data: 101 | writer.writerow((row)) 102 | 103 | def main(): 104 | path_to_pkl = "./continuous_variables/results/datasets/haloDataset_convertedv3.pkl" #path to file generated after running discretized_convert.py 105 | subjects_data = pd.read_pickle(path_to_pkl) 106 | indexToCode = pd.read_pickle("./continuous_variables/data/indexToCode.pkl") #path to indexToCode file generated by genDatasetContinuous.py 107 | idToLabel = pd.read_pickle('./continuous_variables/data/idToLabel.pkl') #path to idToLabel file generated by genDatasetContinuous.py 108 | d_icd_diagnoses = pd.read_csv('./continuous_variables/data/D_ICD_DIAGNOSES.csv') #path to ICD9 diagnoses codes (to be downloaded and extracted for MIMIC-III from physionet) 109 | d_icd_procedures = pd.read_csv('./continuous_variables/data/D_ICD_PROCEDURES.csv') #path to ICD9 procedure codes (to be downloaded and extracted for MIMIC-III from physionet) 110 | 111 | all_subjects_data = [] 112 | for subject_id, subject_data in enumerate(subjects_data, start = 1): #Assign subject IDs starting with 1 113 | flattened_data = flatten_data(subject_id, subject_data, indexToCode, idToLabel, d_icd_diagnoses, d_icd_procedures) 114 | all_subjects_data.append(flattened_data) 115 | 116 | write_to_csv(all_subjects_data, './continuous_variables/results/datasets/haloDataset_convertedv3.csv') #path to save final .csv 117 | 118 | main() 119 | -------------------------------------------------------------------------------- /baselines/synteg/generate_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import pickle 5 | import numpy as np 6 | from tqdm import tqdm 7 | from config import SyntegConfig 8 | from synteg import Generator, DependencyModel 9 | 10 | SEED = 4 11 | random.seed(SEED) 12 | np.random.seed(SEED) 13 | torch.manual_seed(SEED) 14 | config = SyntegConfig() 15 | 16 | local_rank = -1 17 | fp16 = False 18 | if local_rank == -1: 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | n_gpu = torch.cuda.device_count() 21 | else: 22 | torch.cuda.set_device(local_rank) 23 | device = torch.device("cuda", local_rank) 24 | n_gpu = 1 25 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 26 | torch.distributed.init_process_group(backend='nccl') 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(SEED) 29 | 30 | train_ehr_dataset = pickle.load(open('../../data/trainDataset.pkl', 'rb')) 31 | index_to_code = pickle.load(open("../../data/indexToCode.pkl", "rb")) 32 | id_to_label = pickle.load(open("../../data/idToLabel.pkl", "rb")) 33 | model = DependencyModel(config).to(device) 34 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 35 | generator = Generator(config).to(device) 36 | generator_optimizer = torch.optim.Adam(generator.parameters(), lr=4e-6, weight_decay=1e-5) 37 | checkpoint1 = torch.load("../../save/synteg_dependency_model", map_location=torch.device(device)) 38 | checkpoint2 = torch.load("../../save/synteg_condition_model", map_location=torch.device(device)) 39 | model.load_state_dict(checkpoint1['model']) 40 | optimizer.load_state_dict(checkpoint1['optimizer']) 41 | generator.load_state_dict(checkpoint2['generator']) 42 | generator_optimizer.load_state_dict(checkpoint2['generator_optimizer']) 43 | 44 | # Add the labels to the index_to_code mapping 45 | for k, l in id_to_label.items(): 46 | index_to_code[config.code_vocab_dim+k] = f"Chronic Condition: {l}" 47 | 48 | def sample_sequence(length, context, batch_size, device='cuda'): 49 | context = torch.tensor(context, device=device, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1).to(device) 50 | ehr = context.unsqueeze(1).to(device) 51 | batch_ehr = torch.tensor(np.ones((batch_size, config.max_num_visit, config.max_length_visit)) * config.vocab_dim, dtype=torch.long).to(device) 52 | batch_ehr[:,0,0] = config.code_vocab_dim + config.label_vocab_dim 53 | batch_lens = torch.zeros((batch_size, config.max_num_visit, 1), dtype=torch.int).to(device) 54 | batch_lens[:,0,0] = 1 55 | with torch.no_grad(): 56 | for j in range(length-1): 57 | for i in range(batch_size): 58 | codes = torch.nonzero(ehr[i,j]).squeeze(1) 59 | batch_ehr[i,j,0:len(codes)] = codes[0:config.max_length_visit] 60 | batch_lens[i,j] = min(len(codes), config.max_length_visit) 61 | condition_vector = model(batch_ehr, batch_lens, export=True) 62 | condition = condition_vector[:,j,:] 63 | z = torch.randn((batch_size, config.z_dim)).to(device) 64 | visit = generator(z, condition) 65 | visit = torch.bernoulli(visit).unsqueeze(1) 66 | ehr = torch.cat((ehr, visit), dim=1) 67 | ehr = ehr.cpu().detach().numpy() 68 | return ehr 69 | 70 | def convert_ehr(ehrs, index_to_code=None): 71 | ehr_outputs = [] 72 | for i in range(len(ehrs)): 73 | ehr = ehrs[i] 74 | ehr_output = [] 75 | labels_output = ehr[1][config.code_vocab_dim:config.code_vocab_dim+config.label_vocab_dim] 76 | if index_to_code is not None: 77 | labels_output = [index_to_code[idx + config.code_vocab_dim] for idx in np.nonzero(labels_output)[0]] 78 | for j in range(2, len(ehr)): 79 | visit = ehr[j] 80 | visit_output = [] 81 | indices = np.nonzero(visit)[0] 82 | end = False 83 | for idx in indices: 84 | if idx < config.code_vocab_dim: 85 | visit_output.append(index_to_code[idx] if index_to_code is not None else idx) 86 | elif idx == config.code_vocab_dim+config.label_vocab_dim+1: 87 | end = True 88 | if visit_output != []: 89 | ehr_output.append(visit_output) 90 | if end: 91 | break 92 | ehr_outputs.append({'visits': ehr_output, 'labels': labels_output}) 93 | ehr = None 94 | ehr_output = None 95 | labels_output = None 96 | visit = None 97 | visit_output = None 98 | indices = None 99 | return ehr_outputs 100 | 101 | # Generate a few sampled EHR for examinations 102 | #stoken = np.zeros(config.vocab_dim) 103 | #synthetic_ehrs = sample_sequence(config.max_num_visit, stoken, batch_size=3, device=device) 104 | #synthetic_ehrs = convert_ehr(synthetic_ehrs, index_to_code) 105 | #print("Sampled Synthetic EHRs: ") 106 | #for i in range(3): 107 | # print("Labels: ") 108 | # print(synthetic_ehrs[i]['labels']) 109 | # print("Visits: ") 110 | # for v in synthetic_ehrs[i]['visits']: 111 | # print(v) 112 | # print("\n\n") 113 | 114 | # Generate Synthetic EHR dataset 115 | synthetic_ehr_dataset = [] 116 | count = 0 117 | stoken = np.zeros(config.vocab_dim) 118 | for i in tqdm(range(0, len(train_ehr_dataset), config.dependency_batchsize)): 119 | bs = min([len(train_ehr_dataset)-i, config.dependency_batchsize]) 120 | batch_synthetic_ehrs = sample_sequence(config.max_num_visit, stoken, batch_size=bs, device=device) 121 | batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs) 122 | synthetic_ehr_dataset += batch_synthetic_ehrs 123 | if len(synthetic_ehr_dataset) > 10000: 124 | pickle.dump(synthetic_ehr_dataset, open(f'../../temp_synteg/syntegDataset_{count}.pkl', 'wb')) 125 | synthetic_ehr_dataset = [] 126 | count += 1 127 | 128 | pickle.dump(synthetic_ehr_dataset, open(f'../../temp_synteg/syntegDataset_{count}.pkl', 'wb')) -------------------------------------------------------------------------------- /continuous_variables/discretize.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | trainData = pickle.load(open('data/trainDataset.pkl', 'rb')) 4 | valData = pickle.load(open('data/valDataset.pkl', 'rb')) 5 | 6 | idToLab = pickle.load(open('data/idx_to_lab.pkl', 'rb')) 7 | labToNumber = {l: i for (i,l) in enumerate(pickle.load(open('data/id_to_channel.pkl', 'rb')))} 8 | isCategorical = pickle.load(open('data/is_categorical_channel.pkl', 'rb')) 9 | beginPos = pickle.load(open('data/begin_pos.pkl', 'rb')) 10 | possibleValues = pickle.load(open('data/possible_values.pkl', 'rb')) 11 | variableRanges = pickle.load(open('data/variable_ranges.pkl', 'rb')) 12 | 13 | discretization = { 14 | 'Diastolic blood pressure': [0, 40, 50, 60, 65, 70, 75, 80, 85, 90, 95, 100, 105, 110, 120, 130, 375], 15 | 'Fraction inspired oxygen': [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.001, 1.1], 16 | 'Glucose': [0, 40, 60, 80, 100, 110, 120, 130, 140, 150, 160, 170, 180, 200, 225, 275, 325, 400, 600, 800, 1000, 2200], 17 | 'Heart Rate': [0, 40, 50, 60, 70, 80, 90, 100, 110, 120, 140, 160, 180, 200, 390], 18 | 'Height': [0, 145, 150, 155, 160, 165, 170, 175, 180, 185, 190, 195, 230], 19 | 'Mean blood pressure': [0, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 180, 200, 375], 20 | 'Oxygen saturation': [0, 30, 40, 50, 55, 60, 65, 70, 75, 80, 85, 90, 100, 100.001, 150], 21 | 'pH': [6.3, 6.7, 7.1, 7.35, 7.45, 7.6, 8.0, 8.3, 10], 22 | 'Respiratory rate': [0, 6, 8, 10, 12, 14, 16, 18, 20, 25, 30, 35, 330], 23 | 'Systolic blood pressure': [0, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200, 210, 230, 375], 24 | 'Temperature': [14.2, 30, 32, 33, 33.5, 34, 34.5, 35, 35.5, 36, 36.5, 37, 37.5, 38, 38.5, 39, 39.5, 40, 47], 25 | 'Weight': [0, 30, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100, 105, 110, 115, 120, 125, 130, 135, 140, 145, 150, 160, 170, 190, 210, 250], 26 | 'Age': [18, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90], 27 | 'Days': [0, 11, 16, 21, 25, 30.1, 35.1, 43, 48, 54, 60, 66, 72, 81, 90, 100.1], 28 | 'Hours': [0, 0.5, 1.5, 2.5, 3.5, 6.5, 10.5, 16.5, 26.5, 48.0, 48.1, 60.1, 80.1, 110.1, 150.1, 200.1] 29 | } 30 | 31 | formatMap = { 32 | 'Diastolic blood pressure': ('.0f', int), 33 | 'Fraction inspired oxygen': ('.2f', float), 34 | 'Glucose': ('.0f', int), 35 | 'Heart Rate': ('.0f', int), 36 | 'Height': ('.0f', int), 37 | 'Mean blood pressure': ('.0f', int), 38 | 'Oxygen saturation': ('.0f', int), 39 | 'pH': ('.2f', float), 40 | 'Respiratory rate': ('.0f', int), 41 | 'Systolic blood pressure': ('.0f', int), 42 | 'Temperature': ('.1f', float), 43 | 'Weight': ('.1f', float), 44 | 'Age': ('.2f', float), 45 | 'Days': ('.2f', float), 46 | 'Hours': ('.1f', float) 47 | } 48 | 49 | def get_index(mapping, key, value): 50 | possible_values = mapping[key] 51 | for i in range(len(possible_values) - 1): 52 | if value < possible_values[i + 1]: 53 | return i 54 | 55 | print(f"{value} for {key} not in {possible_values}") 56 | return len(possible_values) - 2 57 | 58 | # Convert to New Data Format 59 | for p in (trainData + valData): 60 | new_visits = [] 61 | firstVisit = True 62 | for v in p['visits']: 63 | if v[1] == []: 64 | new_cont = get_index(discretization, 'Age' if firstVisit else 'Days', v[3][-1]) 65 | firstVisit = False 66 | new_visits.append((v[0], [], [], [new_cont])) 67 | else: 68 | new_labs = [] 69 | new_values = [] 70 | for l, val in zip(v[1], v[2]): 71 | if isCategorical[idToLab[l]]: 72 | if val == 1: 73 | new_labs.append(labToNumber[idToLab[l]]) 74 | new_values.append(beginPos[labToNumber[idToLab[l]]] - l) 75 | else: 76 | if val < variableRanges[idToLab[l]][0] or val >= variableRanges[idToLab[l]][1]: 77 | continue 78 | 79 | new_labs.append(labToNumber[idToLab[l]]) 80 | new_values.append(get_index(discretization, idToLab[l], val)) 81 | 82 | if not new_labs: 83 | continue 84 | new_cont = get_index(discretization, 'Hours', v[3][-1]) 85 | new_visits.append((v[0], new_labs, new_values, [new_cont])) 86 | 87 | p['visits'] = new_visits 88 | 89 | pickle.dump(trainData, open('discretized_data/trainDataset.pkl', 'wb')) 90 | pickle.dump(valData, open('discretized_data/valDataset.pkl', 'wb')) 91 | 92 | newIdToLab = {i:l for (l,i) in labToNumber.items()} 93 | newBeginPos = [] 94 | seenContinuous = False 95 | for i in range(len(newIdToLab)): 96 | if not seenContinuous: 97 | newBeginPos.append(beginPos[i]) 98 | if not isCategorical[newIdToLab[i]]: 99 | seenContinuous = True 100 | currPos = newBeginPos[i] + len(discretization[newIdToLab[i]]) - 1 101 | else: 102 | newBeginPos.append(currPos) 103 | currPos += len(discretization[newIdToLab[i]]) - 1 104 | 105 | newIdxToId = {} 106 | for i in range(len(newBeginPos) - 1): 107 | for j in range(newBeginPos[i], newBeginPos[i+1]): 108 | newIdxToId[j] = i 109 | for j in range(newBeginPos[-1], newBeginPos[-1] + len(discretization[newIdToLab[len(newBeginPos) - 1]]) - 1): 110 | newIdxToId[j] = len(newBeginPos) - 1 111 | 112 | pickle.dump(newIdxToId, open('discretized_data/idxToId.pkl', 'wb')) 113 | pickle.dump(formatMap, open('discretized_data/formatMap.pkl', 'wb')) 114 | pickle.dump(newIdToLab, open('discretized_data/idToLab.pkl', 'wb')) 115 | pickle.dump(newBeginPos, open('discretized_data/beginPos.pkl', 'wb')) 116 | pickle.dump(isCategorical, open('discretized_data/isCategorical.pkl', 'wb')) 117 | pickle.dump(possibleValues, open('discretized_data/possibleValues.pkl', 'wb')) 118 | pickle.dump(discretization, open('discretized_data/discretization.pkl', 'wb')) 119 | 120 | print(f"NUM LABS: {newBeginPos[-1] + len(discretization[newIdToLab[16]]) - 1}") 121 | print(f"NUM CONTINUOUS: {len(discretization['Age']) - 1}") 122 | -------------------------------------------------------------------------------- /baselines/eva/test_eva.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from sklearn import metrics 7 | from eva import Eva 8 | from config import EVAConfig 9 | 10 | SEED = 4 11 | random.seed(SEED) 12 | np.random.seed(SEED) 13 | torch.manual_seed(SEED) 14 | config = EVAConfig() 15 | 16 | local_rank = -1 17 | fp16 = False 18 | if local_rank == -1: 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | n_gpu = torch.cuda.device_count() 21 | else: 22 | torch.cuda.set_device(local_rank) 23 | device = torch.device("cuda", local_rank) 24 | n_gpu = 1 25 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 26 | torch.distributed.init_process_group(backend='nccl') 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(SEED) 29 | 30 | train_ehr_dataset = pickle.load(open('../../data/trainDataset.pkl', 'rb')) 31 | test_ehr_dataset = pickle.load(open('../../data/testDataset.pkl', 'rb')) 32 | index_to_code = pickle.load(open("../../data/indexToCode.pkl", "rb")) 33 | id_to_label = pickle.load(open("../../data/idToLabel.pkl", "rb")) 34 | train_c = set([c for p in train_ehr_dataset for v in p['visits'] for c in v]) 35 | test_ehr_dataset = [{'labels': p['labels'], 'visits': [[c for c in v if c in train_c] for v in p['visits']]} for p in test_ehr_dataset] 36 | 37 | # Add the labels to the index_to_code mapping 38 | for k, l in id_to_label.items(): 39 | index_to_code[config.code_vocab_size+k] = f"Chronic Condition: {l}" 40 | 41 | # Add the labels to the index_to_code mapping 42 | index_to_code[config.code_vocab_size] = "Chronic Condition: Alzheimer or related disorders or senile" 43 | index_to_code[config.code_vocab_size+1] = "Chronic Condition: Heart Failure" 44 | index_to_code[config.code_vocab_size+2] = "Chronic Condition: Chronic Kidney Disease" 45 | index_to_code[config.code_vocab_size+3] = "Chronic Condition: Cancer" 46 | index_to_code[config.code_vocab_size+4] = "Chronic Condition: Chronic Obstructive Pulmonary Disease" 47 | index_to_code[config.code_vocab_size+5] = "Chronic Condition: Depression" 48 | index_to_code[config.code_vocab_size+6] = "Chronic Condition: Diabetes" 49 | index_to_code[config.code_vocab_size+7] = "Chronic Condition: Ischemic Heart Disease" 50 | index_to_code[config.code_vocab_size+8] = "Chronic Condition: Osteoporosis" 51 | index_to_code[config.code_vocab_size+9] = "Chronic Condition: rheumatoid arthritis and osteoarthritis (RA/OA)" 52 | index_to_code[config.code_vocab_size+10] = "Chronic Condition: Stroke/transient Ischemic Attack" 53 | 54 | def get_batch(loc, batch_size, mode): 55 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 56 | # Where each patient P is [V_1, V_2, ... , V_j] 57 | # Where each visit V is [C_1, C_2, ... , C_k] 58 | # And where each Label L is a binary vector [L_1 ... L_n] 59 | if mode == 'train': 60 | ehr = train_ehr_dataset[loc:loc+batch_size] 61 | elif mode == 'valid': 62 | ehr = val_ehr_dataset[loc:loc+batch_size] 63 | else: 64 | ehr = test_ehr_dataset[loc:loc+batch_size] 65 | 66 | batch_ehr = np.zeros((len(ehr), config.n_ctx, config.total_vocab_size)) 67 | batch_mask = np.zeros((len(ehr), config.n_ctx, 1)) 68 | batch_lens = np.zeros(len(ehr)) 69 | for i, p in enumerate(ehr): 70 | visits = p['visits'] 71 | batch_lens[i] = len(visits) 72 | for j, v in enumerate(visits): 73 | batch_ehr[i,j+2][v] = 1 74 | batch_mask[i,j+2] = 1 75 | batch_ehr[i,1,config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] = np.array(p['labels']) # Set the patient labels 76 | batch_ehr[i,len(visits)+1,config.code_vocab_size+config.label_vocab_size+1] = 1 # Set the final visit to have the end token 77 | batch_ehr[i,len(visits)+2:,config.code_vocab_size+config.label_vocab_size+2] = 1 # Set the rest to the padded visit token 78 | 79 | batch_mask[:,1] = 1 # Set the mask to cover the labels 80 | batch_ehr[:,0,config.code_vocab_size+config.label_vocab_size] = 1 # Set the first visits to be the start token 81 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 82 | return batch_ehr, batch_lens, batch_mask 83 | 84 | model = Eva(config).to(device) 85 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 86 | 87 | checkpoint = torch.load('../../save/eva_model', map_location=torch.device(device)) 88 | model.load_state_dict(checkpoint['model']) 89 | optimizer.load_state_dict(checkpoint['optimizer']) 90 | 91 | def convert_ehr(ehrs, index_to_code=None): 92 | ehr_outputs = [] 93 | for i in range(len(ehrs)): 94 | ehr = ehrs[i] 95 | ehr_output = [] 96 | labels_output = ehr[0][config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] 97 | if index_to_code is not None: 98 | labels_output = [index_to_code[idx + config.code_vocab_size] for idx in np.nonzero(labels_output)[0]] 99 | for j in range(1, len(ehr)): 100 | visit = ehr[j] 101 | visit_output = [] 102 | indices = np.nonzero(visit) 103 | if len(indices) > 0: 104 | indices = indices[0] 105 | else: 106 | continue 107 | end = False 108 | for idx in indices: 109 | if idx < config.code_vocab_size: 110 | visit_output.append(index_to_code[idx] if index_to_code is not None else idx) 111 | elif idx == config.code_vocab_size+config.label_vocab_size+1: 112 | end = True 113 | if visit_output != []: 114 | ehr_output.append(visit_output) 115 | if end: 116 | break 117 | ehr_outputs.append({'visits': ehr_output, 'labels': labels_output}) 118 | ehr = None 119 | ehr_output = None 120 | labels_output = None 121 | visit = None 122 | visit_output = None 123 | indices = None 124 | return ehr_outputs 125 | 126 | # Generate Synthetic EHR dataset 127 | synthetic_ehr_dataset = [] 128 | for i in tqdm(range(0, len(train_ehr_dataset), config.batch_size)): 129 | bs = min([len(train_ehr_dataset)-i, config.batch_size]) 130 | batch_synthetic_ehrs = model.sample(bs, device) 131 | batch_synthetic_ehrs = torch.bernoulli(batch_synthetic_ehrs) 132 | batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs.detach().cpu().numpy()) 133 | synthetic_ehr_dataset += batch_synthetic_ehrs 134 | 135 | pickle.dump(synthetic_ehr_dataset, open(f'../../results/datasets/evaDataset.pkl', 'wb')) -------------------------------------------------------------------------------- /baselines/synteg/synteg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Brandon Theodorou 3 | Original SynTEG Paper Here: https://academic.oup.com/jamia/article/28/3/596/6024632 4 | SynTEG Pytorch Model Derived From: https://github.com/allhailjustice/SynTEG 5 | ''' 6 | import copy 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | 11 | ####### 12 | ### Dependency Learning Model 13 | ####### 14 | 15 | class Embedding(nn.Module): 16 | def __init__(self, config): 17 | """Construct an embedding matrix to embed sparse codes""" 18 | super(Embedding, self).__init__() 19 | self.code_embed = nn.Embedding(config.vocab_dim+1, config.embedding_dim) 20 | 21 | def forward(self, codes): # batch_size * visits * codes 22 | code_embeds = self.code_embed(codes) 23 | return code_embeds 24 | 25 | class SingleVisitTransformer(nn.Module): 26 | """An Encoder Transformer to turn code embeddings into a visit embedding""" 27 | def __init__(self, config): 28 | super(SingleVisitTransformer, self).__init__() 29 | encoderLayer = nn.TransformerEncoderLayer(config.embedding_dim, config.num_head, 30 | dim_feedforward=config.ff_dim, dropout=0.1, activation="relu", 31 | layer_norm_eps=1e-08, batch_first=True) 32 | self.transformer = nn.TransformerEncoder(encoderLayer, 2) 33 | 34 | def forward(self, code_embeddings, visit_lengths): 35 | bs, vs, cs, ed = code_embeddings.shape 36 | mask = torch.ones((bs, vs, cs)).to(code_embeddings.device) 37 | for i in range(bs): 38 | for j in range(vs): 39 | mask[i,j,:visit_lengths[i,j]] = 0 40 | visits = torch.reshape(code_embeddings, (bs*vs,cs,ed)) 41 | mask = torch.reshape(mask, (bs*vs,cs)) 42 | encodings = self.transformer(visits, src_key_padding_mask=mask) 43 | encodings = torch.reshape(encodings, (bs,vs,cs,ed)) 44 | visit_representations = encodings[:,:,0,:] 45 | return visit_representations 46 | 47 | class RecurrentLayer(nn.Module): 48 | """An Recurrent Layer to predict the next visit based on the visit embeddings""" 49 | def __init__(self, config): 50 | super(RecurrentLayer, self).__init__() 51 | self.lstm = nn.LSTM(input_size=config.lstm_dim, hidden_size=config.lstm_dim, num_layers=config.n_layer, dropout=0.1) 52 | 53 | def forward(self, visit_embeddings): 54 | output, _ = self.lstm(visit_embeddings) 55 | return output 56 | 57 | class DependencyModel(nn.Module): 58 | """The entire Dependency Model component of SynTEG""" 59 | def __init__(self, config): 60 | super(DependencyModel, self).__init__() 61 | self.embeddings = Embedding(config) 62 | self.visit_att = SingleVisitTransformer(config) 63 | self.proj1 = nn.Linear(config.embedding_dim, config.lstm_dim) 64 | self.lstm = RecurrentLayer(config) 65 | self.proj2 = nn.Linear(config.lstm_dim, config.condition_dim) 66 | self.proj3 = nn.Linear(config.condition_dim, config.vocab_dim) 67 | 68 | def forward(self, inputs_word, visit_lengths, export=False): # bs * visits * codes, bs * visits * 1 69 | inputs = self.embeddings(inputs_word) # bs * visits * codes * embedding_dim 70 | inputs = self.visit_att(inputs, visit_lengths) # bs * visits * embedding_dim 71 | inputs = self.proj1(inputs) # bs * visits * lstm_dim 72 | output = self.lstm(inputs) # bs * visits * lstm_dim 73 | if export: 74 | return self.proj2(output) # bs * visit * condition 75 | else: 76 | output = self.proj3(torch.relu(self.proj2(output))) # bs * visits * vocab_dim 77 | sig = nn.Sigmoid() 78 | diagnosis_output = sig(output[:, :-1, :]) 79 | return diagnosis_output 80 | 81 | ####### 82 | ### Conditional GAN Model 83 | ####### 84 | 85 | class PointWiseLayer(nn.Module): 86 | def __init__(self, num_outputs): 87 | """Construct an embedding matrix to embed sparse codes""" 88 | super(PointWiseLayer, self).__init__() 89 | self.bias = nn.Parameter(torch.zeros(num_outputs).uniform_(-math.sqrt(num_outputs), math.sqrt(num_outputs))) 90 | 91 | def forward(self, x1, x2): 92 | return x1 * x2 + self.bias 93 | 94 | class Generator(nn.Module): 95 | def __init__(self, config): 96 | super(Generator, self).__init__() 97 | self.dense_layers = nn.Sequential(*[nn.Linear(config.g_dims[i-1] if i > 0 else config.z_dim, config.g_dims[i]) for i in range(len(config.g_dims[:-1]))]) 98 | self.batch_norm_layers = nn.Sequential(*[nn.BatchNorm1d(dim, eps=1e-5) for dim in config.g_dims[:-1]]) 99 | self.output_layer = nn.Linear(config.g_dims[-2], config.g_dims[-1]) 100 | self.output_sigmoid = nn.Sigmoid() 101 | self.condition_layers = nn.Sequential(*[nn.Linear(config.condition_dim, dim) for dim in config.g_dims[:-1]]) 102 | self.pointwiselayers = nn.Sequential(*[PointWiseLayer(dim) for dim in config.g_dims[:-1]]) 103 | 104 | def forward(self, x, condition): 105 | for i in range(len(self.dense_layers)): 106 | h = self.dense_layers[i](x) 107 | x = nn.functional.relu(self.pointwiselayers[i](self.batch_norm_layers[i](h), self.condition_layers[i](condition))) 108 | x = self.output_layer(x) 109 | x = self.output_sigmoid(x) 110 | return x 111 | 112 | class Discriminator(nn.Module): 113 | def __init__(self, config): 114 | super(Discriminator, self).__init__() 115 | self.dense_layers = nn.Sequential(*[nn.Linear(config.d_dims[i-1] if i > 0 else config.g_dims[-1] + 1, config.d_dims[i]) for i in range(len(config.d_dims))]) 116 | self.layer_norm_layers = nn.Sequential(*[nn.LayerNorm(dim, eps=1e-5) for dim in config.d_dims]) 117 | self.output_layer = nn.Linear(config.d_dims[-1], 1) 118 | self.condition_layers = nn.Sequential(*[nn.Linear(config.condition_dim, dim) for dim in config.d_dims]) 119 | self.pointwiselayers = nn.Sequential(*[PointWiseLayer(dim) for dim in config.d_dims]) 120 | 121 | def forward(self, x, condition): 122 | a = (2 * x) ** 15 123 | sparsity = torch.sum(a / (a + 1), axis=-1, keepdim=True) 124 | x = torch.cat((x, sparsity), axis=-1) 125 | for i in range(len(self.dense_layers)): 126 | h = self.dense_layers[i](x) 127 | x = self.pointwiselayers[i](self.layer_norm_layers[i](h), self.condition_layers[i](condition)) 128 | x = self.output_layer(x) 129 | return x -------------------------------------------------------------------------------- /baselines/synteg/dependency_learning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import pickle 5 | import numpy as np 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | from config import SyntegConfig 9 | from synteg import DependencyModel 10 | 11 | SEED = 4 12 | random.seed(SEED) 13 | np.random.seed(SEED) 14 | torch.manual_seed(SEED) 15 | config = SyntegConfig() 16 | 17 | local_rank = -1 18 | fp16 = False 19 | if local_rank == -1: 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | n_gpu = torch.cuda.device_count() 22 | else: 23 | torch.cuda.set_device(local_rank) 24 | device = torch.device("cuda", local_rank) 25 | n_gpu = 1 26 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 27 | torch.distributed.init_process_group(backend='nccl') 28 | if torch.cuda.is_available(): 29 | torch.cuda.manual_seed_all(SEED) 30 | 31 | train_ehr_dataset = pickle.load(open('../../data/trainDataset.pkl', 'rb')) 32 | val_ehr_dataset = pickle.load(open('../../data/valDataset.pkl', 'rb')) 33 | 34 | def get_batch(loc, batch_size, mode): 35 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 36 | # Where each patient P is [V_1, V_2, ... , V_j] 37 | # Where each visit V is [C_1, C_2, ... , C_k] 38 | # And where each Label L is a binary vector [L_1 ... L_11] 39 | if mode == 'train': 40 | ehr = train_ehr_dataset[loc:loc+batch_size] 41 | elif mode == 'valid': 42 | ehr = val_ehr_dataset[loc:loc+batch_size] 43 | else: 44 | ehr = test_ehr_dataset[loc:loc+batch_size] 45 | 46 | batch_ehr = np.zeros((len(ehr), config.max_num_visit, config.max_length_visit)) 47 | batch_ehr[:,:,:] = config.vocab_dim # Initialize each code to the padding code 48 | batch_lens = np.ones((len(ehr), config.max_num_visit, 1)) 49 | batch_mask = np.zeros((len(ehr), config.max_num_visit, 1)) 50 | batch_num_visits = np.zeros(len(ehr)) 51 | for i, p in enumerate(ehr): 52 | visits = p['visits'] 53 | for j, v in enumerate(visits): 54 | batch_mask[i,j+2] = 1 55 | batch_lens[i,j+2] = len(v) + 1 56 | for k, c in enumerate(v): 57 | batch_ehr[i,j+2,k+1] = c 58 | batch_ehr[i,j+2,len(v)+1] = config.code_vocab_dim + config.label_vocab_dim + 1 # Set the last code in the last visit to be the end record code 59 | batch_lens[i,j+2] = len(v) + 2 60 | for l_idx, l in enumerate(np.nonzero(p['labels'])[0]): 61 | batch_ehr[i,1,l_idx+1] = config.code_vocab_dim + l 62 | batch_lens[i,1] = l_idx+2 63 | batch_num_visits[i] = len(visits) 64 | 65 | batch_mask[:,1] = 1 # Set the mask to cover the labels 66 | batch_ehr[:,:,0] = config.code_vocab_dim + config.label_vocab_dim # Set the first code in each visit to be the start/class token 67 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 68 | return batch_ehr, batch_lens, batch_mask, batch_num_visits 69 | 70 | def shuffle_training_data(train_ehr_dataset): 71 | np.random.shuffle(train_ehr_dataset) 72 | 73 | EPOCHS = 50 74 | LR = 1e-4 75 | model = DependencyModel(config).to(device) 76 | optimizer = torch.optim.Adam(model.parameters(), lr=LR) 77 | if os.path.exists("../../save/synteg_dependency_model"): 78 | print("Loading previous model") 79 | checkpoint = torch.load("../../save/synteg_dependency_model", map_location=torch.device(device)) 80 | model.load_state_dict(checkpoint['model']) 81 | optimizer.load_state_dict(checkpoint['optimizer']) 82 | 83 | # Train 84 | global_loss = 1e10 85 | for e in tqdm(range(EPOCHS)): 86 | shuffle_training_data(train_ehr_dataset) 87 | for i in range(0, len(train_ehr_dataset), config.dependency_batchsize): 88 | model.train() 89 | 90 | batch_ehr, batch_lens, batch_mask, _ = get_batch(i, config.dependency_batchsize, 'train') 91 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.int).to(device) 92 | batch_lens = torch.tensor(batch_lens, dtype=torch.int).to(device) 93 | batch_mask = torch.tensor(batch_mask, dtype=torch.int).to(device) 94 | 95 | optimizer.zero_grad() 96 | diagnosis_output = model(batch_ehr, batch_lens) 97 | labels = torch.sum(nn.functional.one_hot(batch_ehr.long(), num_classes=config.vocab_dim+1), dim=2).float() 98 | labels = labels[..., 1:, :-1].contiguous() 99 | diagnosis_output = diagnosis_output * batch_mask 100 | labels = labels * batch_mask 101 | bce = nn.BCELoss() 102 | loss = bce(diagnosis_output, labels) 103 | loss.backward() 104 | optimizer.step() 105 | 106 | if i % (500*config.dependency_batchsize) == 0: 107 | print("Epoch %d, Iter %d: Training Loss:%.6f"%(e, i, loss)) 108 | if i % (500*config.dependency_batchsize) == 0: 109 | if i == 0: 110 | continue 111 | 112 | model.eval() 113 | with torch.no_grad(): 114 | val_l = [] 115 | for v_i in range(0, len(val_ehr_dataset), config.dependency_batchsize): 116 | batch_ehr, batch_lens, batch_mask, _ = get_batch(v_i, config.dependency_batchsize, 'valid') 117 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.int).to(device) 118 | batch_lens = torch.tensor(batch_lens, dtype=torch.int).to(device) 119 | batch_mask = torch.tensor(batch_mask, dtype=torch.int).to(device) 120 | 121 | diagnosis_output = model(batch_ehr, batch_lens) 122 | labels = torch.sum(nn.functional.one_hot(batch_ehr.long(), num_classes=config.vocab_dim+1), dim=2).float() 123 | labels = labels[..., 1:, :-1].contiguous() 124 | diagnosis_output = diagnosis_output * batch_mask 125 | labels = labels * batch_mask 126 | bce = nn.BCELoss() 127 | val_loss = bce(diagnosis_output, labels) 128 | val_l.append(val_loss.item()) 129 | 130 | cur_val_loss = np.mean(val_l) 131 | print("Epoch %d Validation Loss:%.7f"%(e, cur_val_loss)) 132 | if cur_val_loss < global_loss: 133 | global_loss = cur_val_loss 134 | state = { 135 | 'model': model.state_dict(), 136 | 'optimizer': optimizer.state_dict(), 137 | 'iteration': i 138 | } 139 | torch.save(state, '../../save/synteg_dependency_model') 140 | print('\n------------ Save best model ------------\n') -------------------------------------------------------------------------------- /evaluate_privacy_membership.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from sklearn import metrics 7 | from model import HALOModel 8 | from config import HALOConfig 9 | 10 | SEED = 4 11 | random.seed(SEED) 12 | np.random.seed(SEED) 13 | torch.manual_seed(SEED) 14 | LR = 0.00001 15 | EPOCHS = 50 16 | BATCH_SIZE = 512 17 | LSTM_HIDDEN_DIM = 32 18 | EMBEDDING_DIM = 64 19 | NUM_TEST_EXAMPLES = 7500 20 | NUM_TOT_EXAMPLES = 7500 21 | NUM_VAL_EXAMPLES = 2500 22 | 23 | local_rank = -1 24 | fp16 = False 25 | if local_rank == -1: 26 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | n_gpu = torch.cuda.device_count() 28 | else: 29 | torch.cuda.set_device(local_rank) 30 | device = torch.device("cuda", local_rank) 31 | n_gpu = 1 32 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 33 | torch.distributed.init_process_group(backend='nccl') 34 | if torch.cuda.is_available(): 35 | torch.cuda.manual_seed_all(SEED) 36 | 37 | config = HALOConfig() 38 | train_ehr_dataset = pickle.load(open('./data/trainDataset.pkl', 'rb')) 39 | train_ehr_dataset = [(p,1) for p in train_ehr_dataset] 40 | test_ehr_dataset = pickle.load(open('./data/testDataset.pkl', 'rb')) 41 | test_ehr_dataset = [(p,0) for p in test_ehr_dataset] 42 | synthetic_ehr_dataset = pickle.load(open('./results/datasets/haloDataset.pkl', 'rb')) 43 | synthetic_ehr_dataset = [p for p in synthetic_ehr_dataset if len(p['visits']) > 0] 44 | 45 | attack_dataset_pos = list(random.sample(train_ehr_dataset, NUM_TOT_EXAMPLES)) 46 | attack_dataset_neg = list(random.sample(test_ehr_dataset, NUM_TOT_EXAMPLES)) 47 | np.random.shuffle(attack_dataset_pos) 48 | np.random.shuffle(attack_dataset_neg) 49 | test_attack_dataset = attack_dataset_pos[:NUM_TEST_EXAMPLES] + attack_dataset_neg[:NUM_TEST_EXAMPLES] 50 | val_attack_dataset = attack_dataset_pos[NUM_TEST_EXAMPLES:NUM_TEST_EXAMPLES+NUM_VAL_EXAMPLES] + attack_dataset_neg[NUM_TEST_EXAMPLES:NUM_TEST_EXAMPLES+NUM_VAL_EXAMPLES] 51 | np.random.shuffle(test_attack_dataset) 52 | np.random.shuffle(val_attack_dataset) 53 | attack_dataset_pos = attack_dataset_pos[NUM_TEST_EXAMPLES+NUM_VAL_EXAMPLES:] 54 | attack_dataset_neg = attack_dataset_neg[NUM_TEST_EXAMPLES+NUM_VAL_EXAMPLES:] 55 | 56 | def get_batch(loc, batch_size, dataset): 57 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 58 | # Where each patient P is [V_1, V_2, ... , V_j] 59 | # Where each visit V is [C_1, C_2, ... , C_k] 60 | # And where each Label L is a binary vector [L_1 ... L_11] 61 | ehr = dataset[loc:loc+batch_size] 62 | attack_labels = [l for (e,l) in ehr] 63 | ehr = [e for (e,l) in ehr] 64 | 65 | batch_ehr = np.zeros((len(ehr), config.n_ctx, config.total_vocab_size)) 66 | batch_mask = np.zeros((len(ehr), config.n_ctx, 1)) 67 | 68 | for i, p in enumerate(ehr): 69 | visits = p['visits'] 70 | for j, v in enumerate(visits): 71 | batch_ehr[i,j+2][v] = 1 72 | batch_mask[i,j+2] = 1 73 | batch_ehr[i,1,config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] = np.array(p['labels']) # Set the patient labels 74 | batch_ehr[i,len(visits)+1,config.code_vocab_size+config.label_vocab_size+1] = 1 # Set the final visit to have the end token 75 | batch_ehr[i,len(visits)+2:,config.code_vocab_size+config.label_vocab_size+2] = 1 # Set the rest to the padded visit token 76 | 77 | batch_mask[:,1] = 1 # Set the mask to cover the labels 78 | batch_ehr[:,0,config.code_vocab_size+config.label_vocab_size] = 1 # Set the first visits to be the start token 79 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 80 | return batch_ehr, batch_mask, attack_labels 81 | 82 | def find_hamming(ehr, dataset): 83 | min_d = 1e10 84 | visits = ehr['visits'] 85 | labels = ehr['labels'] 86 | for p in dataset: 87 | d = 0 if len(visits) == len(p['visits']) else 1 88 | l = p['labels'] 89 | d += ((labels + l) == 1).sum() 90 | for i in range(len(visits)): 91 | v = visits[i] 92 | if i >= len(p['visits']): 93 | d += len(v) 94 | else: 95 | v2 = p['visits'][i] 96 | d += len(v) + len(v2) - (2 * len(set(v) & set(v2))) 97 | 98 | min_d = d if d < min_d else min_d 99 | return min_d 100 | 101 | # Perform the Hamming Distance experiment 102 | ds = [(find_hamming(ehr, synthetic_ehr_dataset), l) for (ehr, l) in tqdm(test_attack_dataset)] 103 | median_dist = np.median([d for (d,l) in ds]) 104 | preds = [1 if d < median_dist else 0 for (d,l) in ds] 105 | labels = [l for (d,l) in ds] 106 | results = { 107 | "Accuracy": metrics.accuracy_score(labels, preds), 108 | "Precision": metrics.precision_score(labels, preds), 109 | "Recall": metrics.recall_score(labels, preds), 110 | "F1": metrics.f1_score(labels, preds) 111 | } 112 | pickle.dump(results, open("results/privacy_evaluation/hamming_model.pkl", "wb")) 113 | print(results) 114 | 115 | 116 | 117 | # Perform the Log Likelihood experiment 118 | model = HALOModel(config).to(device) 119 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 120 | checkpoint = torch.load('./save/halo_model', map_location=torch.device(device)) 121 | model.load_state_dict(checkpoint['model']) 122 | optimizer.load_state_dict(checkpoint['optimizer']) 123 | model.eval() 124 | 125 | probabilities = [] 126 | with torch.no_grad(): 127 | for i in tqdm(range(0, len(test_attack_dataset), config.batch_size)): 128 | # Get batch inputs 129 | batch_ehr, batch_mask, attack_labels = get_batch(i, 2*config.batch_size, test_attack_dataset) 130 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 131 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 132 | 133 | # Get batch outputs 134 | test_loss, predictions, labels = model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight) 135 | 136 | # Calculate and add probabilities 137 | # Note that the masked codes will have probability 1 and be ignored 138 | label_probs = torch.abs(labels - 1.0 + predictions) 139 | log_prob = torch.log(label_probs).sum(dim=[1,2]).cpu().detach().numpy().tolist() 140 | final_probs = [(p,l) for p,l in zip(log_prob, attack_labels)] 141 | probabilities.extend(final_probs) 142 | 143 | median_prob = np.median([p for (p,l) in probabilities]) 144 | preds = [1 if p > median_prob else 0 for (p,l) in probabilities] 145 | labels = [l for (p,l) in probabilities] 146 | results = { 147 | "Accuracy": metrics.accuracy_score(labels, preds), 148 | "Precision": metrics.precision_score(labels, preds), 149 | "Recall": metrics.recall_score(labels, preds), 150 | "F1": metrics.f1_score(labels, preds) 151 | } 152 | pickle.dump(results, open("results/privacy_evaluation/perplexity_model.pkl", "wb")) 153 | print(results) -------------------------------------------------------------------------------- /baselines/eva/eva.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 5 | 6 | class CausalConv1d(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size, dilation=1, **kwargs): 8 | super(CausalConv1d, self).__init__() 9 | self.pad = (kernel_size - 1) * dilation 10 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.pad, dilation=dilation, **kwargs) 11 | 12 | def forward(self, input): 13 | return self.conv(input)[:,:,:-self.conv.padding[0]] 14 | 15 | def connector(mu, log_var): 16 | std = torch.exp(0.5 * log_var) 17 | eps = torch.randn_like(std) 18 | return eps * std + mu 19 | 20 | class Encoder(nn.Module): 21 | def __init__(self, config): 22 | super(Encoder, self).__init__() 23 | self.hidden_dim = config.n_embd 24 | self.embedding_matrix = nn.Linear(config.total_vocab_size, config.n_embd, bias=False) 25 | self.lstm = nn.LSTM(input_size=config.n_embd, 26 | hidden_size=config.n_embd, 27 | num_layers=config.n_lstm_layer, 28 | bidirectional=True, 29 | batch_first=True) 30 | self.latent_encoder = nn.Linear(2*config.n_embd, 2*config.latent_dim) 31 | 32 | def forward(self, input, lengths): 33 | visit_emb = self.embedding_matrix(input) 34 | packed_input = pack_padded_sequence(visit_emb, lengths, batch_first=True, enforce_sorted=False) 35 | packed_output, _ = self.lstm(packed_input) 36 | output, _ = pad_packed_sequence(packed_output, batch_first=True) 37 | out_forward = output[range(len(output)), lengths - 1, :self.hidden_dim] 38 | out_reverse = output[:, 0, self.hidden_dim:] 39 | out_combined = torch.cat((out_forward, out_reverse), 1) 40 | mean_logvar = self.latent_encoder(out_combined) 41 | return mean_logvar 42 | 43 | class Decoder(nn.Module): 44 | def __init__(self, config): 45 | super(Decoder, self).__init__() 46 | self.deconv1 = nn.ConvTranspose1d(config.latent_dim, 64, 4, stride=2) 47 | self.deconv2 = nn.ConvTranspose1d(64, 64, 3, stride=2) 48 | self.deconv3 = nn.ConvTranspose1d(64, 64, 3, stride=2) 49 | self.deconv4 = nn.ConvTranspose1d(64, 128, 3, stride=3) 50 | self.causal_conv1 = CausalConv1d(128, 256, 5, dilation=2) 51 | self.causal_conv2 = CausalConv1d(256, 512, 5, dilation=2) 52 | self.causal_conv3 = CausalConv1d(512, 4096, 5, dilation=2) 53 | self.causal_conv4 = CausalConv1d(4096, config.total_vocab_size, 5, dilation=2) 54 | 55 | def forward(self, input): 56 | input = input.unsqueeze(2) 57 | out = self.deconv1(input) 58 | out = self.deconv2(out) 59 | out = self.deconv3(out) 60 | out = self.deconv4(out) 61 | out = self.causal_conv1(out) 62 | out = self.causal_conv2(out) 63 | out = self.causal_conv3(out) 64 | out = self.causal_conv4(out) 65 | out = out.transpose(1, 2) 66 | return out 67 | 68 | class Eva(nn.Module): 69 | def __init__(self, config): 70 | super(Eva, self).__init__() 71 | self.latent_dim = config.latent_dim 72 | self.encoder = Encoder(config) 73 | self.decoder = Decoder(config) 74 | 75 | def forward(self, input_visits, input_lengths, ehr_labels=None, ehr_masks=None, pos_loss_weight=None, kl_weight=1): #kl_weight 0.1 to 1 over a couple epochs 76 | mean_logvar = self.encoder(input_visits, input_lengths) 77 | mu = mean_logvar[:,:self.latent_dim] 78 | log_var = mean_logvar[:,self.latent_dim:] 79 | kl_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) 80 | decoder_inputs = connector(mu, log_var) 81 | code_logits = self.decoder(decoder_inputs) 82 | sig = nn.Sigmoid() 83 | code_probs = sig(code_logits) 84 | if ehr_labels is not None: 85 | shift_probs = code_probs[..., :-1, :].contiguous() 86 | shift_labels = ehr_labels[..., 1:, :].contiguous() 87 | loss_weights = None 88 | if pos_loss_weight is not None: 89 | loss_weights = torch.ones(shift_probs.shape, device=code_probs.device) 90 | loss_weights = loss_weights + (pos_loss_weight-1) * shift_labels 91 | if ehr_masks is not None: 92 | shift_probs = shift_probs * ehr_masks 93 | shift_labels = shift_labels * ehr_masks 94 | if pos_loss_weight is not None: 95 | loss_weights = loss_weights * ehr_masks 96 | 97 | bce = nn.BCELoss(weight=loss_weights) 98 | rc_loss = bce(shift_probs, shift_labels) 99 | loss = rc_loss + kl_weight * kl_loss 100 | return loss, shift_probs, shift_labels 101 | 102 | return code_probs 103 | 104 | def sample(self, batch_size, device): 105 | decoder_inputs = torch.randn((batch_size, self.latent_dim)).to(device) 106 | code_logits = self.decoder(decoder_inputs) 107 | sig = nn.Sigmoid() 108 | code_probs = sig(code_logits) 109 | return code_probs 110 | 111 | def marginal_log_likelihood(self, input_ehr, input_lens, input_mask, num_samples): 112 | bs = input_ehr.size(0) 113 | mean_logvar = self.encoder(input_ehr, input_lens) 114 | mu = mean_logvar[:,:self.latent_dim] 115 | log_var = mean_logvar[:,self.latent_dim:] 116 | 117 | rep_mu = mu.unsqueeze(1).repeat(1,num_samples,1) 118 | rep_log_var = log_var.unsqueeze(1).repeat(1,num_samples,1) 119 | rep_sigma = torch.exp(0.5 * rep_log_var) 120 | 121 | latent_samples = connector(rep_mu, rep_log_var) 122 | latent_samples = latent_samples.reshape((bs * num_samples, self.latent_dim)) 123 | rep_mu = rep_mu.reshape((bs * num_samples, self.latent_dim)) 124 | rep_sigma = rep_sigma.reshape((bs * num_samples, self.latent_dim)) 125 | 126 | log2pi = np.log(2*np.pi) 127 | logp_z = -log2pi * self.latent_dim / 2 - torch.sum(torch.square(latent_samples), axis=1) / 2 128 | logq_z_x = -log2pi * self.latent_dim / 2 - torch.sum(torch.square((latent_samples - rep_mu) / rep_sigma) + 2 * torch.log(rep_sigma), axis=1) / 2 129 | code_logits = self.decoder(latent_samples) 130 | sig = nn.Sigmoid() 131 | code_probs = sig(code_logits) 132 | code_probs = code_probs[..., :-1, :].contiguous() 133 | ehr_labels = input_ehr[..., 1:, :].contiguous() 134 | ehr_labels = ehr_labels.repeat((num_samples, 1, 1)) 135 | logp_x_z = torch.sum((ehr_labels * torch.log(code_probs) + (1 - ehr_labels) * torch.log(1 - code_probs)) * input_mask.repeat((num_samples, 1, 1)), axis=(1,2)) # bs*ns 136 | logp_x = logp_x_z + logp_z - logq_z_x 137 | logp_x = logp_x.reshape(bs, num_samples) 138 | m, _ = torch.max(logp_x, dim=1, keepdim=True) 139 | logprob = m + torch.log(torch.mean(torch.exp(logp_x - m), axis=1, keepdim=True)) 140 | return torch.sum(logprob) -------------------------------------------------------------------------------- /baselines/gpt/gpt.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Brandon Theodorou 3 | Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 4 | Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT 5 | GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch 6 | ''' 7 | import copy 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | 12 | def gelu(x): 13 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 14 | 15 | class LayerNorm(nn.Module): 16 | def __init__(self, hidden_size, eps=1e-12): 17 | """Construct a layernorm module in the TF style (epsilon inside the square root).""" 18 | super(LayerNorm, self).__init__() 19 | self.weight = nn.Parameter(torch.ones(hidden_size)) 20 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 21 | self.variance_epsilon = eps 22 | 23 | def forward(self, x): 24 | u = x.mean(-1, keepdim=True) 25 | s = (x - u).pow(2).mean(-1, keepdim=True) 26 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 27 | return self.weight * x + self.bias 28 | 29 | class Conv1D(nn.Module): 30 | def __init__(self, nf, nx): 31 | super(Conv1D, self).__init__() 32 | self.nf = nf 33 | w = torch.empty(nx, nf) 34 | nn.init.normal_(w, std=0.02) 35 | self.weight = nn.Parameter(w) 36 | self.bias = nn.Parameter(torch.zeros(nf)) 37 | 38 | def forward(self, x): 39 | size_out = x.size()[:-1] + (self.nf,) 40 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 41 | x = x.view(*size_out) 42 | return x 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, nx, n_ctx, config, scale=False): 46 | super(Attention, self).__init__() 47 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 48 | assert n_state % config.n_head == 0 49 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 50 | self.n_head = config.n_head 51 | self.split_size = n_state 52 | self.scale = scale 53 | self.c_attn = Conv1D(n_state * 3, nx) 54 | self.c_proj = Conv1D(n_state, nx) 55 | 56 | def _attn(self, q, k, v): 57 | w = torch.matmul(q, k) 58 | if self.scale: 59 | w = w / math.sqrt(v.size(-1)) 60 | nd, ns = w.size(-2), w.size(-1) 61 | b = self.bias[:, :, ns-nd:ns, :ns] 62 | w = w * b - 1e10 * (1 - b) 63 | w = nn.Softmax(dim=-1)(w) 64 | return torch.matmul(w, v) 65 | 66 | def merge_heads(self, x): 67 | x = x.permute(0, 2, 1, 3).contiguous() 68 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 69 | return x.view(*new_x_shape) 70 | 71 | def split_heads(self, x, k=False): 72 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 73 | x = x.view(*new_x_shape) 74 | if k: 75 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 76 | else: 77 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 78 | 79 | def forward(self, x, layer_past=None): 80 | x = self.c_attn(x) 81 | query, key, value = x.split(self.split_size, dim=2) 82 | query = self.split_heads(query) 83 | key = self.split_heads(key, k=True) 84 | value = self.split_heads(value) 85 | if layer_past is not None: 86 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 87 | key = torch.cat((past_key, key), dim=-1) 88 | value = torch.cat((past_value, value), dim=-2) 89 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 90 | a = self._attn(query, key, value) 91 | a = self.merge_heads(a) 92 | a = self.c_proj(a) 93 | return a, present 94 | 95 | class MLP(nn.Module): 96 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 97 | super(MLP, self).__init__() 98 | nx = config.n_embd 99 | self.c_fc = Conv1D(n_state, nx) 100 | self.c_proj = Conv1D(nx, n_state) 101 | self.act = gelu 102 | 103 | def forward(self, x): 104 | h = self.act(self.c_fc(x)) 105 | h2 = self.c_proj(h) 106 | return h2 107 | 108 | class Block(nn.Module): 109 | def __init__(self, n_ctx, config, scale=False): 110 | super(Block, self).__init__() 111 | nx = config.n_embd 112 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) 113 | self.attn = Attention(nx, n_ctx, config, scale) 114 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) 115 | self.mlp = MLP(4 * nx, config) 116 | 117 | def forward(self, x, layer_past=None): 118 | a, present = self.attn(self.ln_1(x), layer_past=layer_past) 119 | x = x + a 120 | m = self.mlp(self.ln_2(x)) 121 | x = x + m 122 | return x, present 123 | 124 | class GPT2Model(nn.Module): 125 | def __init__(self, config): 126 | super(GPT2Model, self).__init__() 127 | self.n_layer = config.n_layer 128 | self.n_embd = config.n_embd 129 | self.n_vocab = config.total_vocab_size 130 | 131 | self.code_embed_mat = nn.Embedding(config.total_vocab_size, config.n_embd) 132 | self.pos_embed_mat = nn.Embedding(config.n_positions, config.n_embd) 133 | block = Block(config.n_ctx, config, scale=True) 134 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) 135 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 136 | 137 | def forward(self, input_ids, position_ids=None, past=None): 138 | if past is None: 139 | past_length = 0 140 | past = [None] * len(self.h) 141 | else: 142 | past_length = past[0][0].size(-2) 143 | if position_ids is None: 144 | position_ids = torch.arange(past_length, input_ids.size(1) + past_length, dtype=torch.long, device=input_ids.device) 145 | position_ids = position_ids.unsqueeze(0).expand_as(input_ids) 146 | 147 | inputs_embeds = self.code_embed_mat(input_ids) 148 | position_embeds = self.pos_embed_mat(position_ids) 149 | hidden_states = inputs_embeds + position_embeds 150 | presents = [] 151 | for block, layer_past in zip(self.h, past): 152 | hidden_states, present = block(hidden_states, layer_past) 153 | presents.append(present) 154 | hidden_states = self.ln_f(hidden_states) 155 | return hidden_states, presents 156 | 157 | class GPTHead(nn.Module): 158 | def __init__(self, model_embeddings_weights, config): 159 | super(GPTHead, self).__init__() 160 | self.n_embd = config.n_embd 161 | self.set_embeddings_weights(model_embeddings_weights) 162 | 163 | def set_embeddings_weights(self, model_embeddings_weights): 164 | embed_shape = model_embeddings_weights.shape 165 | self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False) 166 | self.decoder.weight = nn.Parameter(model_embeddings_weights) # Tied weights 167 | 168 | def forward(self, hidden_state): 169 | code_logits = self.decoder(hidden_state) 170 | return code_logits 171 | 172 | class GPTModel(nn.Module): 173 | def __init__(self, config): 174 | super(GPTModel, self).__init__() 175 | self.transformer = GPT2Model(config) 176 | self.ehr_head = GPTHead(self.transformer.code_embed_mat.weight, config) 177 | self.config = config 178 | 179 | def set_tied(self): 180 | """Make sure we are sharing the embeddings""" 181 | self.ehr_head.set_embeddings_weights(self.transformer.code_embed_mat.weight) 182 | 183 | def forward(self, input_ids, position_ids=None, ehr_labels=None, past=None): 184 | hidden_states, presents = self.transformer(input_ids, position_ids, past) 185 | code_logits = self.ehr_head(hidden_states) 186 | if ehr_labels is not None: 187 | code_logits = code_logits[:, :-1, :].contiguous() 188 | ehr_labels = ehr_labels[:, 1:].contiguous() 189 | ce = nn.CrossEntropyLoss(ignore_index=self.config.total_vocab_size - 1) 190 | loss = ce(code_logits.view(-1, code_logits.size(-1)), ehr_labels.view(-1)) 191 | return loss, code_logits, ehr_labels 192 | 193 | return code_logits, presents -------------------------------------------------------------------------------- /baselines/haloCoarse/haloCoarse.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Brandon Theodorou 3 | Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 4 | Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT 5 | GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch 6 | ''' 7 | import copy 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | 12 | def gelu(x): 13 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 14 | 15 | class LayerNorm(nn.Module): 16 | def __init__(self, hidden_size, eps=1e-12): 17 | """Construct a layernorm module in the TF style (epsilon inside the square root).""" 18 | super(LayerNorm, self).__init__() 19 | self.weight = nn.Parameter(torch.ones(hidden_size)) 20 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 21 | self.variance_epsilon = eps 22 | 23 | def forward(self, x): 24 | u = x.mean(-1, keepdim=True) 25 | s = (x - u).pow(2).mean(-1, keepdim=True) 26 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 27 | return self.weight * x + self.bias 28 | 29 | class Conv1D(nn.Module): 30 | def __init__(self, nf, nx): 31 | super(Conv1D, self).__init__() 32 | self.nf = nf 33 | w = torch.empty(nx, nf) 34 | nn.init.normal_(w, std=0.02) 35 | self.weight = nn.Parameter(w) 36 | self.bias = nn.Parameter(torch.zeros(nf)) 37 | 38 | def forward(self, x): 39 | size_out = x.size()[:-1] + (self.nf,) 40 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 41 | x = x.view(*size_out) 42 | return x 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, nx, n_ctx, config, scale=False): 46 | super(Attention, self).__init__() 47 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 48 | assert n_state % config.n_head == 0 49 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 50 | self.n_head = config.n_head 51 | self.split_size = n_state 52 | self.scale = scale 53 | self.c_attn = Conv1D(n_state * 3, nx) 54 | self.c_proj = Conv1D(n_state, nx) 55 | 56 | def _attn(self, q, k, v): 57 | w = torch.matmul(q, k) 58 | if self.scale: 59 | w = w / math.sqrt(v.size(-1)) 60 | nd, ns = w.size(-2), w.size(-1) 61 | b = self.bias[:, :, ns-nd:ns, :ns] 62 | w = w * b - 1e10 * (1 - b) 63 | w = nn.Softmax(dim=-1)(w) 64 | return torch.matmul(w, v) 65 | 66 | def merge_heads(self, x): 67 | x = x.permute(0, 2, 1, 3).contiguous() 68 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 69 | return x.view(*new_x_shape) 70 | 71 | def split_heads(self, x, k=False): 72 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 73 | x = x.view(*new_x_shape) 74 | if k: 75 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 76 | else: 77 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 78 | 79 | def forward(self, x, layer_past=None): 80 | x = self.c_attn(x) 81 | query, key, value = x.split(self.split_size, dim=2) 82 | query = self.split_heads(query) 83 | key = self.split_heads(key, k=True) 84 | value = self.split_heads(value) 85 | if layer_past is not None: 86 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 87 | key = torch.cat((past_key, key), dim=-1) 88 | value = torch.cat((past_value, value), dim=-2) 89 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 90 | a = self._attn(query, key, value) 91 | a = self.merge_heads(a) 92 | a = self.c_proj(a) 93 | return a, present 94 | 95 | class MLP(nn.Module): 96 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 97 | super(MLP, self).__init__() 98 | nx = config.n_embd 99 | self.c_fc = Conv1D(n_state, nx) 100 | self.c_proj = Conv1D(nx, n_state) 101 | self.act = gelu 102 | 103 | def forward(self, x): 104 | h = self.act(self.c_fc(x)) 105 | h2 = self.c_proj(h) 106 | return h2 107 | 108 | class Block(nn.Module): 109 | def __init__(self, n_ctx, config, scale=False): 110 | super(Block, self).__init__() 111 | nx = config.n_embd 112 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) 113 | self.attn = Attention(nx, n_ctx, config, scale) 114 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) 115 | self.mlp = MLP(4 * nx, config) 116 | 117 | def forward(self, x, layer_past=None): 118 | a, present = self.attn(self.ln_1(x), layer_past=layer_past) 119 | x = x + a 120 | m = self.mlp(self.ln_2(x)) 121 | x = x + m 122 | return x, present 123 | 124 | class CoarseTransformerModel(nn.Module): 125 | def __init__(self, config): 126 | super(CoarseTransformerModel, self).__init__() 127 | self.n_layer = config.n_layer 128 | self.n_embd = config.n_embd 129 | self.n_vocab = config.total_vocab_size 130 | 131 | self.vis_embed_mat = nn.Linear(config.total_vocab_size, config.n_embd, bias=False) 132 | self.pos_embed_mat = nn.Embedding(config.n_positions, config.n_embd) 133 | block = Block(config.n_ctx, config, scale=True) 134 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) 135 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 136 | 137 | def forward(self, input_visits, position_ids=None, past=None): 138 | if past is None: 139 | past_length = 0 140 | past = [None] * len(self.h) 141 | else: 142 | past_length = past[0][0].size(-2) 143 | if position_ids is None: 144 | position_ids = torch.arange(past_length, input_visits.size(1) + past_length, dtype=torch.long, 145 | device=input_visits.device) 146 | position_ids = position_ids.unsqueeze(0).expand(input_visits.size(0), input_visits.size(1)) 147 | 148 | inputs_embeds = self.vis_embed_mat(input_visits) 149 | position_embeds = self.pos_embed_mat(position_ids) 150 | hidden_states = inputs_embeds + position_embeds 151 | presents = [] 152 | for block, layer_past in zip(self.h, past): 153 | hidden_states, present = block(hidden_states, layer_past) 154 | presents.append(present) 155 | hidden_states = self.ln_f(hidden_states) 156 | return hidden_states, presents 157 | 158 | class SimpleHead(nn.Module): 159 | def __init__(self, model_embeddings_weights, config): 160 | super(SimpleHead, self).__init__() 161 | self.n_embd = config.n_embd 162 | self.set_embeddings_weights(model_embeddings_weights) 163 | 164 | def set_embeddings_weights(self, model_embeddings_weights): 165 | embed_shape = model_embeddings_weights.shape 166 | self.decoder = nn.Linear(embed_shape[0], embed_shape[1], bias=False) 167 | self.decoder.weight = nn.Parameter(model_embeddings_weights.transpose(0, 1)) # Tied weights 168 | 169 | def forward(self, hidden_state): 170 | code_logits = self.decoder(hidden_state) 171 | return code_logits 172 | 173 | class HALOCoarseModel(nn.Module): 174 | def __init__(self, config): 175 | super(HALOCoarseModel, self).__init__() 176 | self.transformer = CoarseTransformerModel(config) 177 | self.ehr_head = SimpleHead(self.transformer.vis_embed_mat.weight, config) 178 | 179 | def set_tied(self): 180 | """Make sure we are sharing the embeddings""" 181 | self.ehr_head.set_embeddings_weights(self.transformer.vis_embed_mat.weight) 182 | 183 | def forward(self, input_visits, position_ids=None, ehr_labels=None, ehr_masks=None, past=None, pos_loss_weight=None): 184 | hidden_states, presents = self.transformer(input_visits, position_ids, past) 185 | code_logits = self.ehr_head(hidden_states) 186 | sig = nn.Sigmoid() 187 | code_probs = sig(code_logits) 188 | if ehr_labels is not None: 189 | shift_probs = code_probs[..., :-1, :].contiguous() 190 | shift_labels = ehr_labels[..., 1:, :].contiguous() 191 | loss_weights = None 192 | if pos_loss_weight is not None: 193 | loss_weights = torch.ones(shift_probs.shape, device=code_probs.device) 194 | loss_weights = loss_weights + (pos_loss_weight-1) * shift_labels 195 | if ehr_masks is not None: 196 | shift_probs = shift_probs * ehr_masks 197 | shift_labels = shift_labels * ehr_masks 198 | if pos_loss_weight is not None: 199 | loss_weights = loss_weights * ehr_masks 200 | 201 | bce = nn.BCELoss(weight=loss_weights) 202 | loss = bce(shift_probs, shift_labels) 203 | return loss, shift_probs, shift_labels 204 | 205 | return code_probs, presents -------------------------------------------------------------------------------- /hcup_ccs_2015_definitions_benchmark.yaml: -------------------------------------------------------------------------------- 1 | "Septicemia (except in labor)": 2 | use_in_benchmark: True 3 | type: "acute" 4 | id: 2 5 | codes: [ "0031", "0202", "0223", "0362", "0380", "0381", "03810", "03811", "03812", "03819", "0382", "0383", "03840", "03841", "03842", "03843", "03844", "03849", "0388", "0389", "0545", "449", "77181", "7907", "99591", "99592" ] 6 | 7 | "Diabetes mellitus without complication": 8 | use_in_benchmark: True 9 | type: "chronic" 10 | id: 49 11 | codes: [ "24900", "25000", "25001", "7902", "79021", "79022", "79029", "7915", "7916", "V4585", "V5391", "V6546" ] 12 | 13 | "Diabetes mellitus with complications": 14 | use_in_benchmark: True 15 | type: "chronic" 16 | id: 50 17 | codes: [ "24901", "24910", "24911", "24920", "24921", "24930", "24931", "24940", "24941", "24950", "24951", "24960", "24961", "24970", "24971", "24980", "24981", "24990", "24991", "25002", "25003", "25010", "25011", "25012", "25013", "25020", "25021", "25022", "25023", "25030", "25031", "25032", "25033", "25040", "25041", "25042", "25043", "25050", "25051", "25052", "25053", "25060", "25061", "25062", "25063", "25070", "25071", "25072", "25073", "25080", "25081", "25082", "25083", "25090", "25091", "25092", "25093" ] 18 | 19 | "Disorders of lipid metabolism": 20 | use_in_benchmark: True 21 | type: "chronic" 22 | id: 53 23 | codes: [ "2720", "2721", "2722", "2723", "2724" ] 24 | 25 | "Fluid and electrolyte disorders": 26 | use_in_benchmark: True 27 | type: "acute" 28 | id: 55 29 | codes: [ "2760", "2761", "2762", "2763", "2764", "2765", "27650", "27651", "27652", "2766", "27669", "2767", "2768", "2769", "9951" ] 30 | 31 | "Essential hypertension": 32 | use_in_benchmark: True 33 | type: "chronic" 34 | id: 98 35 | codes: [ "4011", "4019" ] 36 | 37 | "Hypertension with complications and secondary hypertension": 38 | use_in_benchmark: True 39 | type: "chronic" 40 | id: 99 41 | codes: [ "4010", "40200", "40201", "40210", "40211", "40290", "40291", "4030", "40300", "40301", "4031", "40310", "40311", "4039", "40390", "40391", "4040", "40400", "40401", "40402", "40403", "4041", "40410", "40411", "40412", "40413", "4049", "40490", "40491", "40492", "40493", "40501", "40509", "40511", "40519", "40591", "40599", "4372" ] 42 | 43 | "Acute myocardial infarction": 44 | use_in_benchmark: True 45 | type: "acute" 46 | id: 100 47 | codes: [ "4100", "41000", "41001", "41002", "4101", "41010", "41011", "41012", "4102", "41020", "41021", "41022", "4103", "41030", "41031", "41032", "4104", "41040", "41041", "41042", "4105", "41050", "41051", "41052", "4106", "41060", "41061", "41062", "4107", "41070", "41071", "41072", "4108", "41080", "41081", "41082", "4109", "41090", "41091", "41092" ] 48 | 49 | "Coronary atherosclerosis and other heart disease": 50 | use_in_benchmark: True 51 | type: "chronic" 52 | id: 101 53 | codes: [ "4110", "4111", "4118", "41181", "41189", "412", "4130", "4131", "4139", "4140", "41400", "41401", "41406", "4142", "4143", "4144", "4148", "4149", "V4581", "V4582" ] 54 | 55 | "Conduction disorders": 56 | use_in_benchmark: True 57 | type: "chronic" 58 | id: 105 59 | codes: [ "4260", "42610", "42611", "42612", "42613", "4262", "4263", "4264", "42650", "42651", "42652", "42653", "42654", "4266", "4267", "42681", "42682", "42689", "4269", "V450", "V4500", "V4501", "V4502", "V4509", "V533", "V5331", "V5332", "V5339" ] 60 | 61 | "Cardiac dysrhythmias": 62 | use_in_benchmark: True 63 | type: "chronic" 64 | id: 106 65 | codes: [ "4270", "4271", "4272", "42731", "42732", "42760", "42761", "42769", "42781", "42789", "4279", "7850", "7851" ] 66 | 67 | "Congestive heart failure; nonhypertensive": 68 | use_in_benchmark: True 69 | type: "acute" 70 | id: 108 71 | codes: [ "39891", "4280", "4281", "42820", "42821", "42822", "42823", "42830", "42831", "42832", "42833", "42840", "42841", "42842", "42843", "4289" ] 72 | 73 | "Acute cerebrovascular disease": 74 | use_in_benchmark: True 75 | type: "acute" 76 | id: 109 77 | codes: [ "34660", "34661", "34662", "34663", "430", "431", "4320", "4321", "4329", "43301", "43311", "43321", "43331", "43381", "43391", "4340", "43400", "43401", "4341", "43410", "43411", "4349", "43490", "43491", "436" ] 78 | 79 | "Pneumonia (except that caused by tuberculosis or sexually transmitted disease)": 80 | use_in_benchmark: True 81 | type: "acute" 82 | id: 122 83 | codes: [ "00322", "0203", "0204", "0205", "0212", "0221", "0310", "0391", "0521", "0551", "0730", "0830", "1124", "1140", "1144", "1145", "11505", "11515", "11595", "1304", "1363", "4800", "4801", "4802", "4803", "4808", "4809", "481", "4820", "4821", "4822", "4823", "48230", "48231", "48232", "48239", "4824", "48240", "48241", "48242", "48249", "4828", "48281", "48282", "48283", "48284", "48289", "4829", "483", "4830", "4831", "4838", "4841", "4843", "4845", "4846", "4847", "4848", "485", "486", "5130", "5171" ] 84 | 85 | "Chronic obstructive pulmonary disease and bronchiectasis": 86 | use_in_benchmark: True 87 | type: "chronic" 88 | id: 127 89 | codes: [ "490", "4910", "4911", "4912", "49120", "49121", "49122", "4918", "4919", "4920", "4928", "494", "4940", "4941", "496" ] 90 | 91 | "Pleurisy; pneumothorax; pulmonary collapse": 92 | use_in_benchmark: True 93 | type: "acute" 94 | id: 130 95 | codes: [ "5100", "5109", "5110", "5111", "5118", "51189", "5119", "5120", "5128", "51281", "51282", "51283", "51284", "51289", "5180", "5181", "5182" ] 96 | 97 | "Respiratory failure; insufficiency; arrest (adult)": 98 | use_in_benchmark: True 99 | type: "acute" 100 | id: 131 101 | codes: [ "5173", "5185", "51851", "51852", "51853", "51881", "51882", "51883", "51884", "7991", "V461", "V4611", "V4612", "V4613", "V4614", "V462" ] 102 | 103 | "Other lower respiratory disease": 104 | use_in_benchmark: True 105 | type: "acute" 106 | id: 133 107 | codes: [ "5131", "514", "515", "5160", "5161", "5162", "5163", "51630", "51631", "51632", "51633", "51634", "51635", "51636", "51637", "5164", "5165", "51661", "51662", "51663", "51664", "51669", "5168", "5169", "5172", "5178", "5183", "5184", "51889", "5194", "5198", "5199", "7825", "78600", "78601", "78602", "78603", "78604", "78605", "78606", "78607", "78609", "7862", "7863", "78630", "78631", "78639", "7864", "78652", "7866", "7867", "7868", "7869", "7931", "79311", "79319", "7942", "V126", "V1260", "V1261", "V1269", "V426" ] 108 | 109 | "Other upper respiratory disease": 110 | use_in_benchmark: True 111 | type: "acute" 112 | id: 134 113 | codes: [ "470", "4710", "4711", "4718", "4719", "4720", "4721", "4722", "4760", "4761", "4770", "4772", "4778", "4779", "4780", "4781", "47811", "47819", "47820", "47821", "47822", "47824", "47825", "47826", "47829", "47830", "47831", "47832", "47833", "47834", "4784", "4785", "4786", "47870", "47871", "47874", "47875", "47879", "4788", "4789", "5191", "51911", "51919", "5192", "5193", "7841", "78440", "78441", "78442", "78443", "78444", "78449", "7847", "7848", "7849", "78499", "7861", "V414", "V440", "V550" ] 114 | 115 | "Other liver diseases": 116 | use_in_benchmark: True 117 | type: "acute" 118 | id: 151 119 | codes: [ "570", "5715", "5716", "5718", "5719", "5720", "5721", "5722", "5723", "5724", "5728", "5730", "5734", "5735", "5738", "5739", "7824", "7891", "7895", "78959", "7904", "7905", "7948", "V427" ] 120 | 121 | "Gastrointestinal hemorrhage": 122 | use_in_benchmark: True 123 | type: "acute" 124 | id: 153 125 | codes: [ "4560", "45620", "5307", "53082", "53100", "53101", "53120", "53121", "53140", "53141", "53160", "53161", "53200", "53201", "53220", "53221", "53240", "53241", "53260", "53261", "53300", "53301", "53320", "53321", "53340", "53341", "53360", "53361", "53400", "53401", "53420", "53421", "53440", "53441", "53460", "53461", "5693", "5780", "5781", "5789" ] 126 | 127 | "Acute and unspecified renal failure": 128 | use_in_benchmark: True 129 | type: "acute" 130 | id: 157 131 | codes: [ "5845", "5846", "5847", "5848", "5849", "586" ] 132 | 133 | "Chronic kidney disease": 134 | use_in_benchmark: True 135 | type: "chronic" 136 | id: 158 137 | codes: [ "585", "5851", "5852", "5853", "5854", "5855", "5856", "5859", "7925", "V420", "V451", "V4511", "V4512", "V560", "V561", "V562", "V5631", "V5632", "V568" ] 138 | 139 | "Complications of surgical procedures or medical care": 140 | use_in_benchmark: True 141 | type: "acute" 142 | id: 238 143 | codes: [ "27661", "27783", "27788", "2853", "28741", "3490", "3491", "34931", "41511", "4294", "4582", "45821", "45829", "5121", "5122", "5187", "5190", "51900", "51901", "51902", "51909", "53086", "53087", "53640", "53641", "53642", "53649", "53901", "53909", "53981", "53989", "5642", "5643", "5644", "5696", "56962", "56971", "56979", "5793", "59681", "78062", "78063", "78066", "9093", "99524", "9954", "99586", "9970", "99700", "99701", "99702", "99709", "9971", "9972", "9973", "99731", "99732", "99739", "9974", "99741", "99749", "9975", "99760", "99761", "99762", "99769", "99771", "99772", "99779", "9979", "99791", "99799", "9980", "99800", "99801", "99802", "99809", "9981", "99811", "99812", "99813", "9982", "9983", "99830", "99831", "99832", "99833", "9984", "9985", "99851", "99859", "9986", "9987", "9988", "99881", "99882", "99883", "99889", "9989", "9990", "9991", "9992", "9993", "99934", "99939", "9994", "99941", "99942", "99949", "9995", "99951", "99952", "99959", "9996", "99960", "99961", "99962", "99963", "99969", "9997", "99970", "99971", "99972", "99973", "99974", "99975", "99976", "99977", "99978", "99979", "9998", "99980", "99981", "99982", "99983", "99984", "99985", "99988", "99989", "9999", "V1553", "V1580", "V1583", "V9001", "V9009" ] 144 | 145 | "Shock": 146 | use_in_benchmark: True 147 | type: "acute" 148 | id: 249 149 | codes: [ "78550", "78551", "78552", "78559" ] -------------------------------------------------------------------------------- /continuous_variables/genDatasetContinuous.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | from tqdm import tqdm 4 | import numpy as np 5 | import yaml 6 | import pickle 7 | from sklearn.model_selection import train_test_split 8 | 9 | MAX_TIME_STEPS = 150 10 | 11 | mimic_dir = "./" 12 | timeseries_dir = "../../Code/mimic3-benchmarks/data/root/all/" 13 | valid_subjects = os.listdir(timeseries_dir) 14 | patientsFile = mimic_dir + 'PATIENTS.csv' 15 | admissionFile = mimic_dir + "ADMISSIONS.csv" 16 | diagnosisFile = mimic_dir + "DIAGNOSES_ICD.csv" 17 | procedureFile = mimic_dir + "PROCEDURES_ICD.csv" 18 | medicationFile = mimic_dir + "PRESCRIPTIONS.csv" 19 | 20 | channel_to_id = pickle.load(open(mimic_dir + 'channel_to_id.pkl', 'rb')) 21 | is_categorical_channel = pickle.load(open(mimic_dir + 'is_categorical_channel.pkl', 'rb')) 22 | possible_values = pickle.load(open(mimic_dir + 'possible_values.pkl', 'rb')) 23 | begin_pos = pickle.load(open(mimic_dir + 'begin_pos.pkl', 'rb')) 24 | end_pos = pickle.load(open(mimic_dir + 'end_pos.pkl', 'rb')) 25 | 26 | print("Loading CSVs Into Dataframes") 27 | patientsDf = pd.read_csv(patientsFile, dtype=str).set_index("SUBJECT_ID") 28 | patientsDf = patientsDf[['GENDER', 'DOB']] 29 | patientsDf['DOB'] = pd.to_datetime(patientsDf['DOB']) 30 | admissionDf = pd.read_csv(admissionFile, dtype=str) 31 | admissionDf['ADMITTIME'] = pd.to_datetime(admissionDf['ADMITTIME']) 32 | admissionDf = admissionDf.sort_values('ADMITTIME') 33 | admissionDf = admissionDf.reset_index(drop=True) 34 | diagnosisDf = pd.read_csv(diagnosisFile, dtype=str).set_index("HADM_ID") 35 | diagnosisDf = diagnosisDf[diagnosisDf['ICD9_CODE'].notnull()] 36 | diagnosisDf = diagnosisDf[['ICD9_CODE']] 37 | procedureDf = pd.read_csv(procedureFile, dtype=str).set_index("HADM_ID") 38 | procedureDf = procedureDf[procedureDf['ICD9_CODE'].notnull()] 39 | procedureDf = procedureDf[['ICD9_CODE']] 40 | medicationDf = pd.read_csv(medicationFile, dtype=str).set_index("HADM_ID") 41 | medicationDf = medicationDf[medicationDf['NDC'].notnull()] 42 | medicationDf = medicationDf[medicationDf['NDC'] != 0] 43 | medicationDf = medicationDf[['NDC', 'DRUG']] 44 | medicationDf['NDC'] = medicationDf['NDC'].astype(int).astype(str) 45 | medicationDf['NDC'] = [('0' * (11 - len(c))) + c for c in medicationDf['NDC']] 46 | medicationDf['NDC'] = [c[0:5] + '-' + c[5:9] + '-' + c[10:12] for c in medicationDf['NDC']] 47 | 48 | print("Building Dataset") 49 | data = {} 50 | for row in tqdm(admissionDf.itertuples(), total=len(admissionDf)): 51 | hadm_id = row.HADM_ID 52 | subject_id = row.SUBJECT_ID 53 | admit_time = row.ADMITTIME 54 | 55 | if subject_id not in patientsDf.index: 56 | continue 57 | 58 | visit_count = (0 if subject_id not in data else len(data[subject_id]['visits'])) + 1 59 | 60 | tsDf = pd.read_csv(f"{timeseries_dir}{subject_id}/episode{visit_count}_timeseries.csv") if os.path.exists(f"{timeseries_dir}{subject_id}/episode{visit_count}_timeseries.csv") else None 61 | 62 | # Extract the gender and age 63 | patientRow = patientsDf.loc[[subject_id]].iloc[0] 64 | age = (admit_time.to_pydatetime() - patientRow['DOB'].to_pydatetime()).days / 365 65 | if age > 120: 66 | continue 67 | 68 | # Extracting the Diagnoses 69 | if hadm_id in diagnosisDf.index: 70 | diagnoses = list(set(diagnosisDf.loc[[hadm_id]]["ICD9_CODE"])) 71 | else: 72 | diagnoses = [] 73 | 74 | # Extracting the Procedures 75 | if hadm_id in procedureDf.index: 76 | procedures = list(set(procedureDf.loc[[hadm_id]]["ICD9_CODE"])) 77 | else: 78 | procedures = [] 79 | 80 | # Extracting the Medications 81 | if hadm_id in medicationDf.index: 82 | medications = list(set(medicationDf.loc[[hadm_id]]["NDC"])) 83 | else: 84 | medications = [] 85 | 86 | # Extract the lab timeseries 87 | labs = [] 88 | prevTime = 0 89 | currTime = int(tsDf.iloc[0]['Hours']) if tsDf is not None else 0 90 | currMask = [] 91 | currValues = [] 92 | if tsDf is not None: 93 | for i, row in tsDf.iterrows(): 94 | rowTime = int(row['Hours']) 95 | 96 | if rowTime != currTime: 97 | labs.append((currMask, currValues, [currTime - prevTime])) 98 | prevTime = currTime 99 | currTime = rowTime 100 | currMask = [] 101 | currValues = [] 102 | 103 | for col, value in row.iteritems(): 104 | if value != value or col == 'Hours': 105 | continue 106 | 107 | if is_categorical_channel[col]: 108 | if col == 'Glascow coma scale total': 109 | value = str(int(value)) 110 | elif col == 'Capillary refill rate': 111 | value = str(value) 112 | 113 | if begin_pos[channel_to_id[col]] in currMask: 114 | currValues[currMask.index(begin_pos[channel_to_id[col]] + possible_values[col].index(value))] = 1 115 | else: 116 | for j in range(begin_pos[channel_to_id[col]], end_pos[channel_to_id[col]]): 117 | currMask.append(j) 118 | currValues.append(1 if j - begin_pos[channel_to_id[col]] == possible_values[col].index(value) else 0) 119 | else: 120 | if begin_pos[channel_to_id[col]] in currMask: 121 | currValues[currMask.index(begin_pos[channel_to_id[col]])] = value 122 | else: 123 | currMask.append(begin_pos[channel_to_id[col]]) 124 | currValues.append(value) 125 | 126 | labs.append((currMask, currValues, [currTime - prevTime])) 127 | 128 | # Building the hospital admission data point 129 | if subject_id not in data: 130 | data[subject_id] = {'visits': [(diagnoses, procedures, medications, age, labs)]} 131 | else: 132 | data[subject_id]['visits'].append((diagnoses, procedures, medications, age, labs)) 133 | 134 | # Build the label mapping 135 | print("Adding Labels") 136 | with open("hcup_ccs_2015_definitions_benchmark.yaml") as definitions_file: 137 | definitions = yaml.full_load(definitions_file) 138 | 139 | code_to_group = {} 140 | for group in definitions: 141 | if definitions[group]['use_in_benchmark'] == False: 142 | continue 143 | codes = definitions[group]['codes'] 144 | for code in codes: 145 | if code not in code_to_group: 146 | code_to_group[code] = group 147 | else: 148 | assert code_to_group[code] == group 149 | 150 | id_to_group = sorted([k for k in definitions.keys() if definitions[k]['use_in_benchmark'] == True]) 151 | group_to_id = dict((x, i) for (i, x) in enumerate(id_to_group)) 152 | 153 | # Add Labels 154 | for p in data: 155 | label = np.zeros(len(group_to_id)) 156 | for v in data[p]['visits']: 157 | for d in v[0]: 158 | d = str(d) 159 | if d not in code_to_group: 160 | continue 161 | 162 | label[group_to_id[code_to_group[d]]] = 1 163 | 164 | data[p]['labels'] = label 165 | 166 | 167 | # Convert diagnoses, procedures, and medications to text 168 | print("Converting Codes to Text") 169 | medMapping = {row['NDC']: row['DRUG'] for _, row in medicationDf.iterrows()} 170 | for p in data: 171 | new_visits = [] 172 | for v in data[p]['visits']: 173 | new_visit = [] 174 | for c in v[0]: 175 | new_visit.append(c) 176 | for c in v[1]: 177 | new_visit.append(c) 178 | for c in v[2]: 179 | if c in medMapping: 180 | new_visit.append(medMapping[c]) 181 | else: 182 | new_visit.append(c) 183 | 184 | new_visits.append((new_visit, [], [], [v[3]])) 185 | 186 | for lab_v in v[4]: 187 | new_visits.append(([], lab_v[0], lab_v[1], lab_v[2])) 188 | data[p]['visits'] = new_visits 189 | 190 | 191 | # Convert diagnoses, procedures, and medications to indices 192 | print("Converting Codes to Indices") 193 | allCodes = list(set([c for p in data for v in data[p]['visits'] for c in v[0]])) 194 | np.random.shuffle(allCodes) 195 | code_to_index = {c: i for i, c in enumerate(allCodes)} 196 | counter = 0 197 | for p in data: 198 | new_visits = [] 199 | for v in data[p]['visits']: 200 | new_visit = [] 201 | for c in v[0]: 202 | new_visit.append(code_to_index[c]) 203 | 204 | new_visits.append((new_visit, v[1], v[2], v[3])) 205 | data[p]['visits'] = new_visits 206 | 207 | index_to_code = {v: k for k, v in code_to_index.items()} 208 | data = list(data.values()) 209 | data = [{'labels': p['labels'], 'visits': p['visits'][:MAX_TIME_STEPS - 2]} for p in data] # 2 for the start and label visits 210 | 211 | # Train-Val-Test Split 212 | print("Splitting Datasets") 213 | train_dataset, test_dataset = train_test_split(data, test_size=0.2, random_state=4, shuffle=True) 214 | train_dataset, val_dataset = train_test_split(train_dataset, test_size=0.1, random_state=4, shuffle=True) 215 | 216 | # Save Everything 217 | print("Saving Everything") 218 | print(f"CODE VOCAB SIZE: {len(index_to_code)}") 219 | print(f"LABEL VOCAB SIZE: {len(data[0]['labels'])}") 220 | pickle.dump(dict((i, x) for (x, i) in list(group_to_id.items())), open("./data/idToLabel.pkl", "wb")) 221 | pickle.dump(index_to_code, open("./data/indexToCode.pkl", "wb")) 222 | pickle.dump(data, open("./data/allData.pkl", "wb")) 223 | pickle.dump(train_dataset, open("./data/trainData.pkl", "wb")) 224 | pickle.dump(val_dataset, open("./data/valData.pkl", "wb")) 225 | pickle.dump(test_dataset, open("./data/testData.pkl", "wb")) 226 | -------------------------------------------------------------------------------- /continuous_variables/model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Brandon Theodorou 3 | Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 4 | Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT 5 | GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch 6 | ''' 7 | import copy 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | def gelu(x): 14 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 15 | 16 | class LayerNorm(nn.Module): 17 | def __init__(self, hidden_size, eps=1e-12): 18 | """Construct a layernorm module in the TF style (epsilon inside the square root).""" 19 | super(LayerNorm, self).__init__() 20 | self.weight = nn.Parameter(torch.ones(hidden_size)) 21 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 22 | self.variance_epsilon = eps 23 | 24 | def forward(self, x): 25 | u = x.mean(-1, keepdim=True) 26 | s = (x - u).pow(2).mean(-1, keepdim=True) 27 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 28 | return self.weight * x + self.bias 29 | 30 | class Conv1D(nn.Module): 31 | def __init__(self, nf, nx): 32 | super(Conv1D, self).__init__() 33 | self.nf = nf 34 | w = torch.empty(nx, nf) 35 | nn.init.normal_(w, std=0.02) 36 | self.weight = nn.Parameter(w) 37 | self.bias = nn.Parameter(torch.zeros(nf)) 38 | 39 | def forward(self, x): 40 | size_out = x.size()[:-1] + (self.nf,) 41 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 42 | x = x.view(*size_out) 43 | return x 44 | 45 | class Attention(nn.Module): 46 | def __init__(self, nx, n_ctx, config, scale=False): 47 | super(Attention, self).__init__() 48 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 49 | assert n_state % config.n_head == 0 50 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 51 | self.n_head = config.n_head 52 | self.split_size = n_state 53 | self.scale = scale 54 | self.c_attn = Conv1D(n_state * 3, nx) 55 | self.c_proj = Conv1D(n_state, nx) 56 | 57 | def _attn(self, q, k, v): 58 | w = torch.matmul(q, k) 59 | if self.scale: 60 | w = w / math.sqrt(v.size(-1)) 61 | nd, ns = w.size(-2), w.size(-1) 62 | b = self.bias[:, :, ns-nd:ns, :ns] 63 | w = w * b - 1e10 * (1 - b) 64 | w = nn.Softmax(dim=-1)(w) 65 | return torch.matmul(w, v) 66 | 67 | def merge_heads(self, x): 68 | x = x.permute(0, 2, 1, 3).contiguous() 69 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 70 | return x.view(*new_x_shape) 71 | 72 | def split_heads(self, x, k=False): 73 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 74 | x = x.view(*new_x_shape) 75 | if k: 76 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 77 | else: 78 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 79 | 80 | def forward(self, x, layer_past=None): 81 | x = self.c_attn(x) 82 | query, key, value = x.split(self.split_size, dim=2) 83 | query = self.split_heads(query) 84 | key = self.split_heads(key, k=True) 85 | value = self.split_heads(value) 86 | if layer_past is not None: 87 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 88 | key = torch.cat((past_key, key), dim=-1) 89 | value = torch.cat((past_value, value), dim=-2) 90 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 91 | a = self._attn(query, key, value) 92 | a = self.merge_heads(a) 93 | a = self.c_proj(a) 94 | return a, present 95 | 96 | class MLP(nn.Module): 97 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 98 | super(MLP, self).__init__() 99 | nx = config.n_embd 100 | self.c_fc = Conv1D(n_state, nx) 101 | self.c_proj = Conv1D(nx, n_state) 102 | self.act = gelu 103 | 104 | def forward(self, x): 105 | h = self.act(self.c_fc(x)) 106 | h2 = self.c_proj(h) 107 | return h2 108 | 109 | class Block(nn.Module): 110 | def __init__(self, n_ctx, config, scale=False): 111 | super(Block, self).__init__() 112 | nx = config.n_embd 113 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) 114 | self.attn = Attention(nx, n_ctx, config, scale) 115 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) 116 | self.mlp = MLP(4 * nx, config) 117 | 118 | def forward(self, x, layer_past=None): 119 | a, present = self.attn(self.ln_1(x), layer_past=layer_past) 120 | x = x + a 121 | m = self.mlp(self.ln_2(x)) 122 | x = x + m 123 | return x, present 124 | 125 | class CoarseTransformerModel(nn.Module): 126 | def __init__(self, config): 127 | super(CoarseTransformerModel, self).__init__() 128 | self.n_layer = config.n_layer 129 | self.n_embd = config.n_embd 130 | self.n_vocab = config.total_vocab_size 131 | 132 | self.vis_embed_mat = nn.Linear(config.total_vocab_size, config.n_embd, bias=False) 133 | self.pos_embed_mat = nn.Embedding(config.n_positions, config.n_embd) 134 | block = Block(config.n_ctx, config, scale=True) 135 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) 136 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 137 | 138 | def forward(self, input_visits, position_ids=None, past=None): 139 | if past is None: 140 | past_length = 0 141 | past = [None] * len(self.h) 142 | else: 143 | past_length = past[0][0].size(-2) 144 | if position_ids is None: 145 | position_ids = torch.arange(past_length, input_visits.size(1) + past_length, dtype=torch.long, 146 | device=input_visits.device) 147 | position_ids = position_ids.unsqueeze(0).expand(input_visits.size(0), input_visits.size(1)) 148 | 149 | inputs_embeds = self.vis_embed_mat(input_visits) 150 | position_embeds = self.pos_embed_mat(position_ids) 151 | hidden_states = inputs_embeds + position_embeds 152 | for block, layer_past in zip(self.h, past): 153 | hidden_states, _ = block(hidden_states, layer_past) 154 | hidden_states = self.ln_f(hidden_states) 155 | return hidden_states 156 | 157 | class AutoregressiveLinear(nn.Linear): 158 | """ same as Linear except has a configurable mask on the weights """ 159 | def __init__(self, in_features, out_features, bias=True): 160 | super().__init__(in_features, out_features, bias) 161 | self.register_buffer('mask', torch.tril(torch.ones(in_features, out_features)).int()) 162 | 163 | def forward(self, input): 164 | return F.linear(input, self.mask * self.weight, self.bias) 165 | 166 | class FineAutoregressiveHead(nn.Module): 167 | def __init__(self, config): 168 | super(FineAutoregressiveHead, self).__init__() 169 | self.n_embd = config.n_embd 170 | self.total_vocab_size = config.total_vocab_size 171 | 172 | self.auto1 = AutoregressiveLinear(config.n_embd + self.total_vocab_size, config.n_embd + self.total_vocab_size) 173 | self.auto2 = AutoregressiveLinear(config.n_embd + self.total_vocab_size, config.n_embd + self.total_vocab_size) 174 | 175 | def forward(self, history, input_visits): 176 | history = history[:,:-1,:] 177 | input_visits = input_visits[:,1:,:] 178 | code_logits = self.auto2(torch.relu(self.auto1(torch.cat((history, input_visits), dim=2))))[:,:,self.n_embd-1:-1] 179 | return code_logits 180 | 181 | def sample(self, history, input_visits): 182 | history = history[:,:-1,:] 183 | input_visits = input_visits[:,1:,:] 184 | currVisit = torch.cat((history, input_visits), dim=2)[:,-1,:].unsqueeze(1) 185 | code_logits = self.auto2(torch.relu(self.auto1(currVisit)))[:,:,self.n_embd-1:-1] 186 | return code_logits 187 | 188 | class HALOModel(nn.Module): 189 | def __init__(self, config): 190 | super(HALOModel, self).__init__() 191 | self.transformer = CoarseTransformerModel(config) 192 | self.ehr_head = FineAutoregressiveHead(config) 193 | self.total_vocab_size = config.total_vocab_size 194 | 195 | def forward(self, input_visits, position_ids=None, ehr_labels=None, ehr_masks=None, past=None): 196 | hidden_states = self.transformer(input_visits, position_ids, past) 197 | code_logits = self.ehr_head(hidden_states, input_visits) 198 | sig = nn.Sigmoid() 199 | code_probs = sig(code_logits) 200 | 201 | if ehr_labels is not None: 202 | shift_labels = ehr_labels[..., 1:, :].contiguous() 203 | if ehr_masks is not None: 204 | code_probs = code_probs * ehr_masks 205 | shift_labels = shift_labels * ehr_masks 206 | 207 | bce = nn.BCELoss() 208 | loss = bce(code_probs, shift_labels) 209 | return loss, code_probs, shift_labels 210 | 211 | return code_probs 212 | 213 | def sample(self, input_visits, random=True): 214 | sig = nn.Sigmoid() 215 | hidden_states = self.transformer(input_visits) 216 | i = 0 217 | while i < self.total_vocab_size: 218 | next_logits = self.ehr_head.sample(hidden_states, input_visits) 219 | next_probs = sig(next_logits) 220 | if random: 221 | visit = torch.bernoulli(next_probs) 222 | else: 223 | visit = torch.round(next_probs) 224 | 225 | remaining_visit = visit[:,0,i:] 226 | nonzero = torch.nonzero(remaining_visit, as_tuple=True)[1] 227 | if nonzero.numel() == 0: 228 | break 229 | 230 | first_nonzero = nonzero.min() 231 | input_visits[:,-1,i + first_nonzero] = visit[:,0,i + first_nonzero] 232 | i = i + first_nonzero + 1 233 | 234 | return input_visits -------------------------------------------------------------------------------- /test_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from sklearn import metrics 7 | from model import HALOModel 8 | from config import HALOConfig 9 | 10 | SEED = 4 11 | random.seed(SEED) 12 | np.random.seed(SEED) 13 | torch.manual_seed(SEED) 14 | config = HALOConfig() 15 | 16 | local_rank = -1 17 | fp16 = False 18 | if local_rank == -1: 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | n_gpu = torch.cuda.device_count() 21 | else: 22 | torch.cuda.set_device(local_rank) 23 | device = torch.device("cuda", local_rank) 24 | n_gpu = 1 25 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 26 | torch.distributed.init_process_group(backend='nccl') 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(SEED) 29 | 30 | train_ehr_dataset = pickle.load(open('./data/trainDataset.pkl', 'rb')) 31 | test_ehr_dataset = pickle.load(open('./data/testDataset.pkl', 'rb')) 32 | index_to_code = pickle.load(open("./data/indexToCode.pkl", "rb")) 33 | id_to_label = pickle.load(open("./data/idToLabel.pkl", "rb")) 34 | train_c = set([c for p in train_ehr_dataset for v in p['visits'] for c in v]) 35 | test_ehr_dataset = [{'labels': p['labels'], 'visits': [[c for c in v if c in train_c] for v in p['visits']]} for p in test_ehr_dataset] 36 | 37 | def get_batch(loc, batch_size, mode): 38 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 39 | # Where each patient P is [V_1, V_2, ... , V_j] 40 | # Where each visit V is [C_1, C_2, ... , C_k] 41 | # And where each Label L is a binary vector [L_1 ... L_n] 42 | if mode == 'train': 43 | ehr = train_ehr_dataset[loc:loc+batch_size] 44 | elif mode == 'valid': 45 | ehr = val_ehr_dataset[loc:loc+batch_size] 46 | else: 47 | ehr = test_ehr_dataset[loc:loc+batch_size] 48 | 49 | batch_ehr = np.zeros((len(ehr), config.n_ctx, config.total_vocab_size)) 50 | batch_mask = np.zeros((len(ehr), config.n_ctx, 1)) 51 | for i, p in enumerate(ehr): 52 | visits = p['visits'] 53 | for j, v in enumerate(visits): 54 | batch_ehr[i,j+2][v] = 1 55 | batch_mask[i,j+2] = 1 56 | batch_ehr[i,1,config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] = np.array(p['labels']) # Set the patient labels 57 | batch_ehr[i,len(visits)+1,config.code_vocab_size+config.label_vocab_size+1] = 1 # Set the final visit to have the end token 58 | batch_ehr[i,len(visits)+2:,config.code_vocab_size+config.label_vocab_size+2] = 1 # Set the rest to the padded visit token 59 | 60 | batch_mask[:,1] = 1 # Set the mask to cover the labels 61 | batch_ehr[:,0,config.code_vocab_size+config.label_vocab_size] = 1 # Set the first visits to be the start token 62 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 63 | return batch_ehr, batch_mask 64 | 65 | def conf_mat(x, y): 66 | totaltrue = np.sum(x) 67 | totalfalse = len(x) - totaltrue 68 | truepos, totalpos = np.sum(x & y), np.sum(y) 69 | falsepos = totalpos - truepos 70 | return np.array([[totalfalse - falsepos, falsepos], #true negatives, false positives 71 | [totaltrue - truepos, truepos]]) #false negatives, true positives 72 | 73 | model = HALOModel(config).to(device) 74 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 75 | 76 | checkpoint = torch.load('./save/halo_model2', map_location=torch.device(device)) 77 | model.load_state_dict(checkpoint['model']) 78 | optimizer.load_state_dict(checkpoint['optimizer']) 79 | 80 | confusion_matrix = None 81 | probability_list = [] 82 | loss_list = [] 83 | n_visits = 0 84 | n_pos_codes = 0 85 | n_total_codes = 0 86 | model.eval() 87 | with torch.no_grad(): 88 | for v_i in tqdm(range(0, len(test_ehr_dataset), 2*config.batch_size)): 89 | # Get batch inputs 90 | batch_ehr, batch_mask = get_batch(v_i, 2*config.batch_size, 'test') 91 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 92 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 93 | 94 | # Get batch outputs 95 | test_loss, predictions, labels = model(batch_ehr, position_ids=None, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight) 96 | batch_mask_array = batch_mask.squeeze().cpu().detach().numpy() 97 | rounded_preds = np.around(predictions.squeeze().cpu().detach().numpy()).transpose((2,0,1)) 98 | rounded_preds = rounded_preds + batch_mask_array - 1 # Setting the masked visits to be -1 to be ignored by the confusion matrix 99 | rounded_preds = rounded_preds.flatten() 100 | true_values = labels.squeeze().cpu().detach().numpy().transpose((2,0,1)) 101 | true_values = true_values + batch_mask_array - 1 # Setting the masked visits to be -1 to be ignored by the confusion matrix 102 | true_values = true_values.flatten() 103 | 104 | # Append test lost 105 | loss_list.append(test_loss.cpu().detach().numpy()) 106 | 107 | # Add number of visits and codes 108 | n_visits += torch.sum(batch_mask).cpu().item() 109 | n_pos_codes += torch.sum(labels).cpu().item() 110 | n_total_codes += (torch.sum(batch_mask) * config.total_vocab_size).cpu().item() 111 | 112 | # Add confusion matrix 113 | batch_cmatrix = conf_mat(true_values == 1, rounded_preds == 1) 114 | batch_cmatrix[0][0] = torch.sum(batch_mask) * config.total_vocab_size - batch_cmatrix[0][1] - batch_cmatrix[1][0] - batch_cmatrix[1][1] # Remove the masked values 115 | confusion_matrix = batch_cmatrix if confusion_matrix is None else confusion_matrix + batch_cmatrix 116 | 117 | # Calculate and add probabilities 118 | # Note that the masked codes will have probability 1 and be ignored 119 | label_probs = torch.abs(labels - 1.0 + predictions) 120 | log_prob = torch.sum(torch.log(label_probs)).cpu().item() 121 | probability_list.append(log_prob) 122 | 123 | # Save intermediate values in case of error 124 | intermediate = {} 125 | intermediate["Losses"] = loss_list 126 | intermediate["Confusion Matrix"] = confusion_matrix 127 | intermediate["Probabilities"] = probability_list 128 | intermediate["Num Visits"] = n_visits 129 | intermediate["Num Positive Codes"] = n_pos_codes 130 | intermediate["Num Total Codes"] = n_total_codes 131 | pickle.dump(intermediate, open("./results/testing_stats/HALO_intermediate_results.pkl", "wb")) 132 | 133 | #Extract, save, and display test metrics 134 | avg_loss = np.nanmean(loss_list) 135 | tn, fp, fn, tp = confusion_matrix.ravel() 136 | acc = (tn + tp)/(tn+fp+fn+tp) 137 | prc = tp/(tp+fp) 138 | rec = tp/(tp+fn) 139 | f1 = (2 * prc * rec)/(prc + rec) 140 | log_probability = np.sum(probability_list) 141 | pp_visit = np.exp(-log_probability/n_visits) 142 | pp_positive = np.exp(-log_probability/n_pos_codes) 143 | pp_possible = np.exp(-log_probability/n_total_codes) 144 | 145 | metrics_dict = {} 146 | metrics_dict['Test Loss'] = avg_loss 147 | metrics_dict['Confusion Matrix'] = confusion_matrix 148 | metrics_dict['Accuracy'] = acc 149 | metrics_dict['Precision'] = prc 150 | metrics_dict['Recall'] = rec 151 | metrics_dict['F1 Score'] = f1 152 | metrics_dict['Test Log Probability'] = log_probability 153 | metrics_dict['Perplexity Per Visit'] = pp_visit 154 | metrics_dict['Perplexity Per Positive Code'] = pp_positive 155 | metrics_dict['Perplexity Per Possible Code'] = pp_possible 156 | pickle.dump(metrics_dict, open("./results/testing_stats/HALO_Metrics.pkl", "wb")) 157 | 158 | print("Average Test Loss: ", avg_loss) 159 | print("Confusion Matrix: ", confusion_matrix) 160 | print('Accuracy: ', acc) 161 | print('Precision: ', prc) 162 | print('Recall: ', rec) 163 | print('F1 Score: ', f1) 164 | print('Test Log Probability: ', log_probability) 165 | print('Perplexity Per Visit: ', pp_visit) 166 | print('Perplexity Per Positive Code: ', pp_positive) 167 | print('Perplexity Per Possible Code: ', pp_possible) 168 | 169 | def sample_sequence(model, length, context, batch_size, device='cuda', sample=True): 170 | empty = torch.zeros((1,1,config.total_vocab_size), device=device, dtype=torch.float32).repeat(batch_size, 1, 1) 171 | context = torch.tensor(context, device=device, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1) 172 | prev = context.unsqueeze(1) 173 | context = None 174 | with torch.no_grad(): 175 | for _ in range(length-1): 176 | prev = model.sample(torch.cat((prev,empty), dim=1), sample) 177 | if torch.sum(torch.sum(prev[:,:,config.code_vocab_size+config.label_vocab_size+1], dim=1).bool().int(), dim=0).item() == batch_size: 178 | break 179 | ehr = prev.cpu().detach().numpy() 180 | prev = None 181 | empty = None 182 | return ehr 183 | 184 | def convert_ehr(ehrs, index_to_code=None): 185 | ehr_outputs = [] 186 | for i in range(len(ehrs)): 187 | ehr = ehrs[i] 188 | ehr_output = [] 189 | labels_output = ehr[1][config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] 190 | if index_to_code is not None: 191 | labels_output = [index_to_code[idx + config.code_vocab_size] for idx in np.nonzero(labels_output)[0]] 192 | for j in range(2, len(ehr)): 193 | visit = ehr[j] 194 | visit_output = [] 195 | indices = np.nonzero(visit)[0] 196 | end = False 197 | for idx in indices: 198 | if idx < config.code_vocab_size: 199 | visit_output.append(index_to_code[idx] if index_to_code is not None else idx) 200 | elif idx == config.code_vocab_size+config.label_vocab_size+1: 201 | end = True 202 | if visit_output != []: 203 | ehr_output.append(visit_output) 204 | if end: 205 | break 206 | ehr_outputs.append({'visits': ehr_output, 'labels': labels_output}) 207 | ehr = None 208 | ehr_output = None 209 | labels_output = None 210 | visit = None 211 | visit_output = None 212 | indices = None 213 | return ehr_outputs 214 | 215 | # Generate Synthetic EHR dataset 216 | synthetic_ehr_dataset = [] 217 | stoken = np.zeros(config.total_vocab_size) 218 | stoken[config.code_vocab_size+config.label_vocab_size] = 1 219 | for i in tqdm(range(0, len(train_ehr_dataset), config.sample_batch_size)): 220 | bs = min([len(train_ehr_dataset)-i, config.sample_batch_size]) 221 | batch_synthetic_ehrs = sample_sequence(model, config.n_ctx, stoken, batch_size=bs, device=device, sample=True) 222 | batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs) 223 | synthetic_ehr_dataset += batch_synthetic_ehrs 224 | 225 | pickle.dump(synthetic_ehr_dataset, open(f'./results/datasets/haloDataset.pkl', 'wb')) -------------------------------------------------------------------------------- /baselines/lstm/test_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import numpy as np 5 | from tqdm import tqdm 6 | from sklearn import metrics 7 | from lstm import LSTMBaseline 8 | from config import LSTMConfig 9 | 10 | SEED = 4 11 | random.seed(SEED) 12 | np.random.seed(SEED) 13 | torch.manual_seed(SEED) 14 | config = LSTMConfig() 15 | 16 | local_rank = -1 17 | fp16 = False 18 | if local_rank == -1: 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | n_gpu = torch.cuda.device_count() 21 | else: 22 | torch.cuda.set_device(local_rank) 23 | device = torch.device("cuda", local_rank) 24 | n_gpu = 1 25 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 26 | torch.distributed.init_process_group(backend='nccl') 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed_all(SEED) 29 | 30 | train_ehr_dataset = pickle.load(open('../../data/trainDataset.pkl', 'rb')) 31 | test_ehr_dataset = pickle.load(open('../../data/testDataset.pkl', 'rb')) 32 | index_to_code = pickle.load(open("../../data/indexToCode.pkl", "rb")) 33 | id_to_label = pickle.load(open("../../data/idToLabel.pkl", "rb")) 34 | train_c = set([c for p in train_ehr_dataset for v in p['visits'] for c in v]) 35 | test_ehr_dataset = [{'labels': p['labels'], 'visits': [[c for c in v if c in train_c] for v in p['visits']]} for p in test_ehr_dataset] 36 | 37 | # Add the labels to the index_to_code mapping 38 | for k, l in id_to_label.items(): 39 | index_to_code[config.code_vocab_size+k] = f"Chronic Condition: {l}" 40 | 41 | def get_batch(loc, batch_size, mode): 42 | # EHR data saved as [(P_1, L_1), (P_2, L_2), ... , (P_i, L_i)] 43 | # Where each patient P is [V_1, V_2, ... , V_j] 44 | # Where each visit V is [C_1, C_2, ... , C_k] 45 | # And where each Label L is a binary vector [L_1 ... L_n] 46 | if mode == 'train': 47 | ehr = train_ehr_dataset[loc:loc+batch_size] 48 | elif mode == 'valid': 49 | ehr = val_ehr_dataset[loc:loc+batch_size] 50 | else: 51 | ehr = test_ehr_dataset[loc:loc+batch_size] 52 | 53 | batch_ehr = np.zeros((len(ehr), config.n_ctx, config.total_vocab_size)) 54 | batch_mask = np.zeros((len(ehr), config.n_ctx, 1)) 55 | for i, p in enumerate(ehr): 56 | visits = p['visits'] 57 | for j, v in enumerate(visits): 58 | batch_ehr[i,j+2][v] = 1 59 | batch_mask[i,j+2] = 1 60 | batch_ehr[i,1,config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] = np.array(p['labels']) # Set the patient labels 61 | batch_ehr[i,len(visits)+1,config.code_vocab_size+config.label_vocab_size+1] = 1 # Set the final visit to have the end token 62 | batch_ehr[i,len(visits)+2:,config.code_vocab_size+config.label_vocab_size+2] = 1 # Set the rest to the padded visit token 63 | 64 | batch_mask[:,1] = 1 # Set the mask to cover the labels 65 | batch_ehr[:,0,config.code_vocab_size+config.label_vocab_size] = 1 # Set the first visits to be the start token 66 | batch_mask = batch_mask[:,1:,:] # Shift the mask to match the shifted labels and predictions the model will return 67 | return batch_ehr, batch_mask 68 | 69 | def conf_mat(x, y): 70 | totaltrue = np.sum(x) 71 | totalfalse = len(x) - totaltrue 72 | truepos, totalpos = np.sum(x & y), np.sum(y) 73 | falsepos = totalpos - truepos 74 | return np.array([[totalfalse - falsepos, falsepos], #true negatives, false positives 75 | [totaltrue - truepos, truepos]]) #false negatives, true positives 76 | 77 | model = LSTMBaseline(config).to(device) 78 | optimizer = torch.optim.Adam(model.parameters(), lr=config.lr) 79 | 80 | checkpoint = torch.load('../../save/lstm_model', map_location=torch.device(device)) 81 | model.load_state_dict(checkpoint['model']) 82 | optimizer.load_state_dict(checkpoint['optimizer']) 83 | 84 | confusion_matrix = None 85 | probability_list = [] 86 | loss_list = [] 87 | n_visits = 0 88 | n_pos_codes = 0 89 | n_total_codes = 0 90 | model.eval() 91 | with torch.no_grad(): 92 | for v_i in tqdm(range(0, len(test_ehr_dataset), config.batch_size)): 93 | # Get batch inputs 94 | batch_ehr, batch_mask = get_batch(v_i, config.batch_size, 'test') 95 | batch_ehr = torch.tensor(batch_ehr, dtype=torch.float32).to(device) 96 | batch_mask = torch.tensor(batch_mask, dtype=torch.float32).to(device) 97 | 98 | # Get batch outputs 99 | test_loss, predictions, labels = model(batch_ehr, ehr_labels=batch_ehr, ehr_masks=batch_mask, pos_loss_weight=config.pos_loss_weight) 100 | batch_mask_array = batch_mask.squeeze().cpu().detach().numpy() 101 | rounded_preds = np.around(predictions.squeeze().cpu().detach().numpy()).transpose((2,0,1)) 102 | rounded_preds = rounded_preds + batch_mask_array - 1 # Setting the masked visits to be -1 to be ignored by the confusion matrix 103 | rounded_preds = rounded_preds.flatten() 104 | true_values = labels.squeeze().cpu().detach().numpy().transpose((2,0,1)) 105 | true_values = true_values + batch_mask_array - 1 # Setting the masked visits to be -1 to be ignored by the confusion matrix 106 | true_values = true_values.flatten() 107 | 108 | # Append test lost 109 | loss_list.append(test_loss.cpu().detach().numpy()) 110 | 111 | # Add confusion matrix 112 | batch_cmatrix = conf_mat(true_values == 1, rounded_preds == 1) 113 | batch_cmatrix[0][0] = torch.sum(batch_mask) * config.total_vocab_size - batch_cmatrix[0][1] - batch_cmatrix[1][0] - batch_cmatrix[1][1] # Remove the masked values 114 | confusion_matrix = batch_cmatrix if confusion_matrix is None else confusion_matrix + batch_cmatrix 115 | 116 | # Add number of visits and codes 117 | n_visits += torch.sum(batch_mask) 118 | n_pos_codes += torch.sum(labels) 119 | n_total_codes += torch.sum(batch_mask) * config.total_vocab_size 120 | 121 | # Calculate and add probabilities 122 | # Note that the masked codes will have probability 1 and be ignored 123 | label_probs = torch.abs(labels - 1.0 + predictions) 124 | log_prob = torch.sum(torch.log(label_probs)).cpu().item() 125 | probability_list.append(log_prob) 126 | 127 | n_visits = n_visits.cpu().item() 128 | n_pos_codes = n_pos_codes.cpu().item() 129 | n_total_codes = n_total_codes.cpu().item() 130 | 131 | # Save intermediate values in case of error 132 | intermediate = {} 133 | intermediate["Losses"] = loss_list 134 | intermediate["Confusion Matrix"] = confusion_matrix 135 | intermediate["Probabilities"] = probability_list 136 | intermediate["Num Visits"] = n_visits 137 | intermediate["Num Positive Codes"] = n_pos_codes 138 | intermediate["Num Total Codes"] = n_total_codes 139 | pickle.dump(intermediate, open("../../results/testing_stats/LSTMBaseline_intermediate_results.pkl", "wb")) 140 | 141 | #Extract, save, and display test metrics 142 | avg_loss = np.mean(loss_list) 143 | tn, fp, fn, tp = confusion_matrix.ravel() 144 | acc = (tn + tp)/(tn+fp+fn+tp) 145 | prc = tp/(tp+fp) 146 | rec = tp/(tp+fn) 147 | f1 = (2 * prc * rec)/(prc + rec) 148 | log_probability = np.sum(probability_list) 149 | pp_visit = np.exp(-log_probability/n_visits) 150 | pp_positive = np.exp(-log_probability/n_pos_codes) 151 | pp_possible = np.exp(-log_probability/n_total_codes) 152 | 153 | metrics_dict = {} 154 | metrics_dict['Test Loss'] = avg_loss 155 | metrics_dict['Confusion Matrix'] = confusion_matrix 156 | metrics_dict['Accuracy'] = acc 157 | metrics_dict['Precision'] = prc 158 | metrics_dict['Recall'] = rec 159 | metrics_dict['F1 Score'] = f1 160 | metrics_dict['Test Log Probability'] = log_probability 161 | metrics_dict['Perplexity Per Visit'] = pp_visit 162 | metrics_dict['Perplexity Per Positive Code'] = pp_positive 163 | metrics_dict['Perplexity Per Possible Code'] = pp_possible 164 | pickle.dump(metrics_dict, open("../../results/testing_stats/LSTMBaseline_Metrics.pkl", "wb")) 165 | 166 | print("Average Test Loss: ", avg_loss) 167 | print("Confusion Matrix: ", confusion_matrix) 168 | print('Accuracy: ', acc) 169 | print('Precision: ', prc) 170 | print('Recall: ', rec) 171 | print('F1 Score: ', f1) 172 | print('Test Log Probability: ', log_probability) 173 | print('Perplexity Per Visit: ', pp_visit) 174 | print('Perplexity Per Positive Code: ', pp_positive) 175 | print('Perplexity Per Possible Code: ', pp_possible) 176 | 177 | def sample_sequence(model, length, context, batch_size=None, device='cuda', sample=True): 178 | context = torch.tensor(context, device=device, dtype=torch.float32).unsqueeze(0).repeat(batch_size, 1) 179 | prev = context.unsqueeze(1) 180 | with torch.no_grad(): 181 | for i in range(length-1): 182 | code_probs = model(prev) 183 | code_probs = code_probs[:, -1, :].unsqueeze(1) 184 | if sample: 185 | visit = torch.bernoulli(code_probs) 186 | else: 187 | visit = torch.round(code_probs) 188 | prev = torch.cat((prev, visit), dim=1) 189 | ehr = prev.cpu().detach().numpy() 190 | visit = None 191 | prev = None 192 | return ehr 193 | 194 | def convert_ehr(ehrs, index_to_code=None): 195 | ehr_outputs = [] 196 | for i in range(len(ehrs)): 197 | ehr = ehrs[i] 198 | ehr_output = [] 199 | labels_output = ehr[1][config.code_vocab_size:config.code_vocab_size+config.label_vocab_size] 200 | if index_to_code is not None: 201 | labels_output = [index_to_code[idx + config.code_vocab_size] for idx in np.nonzero(labels_output)[0]] 202 | for j in range(2, len(ehr)): 203 | visit = ehr[j] 204 | visit_output = [] 205 | indices = np.nonzero(visit)[0] 206 | end = False 207 | for idx in indices: 208 | if idx < config.code_vocab_size: 209 | visit_output.append(index_to_code[idx] if index_to_code is not None else idx) 210 | elif idx == config.code_vocab_size+config.label_vocab_size+1: 211 | end = True 212 | if visit_output != []: 213 | ehr_output.append(visit_output) 214 | if end: 215 | break 216 | ehr_outputs.append({'visits': ehr_output, 'labels': labels_output}) 217 | ehr = None 218 | ehr_output = None 219 | labels_output = None 220 | visit = None 221 | visit_output = None 222 | indices = None 223 | return ehr_outputs 224 | 225 | # Generate Synthetic EHR dataset 226 | synthetic_ehr_dataset = [] 227 | stoken = np.zeros(config.total_vocab_size) 228 | stoken[config.code_vocab_size+config.label_vocab_size] = 1 229 | for i in tqdm(range(0, len(train_ehr_dataset), config.batch_size)): 230 | bs = min([len(train_ehr_dataset)-i, config.batch_size]) 231 | batch_synthetic_ehrs = sample_sequence(model, config.n_ctx, stoken, batch_size=bs, device=device, sample=True) 232 | batch_synthetic_ehrs = convert_ehr(batch_synthetic_ehrs) 233 | synthetic_ehr_dataset += batch_synthetic_ehrs 234 | 235 | pickle.dump(synthetic_ehr_dataset, open(f'../../results/datasets/lstmDataset.pkl', 'wb')) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | code by Brandon Theodorou 3 | Original GPT-2 Paper and repository here: https://github.com/openai/gpt-2 4 | Original GPT-2 Pytorch Model: https://github.com/huggingface/pytorch-pretrained-BERT 5 | GPT-2 Pytorch Model Derived From: https://github.com/graykode/gpt-2-Pytorch 6 | ''' 7 | import copy 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | def gelu(x): 14 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 15 | 16 | class LayerNorm(nn.Module): 17 | def __init__(self, hidden_size, eps=1e-12): 18 | """Construct a layernorm module in the TF style (epsilon inside the square root).""" 19 | super(LayerNorm, self).__init__() 20 | self.weight = nn.Parameter(torch.ones(hidden_size)) 21 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 22 | self.variance_epsilon = eps 23 | 24 | def forward(self, x): 25 | u = x.mean(-1, keepdim=True) 26 | s = (x - u).pow(2).mean(-1, keepdim=True) 27 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 28 | return self.weight * x + self.bias 29 | 30 | class Conv1D(nn.Module): 31 | def __init__(self, nf, nx): 32 | super(Conv1D, self).__init__() 33 | self.nf = nf 34 | w = torch.empty(nx, nf) 35 | nn.init.normal_(w, std=0.02) 36 | self.weight = nn.Parameter(w) 37 | self.bias = nn.Parameter(torch.zeros(nf)) 38 | 39 | def forward(self, x): 40 | size_out = x.size()[:-1] + (self.nf,) 41 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 42 | x = x.view(*size_out) 43 | return x 44 | 45 | class Attention(nn.Module): 46 | def __init__(self, nx, n_ctx, config, scale=False): 47 | super(Attention, self).__init__() 48 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 49 | assert n_state % config.n_head == 0 50 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 51 | self.n_head = config.n_head 52 | self.split_size = n_state 53 | self.scale = scale 54 | self.c_attn = Conv1D(n_state * 3, nx) 55 | self.c_proj = Conv1D(n_state, nx) 56 | 57 | def _attn(self, q, k, v): 58 | w = torch.matmul(q, k) 59 | if self.scale: 60 | w = w / math.sqrt(v.size(-1)) 61 | nd, ns = w.size(-2), w.size(-1) 62 | b = self.bias[:, :, ns-nd:ns, :ns] 63 | w = w * b - 1e10 * (1 - b) 64 | w = nn.Softmax(dim=-1)(w) 65 | return torch.matmul(w, v) 66 | 67 | def merge_heads(self, x): 68 | x = x.permute(0, 2, 1, 3).contiguous() 69 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 70 | return x.view(*new_x_shape) 71 | 72 | def split_heads(self, x, k=False): 73 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 74 | x = x.view(*new_x_shape) 75 | if k: 76 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 77 | else: 78 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 79 | 80 | def forward(self, x, layer_past=None): 81 | x = self.c_attn(x) 82 | query, key, value = x.split(self.split_size, dim=2) 83 | query = self.split_heads(query) 84 | key = self.split_heads(key, k=True) 85 | value = self.split_heads(value) 86 | if layer_past is not None: 87 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 88 | key = torch.cat((past_key, key), dim=-1) 89 | value = torch.cat((past_value, value), dim=-2) 90 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 91 | a = self._attn(query, key, value) 92 | a = self.merge_heads(a) 93 | a = self.c_proj(a) 94 | return a, present 95 | 96 | class MLP(nn.Module): 97 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 98 | super(MLP, self).__init__() 99 | nx = config.n_embd 100 | self.c_fc = Conv1D(n_state, nx) 101 | self.c_proj = Conv1D(nx, n_state) 102 | self.act = gelu 103 | 104 | def forward(self, x): 105 | h = self.act(self.c_fc(x)) 106 | h2 = self.c_proj(h) 107 | return h2 108 | 109 | class Block(nn.Module): 110 | def __init__(self, n_ctx, config, scale=False): 111 | super(Block, self).__init__() 112 | nx = config.n_embd 113 | self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) 114 | self.attn = Attention(nx, n_ctx, config, scale) 115 | self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) 116 | self.mlp = MLP(4 * nx, config) 117 | 118 | def forward(self, x, layer_past=None): 119 | a, present = self.attn(self.ln_1(x), layer_past=layer_past) 120 | x = x + a 121 | m = self.mlp(self.ln_2(x)) 122 | x = x + m 123 | return x, present 124 | 125 | class CoarseTransformerModel(nn.Module): 126 | def __init__(self, config): 127 | super(CoarseTransformerModel, self).__init__() 128 | self.n_layer = config.n_layer 129 | self.n_embd = config.n_embd 130 | self.n_vocab = config.total_vocab_size 131 | 132 | self.vis_embed_mat = nn.Linear(config.total_vocab_size, config.n_embd, bias=False) 133 | self.pos_embed_mat = nn.Embedding(config.n_positions, config.n_embd) 134 | block = Block(config.n_ctx, config, scale=True) 135 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) 136 | self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) 137 | 138 | def forward(self, input_visits, position_ids=None, past=None): 139 | if past is None: 140 | past_length = 0 141 | past = [None] * len(self.h) 142 | else: 143 | past_length = past[0][0].size(-2) 144 | if position_ids is None: 145 | position_ids = torch.arange(past_length, input_visits.size(1) + past_length, dtype=torch.long, 146 | device=input_visits.device) 147 | position_ids = position_ids.unsqueeze(0).expand(input_visits.size(0), input_visits.size(1)) 148 | 149 | inputs_embeds = self.vis_embed_mat(input_visits) 150 | position_embeds = self.pos_embed_mat(position_ids) 151 | hidden_states = inputs_embeds + position_embeds 152 | for block, layer_past in zip(self.h, past): 153 | hidden_states, _ = block(hidden_states, layer_past) 154 | hidden_states = self.ln_f(hidden_states) 155 | return hidden_states 156 | 157 | class AutoregressiveLinear(nn.Linear): 158 | """ same as Linear except has a configurable mask on the weights """ 159 | def __init__(self, in_features, out_features, bias=True): 160 | super().__init__(in_features, out_features, bias) 161 | self.register_buffer('mask', torch.tril(torch.ones(in_features, out_features)).int()) 162 | 163 | def forward(self, input): 164 | return F.linear(input, self.mask * self.weight, self.bias) 165 | 166 | class FineAutoregressiveHead(nn.Module): 167 | def __init__(self, config): 168 | super(FineAutoregressiveHead, self).__init__() 169 | self.auto1 = AutoregressiveLinear(config.n_embd + config.total_vocab_size, config.n_embd + config.total_vocab_size) 170 | self.auto2 = AutoregressiveLinear(config.n_embd + config.total_vocab_size, config.n_embd + config.total_vocab_size) 171 | self.n_embd = config.n_embd 172 | self.tot_vocab = config.total_vocab_size 173 | 174 | def forward(self, history, input_visits): 175 | history = history[:,:-1,:] 176 | input_visits = input_visits[:,1:,:] 177 | code_logits = self.auto2(torch.relu(self.auto1(torch.cat((history, input_visits), dim=2))))[:,:,self.n_embd-1:-1] 178 | return code_logits 179 | 180 | def sample(self, history, input_visits): 181 | history = history[:,:-1,:] 182 | input_visits = input_visits[:,1:,:] 183 | currVisit = torch.cat((history, input_visits), dim=2)[:,-1,:].unsqueeze(1) 184 | code_logits = self.auto2(torch.relu(self.auto1(currVisit)))[:,:,self.n_embd-1:-1] 185 | return code_logits 186 | 187 | class HALOModel(nn.Module): 188 | def __init__(self, config): 189 | super(HALOModel, self).__init__() 190 | self.transformer = CoarseTransformerModel(config) 191 | self.ehr_head = FineAutoregressiveHead(config) 192 | 193 | def forward(self, input_visits, position_ids=None, ehr_labels=None, ehr_masks=None, past=None, pos_loss_weight=None): 194 | hidden_states = self.transformer(input_visits, position_ids, past) 195 | code_logits = self.ehr_head(hidden_states, input_visits) 196 | sig = nn.Sigmoid() 197 | code_probs = sig(code_logits) 198 | if ehr_labels is not None: 199 | shift_labels = ehr_labels[..., 1:, :].contiguous() 200 | loss_weights = None 201 | if pos_loss_weight is not None: 202 | loss_weights = torch.ones(code_probs.shape, device=code_probs.device) 203 | loss_weights = loss_weights + (pos_loss_weight-1) * shift_labels 204 | if ehr_masks is not None: 205 | code_probs = code_probs * ehr_masks 206 | shift_labels = shift_labels * ehr_masks 207 | if pos_loss_weight is not None: 208 | loss_weights = loss_weights * ehr_masks 209 | 210 | bce = nn.BCELoss(weight=loss_weights) 211 | loss = bce(code_probs, shift_labels) 212 | return loss, code_probs, shift_labels 213 | 214 | return code_probs 215 | 216 | def sample(self, input_visits, random=True): 217 | sig = nn.Sigmoid() 218 | hidden_states = self.transformer(input_visits) 219 | i = 0 220 | while i < self.ehr_head.tot_vocab: 221 | next_logits = self.ehr_head.sample(hidden_states, input_visits) 222 | next_probs = sig(next_logits) 223 | if random: 224 | visit = torch.bernoulli(next_probs) 225 | else: 226 | visit = torch.round(next_probs) 227 | 228 | remaining_visit = visit[:,0,i:] 229 | nonzero = torch.nonzero(remaining_visit, as_tuple=True)[1] 230 | if nonzero.numel() == 0: 231 | break 232 | 233 | first_nonzero = nonzero.min() 234 | input_visits[:,-1,i + first_nonzero] = visit[:,0,i + first_nonzero] 235 | i = i + first_nonzero + 1 236 | 237 | return input_visits --------------------------------------------------------------------------------