├── .gitignore ├── README.md ├── data └── icd9.txt ├── metrics.py ├── models ├── __init__.py ├── layers.py ├── model.py └── utils.py ├── preprocess ├── __init__.py ├── auxiliary.py ├── build_dataset.py ├── encode.py └── parse_csv.py ├── requirements.txt ├── run_preprocess.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # IDEA 132 | .idea/ 133 | 134 | # Data 135 | /data/mimic3/ 136 | 137 | # Baselines 138 | /baselines/ 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Context-Aware-Healthcare 2 | 3 | Codes for AAAI 2022 paper: *Context-aware Health Event Prediction via Transition Functions on Dynamic Disease Graphs* 4 | 5 | ## Download the MIMIC-III and MIMIC-IV datasets 6 | Go to [https://mimic.physionet.org/](https://mimic.physionet.org/gettingstarted/access/) for access. Once you have the authority for the dataset, download the dataset and extract the csv files to `data/mimic3/raw/` and `data/mimic4/raw/` in this project. 7 | 8 | ## Preprocess 9 | ```bash 10 | python run_preprocess.py 11 | ``` 12 | 13 | ## Train model 14 | ```bash 15 | python train.py 16 | ``` 17 | 18 | ## Configuration 19 | Please see `train.py` for detailed configurations. 20 | -------------------------------------------------------------------------------- /data/icd9.txt: -------------------------------------------------------------------------------- 1 | 001-139 2 | 001-009 3 | 010-018 4 | 020-027 5 | 030-041 6 | 042-042 7 | 045-049 8 | 050-059 9 | 060-066 10 | 070-079 11 | 080-088 12 | 090-099 13 | 100-104 14 | 110-118 15 | 120-129 16 | 130-136 17 | 137-139 18 | 140-239 19 | 140-149 20 | 150-159 21 | 160-165 22 | 170-176 23 | 179-189 24 | 190-199 25 | 200-209 26 | 210-229 27 | 230-234 28 | 235-238 29 | 239-239 30 | 240-279 31 | 240-246 32 | 249-259 33 | 260-269 34 | 270-279 35 | 280-289 36 | 280 37 | 281 38 | 282 39 | 283 40 | 284 41 | 285 42 | 286 43 | 287 44 | 288 45 | 289 46 | 290-319 47 | 290-294 48 | 295-299 49 | 300-316 50 | 317-319 51 | 320-389 52 | 320-327 53 | 330-337 54 | 338-338 55 | 339-339 56 | 340-349 57 | 350-359 58 | 360-379 59 | 380-389 60 | 390-459 61 | 390-392 62 | 393-398 63 | 401-405 64 | 410-414 65 | 415-417 66 | 420-429 67 | 430-438 68 | 440-449 69 | 451-459 70 | 460-519 71 | 460-466 72 | 470-478 73 | 480-488 74 | 490-496 75 | 500-508 76 | 510-519 77 | 520-579 78 | 520-529 79 | 530-539 80 | 540-543 81 | 550-553 82 | 555-558 83 | 560-569 84 | 570-579 85 | 580-629 86 | 580-589 87 | 590-599 88 | 600-608 89 | 610-612 90 | 614-616 91 | 617-629 92 | 630-679 93 | 630-639 94 | 640-649 95 | 650-659 96 | 660-669 97 | 670-677 98 | 678-679 99 | 680-709 100 | 680-686 101 | 690-698 102 | 700-709 103 | 710-739 104 | 710-719 105 | 720-724 106 | 725-729 107 | 730-739 108 | 740-759 109 | 740 110 | 741 111 | 742 112 | 743 113 | 744 114 | 745 115 | 746 116 | 747 117 | 748 118 | 749 119 | 750 120 | 751 121 | 752 122 | 753 123 | 754 124 | 755 125 | 756 126 | 757 127 | 758 128 | 759 129 | 760-779 130 | 760-763 131 | 764-779 132 | 780-799 133 | 780-789 134 | 790-796 135 | 797-799 136 | 800-999 137 | 800-804 138 | 805-809 139 | 810-819 140 | 820-829 141 | 830-839 142 | 840-848 143 | 850-854 144 | 860-869 145 | 870-879 146 | 880-887 147 | 890-897 148 | 900-904 149 | 905-909 150 | 910-919 151 | 920-924 152 | 925-929 153 | 930-939 154 | 940-949 155 | 950-957 156 | 958-959 157 | 960-979 158 | 980-989 159 | 990-995 160 | 996-999 161 | V01-V9 162 | V01-V09 163 | V10-V19 164 | V20-V29 165 | V30-V39 166 | V40-V49 167 | V50-V59 168 | V60-V69 169 | V70-V82 170 | V83-V84 171 | V85-V85 172 | V86-V86 173 | V87-V87 174 | V88-V88 175 | V89-V89 176 | V90-V90 177 | V91-V91 178 | E000-E999 179 | E000-E000 180 | E001-E030 181 | E800-E807 182 | E810-E819 183 | E820-E825 184 | E826-E829 185 | E830-E838 186 | E840-E845 187 | E846-E849 188 | E850-E858 189 | E860-E869 190 | E870-E876 191 | E878-E879 192 | E880-E888 193 | E890-E899 194 | E900-E909 195 | E910-E915 196 | E916-E928 197 | E929-E929 198 | E930-E949 199 | E950-E959 200 | E960-E969 201 | E970-E979 202 | E980-E989 203 | E990-E999 -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import f1_score, roc_auc_score 4 | 5 | 6 | def f1(y_true_hot, y_pred, metrics='weighted'): 7 | result = np.zeros_like(y_true_hot) 8 | for i in range(len(result)): 9 | true_number = np.sum(y_true_hot[i] == 1) 10 | result[i][y_pred[i][:true_number]] = 1 11 | return f1_score(y_true=y_true_hot, y_pred=result, average=metrics, zero_division=0) 12 | 13 | 14 | def top_k_prec_recall(y_true_hot, y_pred, ks): 15 | a = np.zeros((len(ks),)) 16 | r = np.zeros((len(ks),)) 17 | for pred, true_hot in zip(y_pred, y_true_hot): 18 | true = np.where(true_hot == 1)[0].tolist() 19 | t = set(true) 20 | for i, k in enumerate(ks): 21 | p = set(pred[:k]) 22 | it = p.intersection(t) 23 | a[i] += len(it) / k 24 | # r[i] += len(it) / min(k, len(t)) 25 | r[i] += len(it) / len(t) 26 | return a / len(y_true_hot), r / len(y_true_hot) 27 | 28 | 29 | def calculate_occurred(historical, y, preds, ks): 30 | # y_occurred = np.sum(np.logical_and(historical, y), axis=-1) 31 | # y_prec = np.mean(y_occurred / np.sum(y, axis=-1)) 32 | r1 = np.zeros((len(ks), )) 33 | r2 = np.zeros((len(ks),)) 34 | n = np.sum(y, axis=-1) 35 | for i, k in enumerate(ks): 36 | # n_k = np.minimum(n, k) 37 | n_k = n 38 | pred_k = np.zeros_like(y) 39 | for T in range(len(pred_k)): 40 | pred_k[T][preds[T][:k]] = 1 41 | # pred_occurred = np.sum(np.logical_and(historical, pred_k), axis=-1) 42 | pred_occurred = np.logical_and(historical, pred_k) 43 | pred_not_occurred = np.logical_and(np.logical_not(historical), pred_k) 44 | pred_occurred_true = np.logical_and(pred_occurred, y) 45 | pred_not_occurred_true = np.logical_and(pred_not_occurred, y) 46 | r1[i] = np.mean(np.sum(pred_occurred_true, axis=-1) / n_k) 47 | r2[i] = np.mean(np.sum(pred_not_occurred_true, axis=-1) / n_k) 48 | return r1, r2 49 | 50 | 51 | def evaluate_codes(model, dataset, loss_fn, output_size, historical=None): 52 | model.eval() 53 | total_loss = 0.0 54 | labels = dataset.label() 55 | preds = [] 56 | for step in range(len(dataset)): 57 | code_x, visit_lens, divided, y, neighbors = dataset[step] 58 | output = model(code_x, divided, neighbors, visit_lens) 59 | pred = torch.argsort(output, dim=-1, descending=True) 60 | preds.append(pred) 61 | loss = loss_fn(output, y) 62 | total_loss += loss.item() * output_size * len(code_x) 63 | print('\r Evaluating step %d / %d' % (step + 1, len(dataset)), end='') 64 | avg_loss = total_loss / dataset.size() 65 | preds = torch.vstack(preds).detach().cpu().numpy() 66 | f1_score = f1(labels, preds) 67 | prec, recall = top_k_prec_recall(labels, preds, ks=[10, 20, 30, 40]) 68 | if historical is not None: 69 | r1, r2 = calculate_occurred(historical, labels, preds, ks=[10, 20, 30, 40]) 70 | print('\r Evaluation: loss: %.4f --- f1_score: %.4f --- top_k_recall: %.4f, %.4f, %.4f, %.4f --- occurred: %.4f, %.4f, %.4f, %.4f --- not occurred: %.4f, %.4f, %.4f, %.4f' 71 | % (avg_loss, f1_score, recall[0], recall[1], recall[2], recall[3], r1[0], r1[1], r1[2], r1[3], r2[0], r2[1], r2[2], r2[3])) 72 | else: 73 | print('\r Evaluation: loss: %.4f --- f1_score: %.4f --- top_k_recall: %.4f, %.4f, %.4f, %.4f' 74 | % (avg_loss, f1_score, recall[0], recall[1], recall[2], recall[3])) 75 | return avg_loss, f1_score 76 | 77 | 78 | def evaluate_hf(model, dataset, loss_fn, output_size=1, historical=None): 79 | model.eval() 80 | total_loss = 0.0 81 | labels = dataset.label() 82 | outputs = [] 83 | preds = [] 84 | for step in range(len(dataset)): 85 | code_x, visit_lens, divided, y, neighbors = dataset[step] 86 | output = model(code_x, divided, neighbors, visit_lens).squeeze() 87 | loss = loss_fn(output, y) 88 | total_loss += loss.item() * output_size * len(code_x) 89 | output = output.detach().cpu().numpy() 90 | outputs.append(output) 91 | pred = (output > 0.5).astype(int) 92 | preds.append(pred) 93 | print('\r Evaluating step %d / %d' % (step + 1, len(dataset)), end='') 94 | avg_loss = total_loss / dataset.size() 95 | outputs = np.concatenate(outputs) 96 | preds = np.concatenate(preds) 97 | auc = roc_auc_score(labels, outputs) 98 | f1_score_ = f1_score(labels, preds) 99 | print('\r Evaluation: loss: %.4f --- auc: %.4f --- f1_score: %.4f' % (avg_loss, auc, f1_score_)) 100 | return avg_loss, f1_score_ 101 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LuChang-CS/Chet/2cec030c09207b99023fdcd53327e8c6dc488ff4/models/__init__.py -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from models.utils import SingleHeadAttentionLayer 5 | 6 | 7 | class EmbeddingLayer(nn.Module): 8 | def __init__(self, code_num, code_size, graph_size): 9 | super().__init__() 10 | self.code_num = code_num 11 | self.c_embeddings = nn.Parameter(data=nn.init.xavier_uniform_(torch.empty(code_num, code_size))) 12 | self.n_embeddings = nn.Parameter(data=nn.init.xavier_uniform_(torch.empty(code_num, code_size))) 13 | self.u_embeddings = nn.Parameter(data=nn.init.xavier_uniform_(torch.empty(code_num, graph_size))) 14 | 15 | def forward(self): 16 | return self.c_embeddings, self.n_embeddings, self.u_embeddings 17 | 18 | 19 | class GraphLayer(nn.Module): 20 | def __init__(self, adj, code_size, graph_size): 21 | super().__init__() 22 | self.adj = adj 23 | self.dense = nn.Linear(code_size, graph_size) 24 | self.activation = nn.LeakyReLU() 25 | 26 | def forward(self, code_x, neighbor, c_embeddings, n_embeddings): 27 | center_codes = torch.unsqueeze(code_x, dim=-1) 28 | neighbor_codes = torch.unsqueeze(neighbor, dim=-1) 29 | 30 | center_embeddings = center_codes * c_embeddings 31 | neighbor_embeddings = neighbor_codes * n_embeddings 32 | cc_embeddings = center_codes * torch.matmul(self.adj, center_embeddings) 33 | cn_embeddings = center_codes * torch.matmul(self.adj, neighbor_embeddings) 34 | nn_embeddings = neighbor_codes * torch.matmul(self.adj, neighbor_embeddings) 35 | nc_embeddings = neighbor_codes * torch.matmul(self.adj, center_embeddings) 36 | 37 | co_embeddings = self.activation(self.dense(center_embeddings + cc_embeddings + cn_embeddings)) 38 | no_embeddings = self.activation(self.dense(neighbor_embeddings + nn_embeddings + nc_embeddings)) 39 | return co_embeddings, no_embeddings 40 | 41 | 42 | class TransitionLayer(nn.Module): 43 | def __init__(self, code_num, graph_size, hidden_size, t_attention_size, t_output_size): 44 | super().__init__() 45 | self.gru = nn.GRUCell(input_size=graph_size, hidden_size=hidden_size) 46 | self.single_head_attention = SingleHeadAttentionLayer(graph_size, graph_size, t_output_size, t_attention_size) 47 | self.activation = nn.Tanh() 48 | 49 | self.code_num = code_num 50 | self.hidden_size = hidden_size 51 | 52 | def forward(self, t, co_embeddings, divided, no_embeddings, unrelated_embeddings, hidden_state=None): 53 | m1, m2, m3 = divided[:, 0], divided[:, 1], divided[:, 2] 54 | m1_index = torch.where(m1 > 0)[0] 55 | m2_index = torch.where(m2 > 0)[0] 56 | m3_index = torch.where(m3 > 0)[0] 57 | h_new = torch.zeros((self.code_num, self.hidden_size), dtype=co_embeddings.dtype).to(co_embeddings.device) 58 | output_m1 = 0 59 | output_m23 = 0 60 | if len(m1_index) > 0: 61 | m1_embedding = co_embeddings[m1_index] 62 | h = hidden_state[m1_index] if hidden_state is not None else None 63 | h_m1 = self.gru(m1_embedding, h) 64 | h_new[m1_index] = h_m1 65 | output_m1, _ = torch.max(h_m1, dim=-2) 66 | if t > 0 and len(m2_index) + len(m3_index) > 0: 67 | q = torch.vstack([no_embeddings[m2_index], unrelated_embeddings[m3_index]]) 68 | v = torch.vstack([co_embeddings[m2_index], co_embeddings[m3_index]]) 69 | h_m23 = self.activation(self.single_head_attention(q, q, v)) 70 | h_new[m2_index] = h_m23[:len(m2_index)] 71 | h_new[m3_index] = h_m23[len(m2_index):] 72 | output_m23, _ = torch.max(h_m23, dim=-2) 73 | if len(m1_index) == 0: 74 | output = output_m23 75 | elif len(m2_index) + len(m3_index) == 0: 76 | output = output_m1 77 | else: 78 | output, _ = torch.max(torch.vstack([output_m1, output_m23]), dim=-2) 79 | return output, h_new 80 | 81 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from models.layers import EmbeddingLayer, GraphLayer, TransitionLayer 5 | from models.utils import DotProductAttention 6 | 7 | 8 | class Classifier(nn.Module): 9 | def __init__(self, input_size, output_size, dropout_rate=0., activation=None): 10 | super().__init__() 11 | self.linear = nn.Linear(input_size, output_size) 12 | self.activation = activation 13 | self.dropout = nn.Dropout(p=dropout_rate) 14 | 15 | def forward(self, x): 16 | output = self.dropout(x) 17 | output = self.linear(output) 18 | if self.activation is not None: 19 | output = self.activation(output) 20 | return output 21 | 22 | 23 | class Model(nn.Module): 24 | def __init__(self, code_num, code_size, 25 | adj, graph_size, hidden_size, t_attention_size, t_output_size, 26 | output_size, dropout_rate, activation): 27 | super().__init__() 28 | self.embedding_layer = EmbeddingLayer(code_num, code_size, graph_size) 29 | self.graph_layer = GraphLayer(adj, code_size, graph_size) 30 | self.transition_layer = TransitionLayer(code_num, graph_size, hidden_size, t_attention_size, t_output_size) 31 | self.attention = DotProductAttention(hidden_size, 32) 32 | self.classifier = Classifier(hidden_size, output_size, dropout_rate, activation) 33 | 34 | def forward(self, code_x, divided, neighbors, lens): 35 | embeddings = self.embedding_layer() 36 | c_embeddings, n_embeddings, u_embeddings = embeddings 37 | output = [] 38 | for code_x_i, divided_i, neighbor_i, len_i in zip(code_x, divided, neighbors, lens): 39 | no_embeddings_i_prev = None 40 | output_i = [] 41 | h_t = None 42 | for t, (c_it, d_it, n_it, len_it) in enumerate(zip(code_x_i, divided_i, neighbor_i, range(len_i))): 43 | co_embeddings, no_embeddings = self.graph_layer(c_it, n_it, c_embeddings, n_embeddings) 44 | output_it, h_t = self.transition_layer(t, co_embeddings, d_it, no_embeddings_i_prev, u_embeddings, h_t) 45 | no_embeddings_i_prev = no_embeddings 46 | output_i.append(output_it) 47 | output_i = self.attention(torch.vstack(output_i)) 48 | output.append(output_i) 49 | output = torch.vstack(output) 50 | output = self.classifier(output) 51 | return output 52 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class SingleHeadAttentionLayer(nn.Module): 8 | def __init__(self, query_size, key_size, value_size, attention_size): 9 | super().__init__() 10 | self.attention_size = attention_size 11 | self.dense_q = nn.Linear(query_size, attention_size) 12 | self.dense_k = nn.Linear(key_size, attention_size) 13 | self.dense_v = nn.Linear(query_size, value_size) 14 | 15 | def forward(self, q, k, v): 16 | query = self.dense_q(q) 17 | key = self.dense_k(k) 18 | value = self.dense_v(v) 19 | g = torch.div(torch.matmul(query, key.T), math.sqrt(self.attention_size)) 20 | score = torch.softmax(g, dim=-1) 21 | output = torch.sum(torch.unsqueeze(score, dim=-1) * value, dim=-2) 22 | return output 23 | 24 | 25 | class DotProductAttention(nn.Module): 26 | def __init__(self, value_size, attention_size): 27 | super().__init__() 28 | self.attention_size = attention_size 29 | self.context = nn.Parameter(data=nn.init.xavier_uniform_(torch.empty(attention_size, 1))) 30 | self.dense = nn.Linear(value_size, attention_size) 31 | 32 | def forward(self, x): 33 | t = self.dense(x) 34 | vu = torch.matmul(t, self.context).squeeze() 35 | score = torch.softmax(vu, dim=-1) 36 | output = torch.sum(x * torch.unsqueeze(score, dim=-1), dim=-2) 37 | return output 38 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | 6 | def save_sparse(path, x): 7 | idx = np.where(x > 0) 8 | values = x[idx] 9 | np.savez(path, idx=idx, values=values, shape=x.shape) 10 | 11 | 12 | def load_sparse(path): 13 | data = np.load(path) 14 | idx, values = data['idx'], data['values'] 15 | mat = np.zeros(data['shape'], dtype=values.dtype) 16 | mat[tuple(idx)] = values 17 | return mat 18 | 19 | 20 | def save_data(path, code_x, visit_lens, codes_y, hf_y, divided, neighbors): 21 | save_sparse(os.path.join(path, 'code_x'), code_x) 22 | np.savez(os.path.join(path, 'visit_lens'), lens=visit_lens) 23 | save_sparse(os.path.join(path, 'code_y'), codes_y) 24 | np.savez(os.path.join(path, 'hf_y'), hf_y=hf_y) 25 | save_sparse(os.path.join(path, 'divided'), divided) 26 | save_sparse(os.path.join(path, 'neighbors'), neighbors) 27 | -------------------------------------------------------------------------------- /preprocess/auxiliary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from preprocess.parse_csv import EHRParser 4 | 5 | 6 | def generate_code_code_adjacent(pids, patient_admission, admission_codes_encoded, code_num, threshold=0.01): 7 | print('generating code code adjacent matrix ...') 8 | n = code_num 9 | adj = np.zeros((n, n), dtype=int) 10 | for i, pid in enumerate(pids): 11 | print('\r\t%d / %d' % (i, len(pids)), end='') 12 | for admission in patient_admission[pid]: 13 | codes = admission_codes_encoded[admission[EHRParser.adm_id_col]] 14 | for row in range(len(codes) - 1): 15 | for col in range(row + 1, len(codes)): 16 | c_i = codes[row] 17 | c_j = codes[col] 18 | adj[c_i, c_j] += 1 19 | adj[c_j, c_i] += 1 20 | print('\r\t%d / %d' % (len(pids), len(pids))) 21 | norm_adj = normalize_adj(adj) 22 | a = norm_adj < threshold 23 | b = adj.sum(axis=-1, keepdims=True) > (1 / threshold) 24 | adj[np.logical_and(a, b)] = 0 25 | return adj 26 | 27 | 28 | def normalize_adj(adj): 29 | s = adj.sum(axis=-1, keepdims=True) 30 | s[s == 0] = 1 31 | result = adj / s 32 | return result 33 | 34 | 35 | def generate_neighbors(code_x, lens, adj): 36 | n = len(code_x) 37 | neighbors = np.zeros_like(code_x, dtype=bool) 38 | # a = 0 39 | # b = 100000 40 | # c = -1 41 | # nn = 0 42 | for i, admissions in enumerate(code_x): 43 | print('\r\t%d / %d' % (i + 1, n), end='') 44 | for j in range(lens[i]): 45 | codes_set = set(np.where(admissions[j] == 1)[0]) 46 | all_neighbors = set() 47 | for code in codes_set: 48 | code_neighbors = set(np.where(adj[code] > 0)[0]).difference(codes_set) 49 | all_neighbors.update(code_neighbors) 50 | if len(all_neighbors) > 0: 51 | neighbors[i, j, np.array(list(all_neighbors))] = 1 52 | # a += len(all_neighbors) 53 | # if b > len(all_neighbors): 54 | # b = len(all_neighbors) 55 | # if c < len(all_neighbors): 56 | # c = len(all_neighbors) 57 | # nn += 1 58 | print('\r\t%d / %d' % (n, n)) 59 | # print(b, c, a / nn);exit() 60 | return neighbors 61 | 62 | 63 | def divide_middle(code_x, neighbors, lens): 64 | n = len(code_x) 65 | divided = np.zeros((*code_x.shape, 3), dtype=bool) 66 | for i, admissions in enumerate(code_x): 67 | print('\r\t%d / %d' % (i + 1, n), end='') 68 | divided[i, 0, :, 0] = admissions[0] 69 | for j in range(1, lens[i]): 70 | codes_set = set(np.where(admissions[j] == 1)[0]) 71 | m_set = set(np.where(admissions[j - 1] == 1)[0]) 72 | n_set = set(np.where(neighbors[i][j - 1] == 1)[0]) 73 | m1 = codes_set.intersection(m_set) 74 | m2 = codes_set.intersection(n_set) 75 | m3 = codes_set.difference(m_set).difference(n_set) 76 | if len(m1) > 0: 77 | divided[i, j, np.array(list(m1)), 0] = 1 78 | if len(m2) > 0: 79 | divided[i, j, np.array(list(m2)), 1] = 1 80 | if len(m3) > 0: 81 | divided[i, j, np.array(list(m3)), 2] = 1 82 | print('\r\t%d / %d' % (n, n)) 83 | return divided 84 | 85 | 86 | def parse_icd9_range(range_: str) -> (str, str, int, int): 87 | ranges = range_.lstrip().split('-') 88 | if ranges[0][0] == 'V': 89 | prefix = 'V' 90 | format_ = '%02d' 91 | start, end = int(ranges[0][1:]), int(ranges[1][1:]) 92 | elif ranges[0][0] == 'E': 93 | prefix = 'E' 94 | format_ = '%03d' 95 | start, end = int(ranges[0][1:]), int(ranges[1][1:]) 96 | else: 97 | prefix = '' 98 | format_ = '%03d' 99 | if len(ranges) == 1: 100 | start = int(ranges[0]) 101 | end = start 102 | else: 103 | start, end = int(ranges[0]), int(ranges[1]) 104 | return prefix, format_, start, end 105 | 106 | 107 | def generate_code_levels(path, code_map: dict) -> np.ndarray: 108 | print('generating code levels ...') 109 | import os 110 | three_level_code_set = set(code.split('.')[0] for code in code_map) 111 | icd9_path = os.path.join(path, 'icd9.txt') 112 | icd9_range = list(open(icd9_path, 'r', encoding='utf-8').readlines()) 113 | three_level_dict = dict() 114 | level1, level2, level3 = (0, 0, 0) 115 | level1_can_add = False 116 | for range_ in icd9_range: 117 | range_ = range_.rstrip() 118 | if range_[0] == ' ': 119 | prefix, format_, start, end = parse_icd9_range(range_) 120 | level2_cannot_add = True 121 | for i in range(start, end + 1): 122 | code = prefix + format_ % i 123 | if code in three_level_code_set: 124 | three_level_dict[code] = [level1, level2, level3] 125 | level3 += 1 126 | level1_can_add = True 127 | level2_cannot_add = False 128 | if not level2_cannot_add: 129 | level2 += 1 130 | else: 131 | if level1_can_add: 132 | level1 += 1 133 | level1_can_add = False 134 | 135 | code_level = dict() 136 | for code, cid in code_map.items(): 137 | three_level_code = code.split('.')[0] 138 | three_level = three_level_dict[three_level_code] 139 | code_level[code] = three_level + [cid] 140 | 141 | code_level_matrix = np.zeros((len(code_map), 4), dtype=int) 142 | for code, cid in code_map.items(): 143 | code_level_matrix[cid] = code_level[code] 144 | 145 | return code_level_matrix 146 | -------------------------------------------------------------------------------- /preprocess/build_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from preprocess.parse_csv import EHRParser 4 | 5 | 6 | def split_patients(patient_admission, admission_codes, code_map, train_num, test_num, seed=6669): 7 | np.random.seed(seed) 8 | common_pids = set() 9 | for i, code in enumerate(code_map): 10 | print('\r\t%.2f%%' % ((i + 1) * 100 / len(code_map)), end='') 11 | for pid, admissions in patient_admission.items(): 12 | for admission in admissions: 13 | codes = admission_codes[admission[EHRParser.adm_id_col]] 14 | if code in codes: 15 | common_pids.add(pid) 16 | break 17 | else: 18 | continue 19 | break 20 | print('\r\t100%') 21 | max_admission_num = 0 22 | pid_max_admission_num = 0 23 | for pid, admissions in patient_admission.items(): 24 | if len(admissions) > max_admission_num: 25 | max_admission_num = len(admissions) 26 | pid_max_admission_num = pid 27 | common_pids.add(pid_max_admission_num) 28 | remaining_pids = np.array(list(set(patient_admission.keys()).difference(common_pids))) 29 | np.random.shuffle(remaining_pids) 30 | 31 | valid_num = len(patient_admission) - train_num - test_num 32 | train_pids = np.array(list(common_pids.union(set(remaining_pids[:(train_num - len(common_pids))].tolist())))) 33 | valid_pids = remaining_pids[(train_num - len(common_pids)):(train_num + valid_num - len(common_pids))] 34 | test_pids = remaining_pids[(train_num + valid_num - len(common_pids)):] 35 | return train_pids, valid_pids, test_pids 36 | 37 | 38 | def build_code_xy(pids, patient_admission, admission_codes_encoded, max_admission_num, code_num): 39 | n = len(pids) 40 | x = np.zeros((n, max_admission_num, code_num), dtype=bool) 41 | y = np.zeros((n, code_num), dtype=int) 42 | lens = np.zeros((n,), dtype=int) 43 | for i, pid in enumerate(pids): 44 | print('\r\t%d / %d' % (i + 1, len(pids)), end='') 45 | admissions = patient_admission[pid] 46 | for k, admission in enumerate(admissions[:-1]): 47 | codes = admission_codes_encoded[admission[EHRParser.adm_id_col]] 48 | x[i, k, codes] = 1 49 | codes = np.array(admission_codes_encoded[admissions[-1][EHRParser.adm_id_col]]) 50 | y[i, codes] = 1 51 | lens[i] = len(admissions) - 1 52 | print('\r\t%d / %d' % (len(pids), len(pids))) 53 | return x, y, lens 54 | 55 | 56 | def build_heart_failure_y(hf_prefix, codes_y, code_map): 57 | hf_list = np.array([cid for code, cid in code_map.items() if code.startswith(hf_prefix)]) 58 | hfs = np.zeros((len(code_map),), dtype=int) 59 | hfs[hf_list] = 1 60 | hf_exist = np.logical_and(codes_y, hfs) 61 | y = (np.sum(hf_exist, axis=-1) > 0).astype(int) 62 | return y 63 | -------------------------------------------------------------------------------- /preprocess/encode.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from preprocess.parse_csv import EHRParser 4 | 5 | 6 | def encode_code(patient_admission, admission_codes): 7 | code_map = OrderedDict() 8 | for pid, admissions in patient_admission.items(): 9 | for admission in admissions: 10 | codes = admission_codes[admission[EHRParser.adm_id_col]] 11 | for code in codes: 12 | if code not in code_map: 13 | code_map[code] = len(code_map) 14 | 15 | admission_codes_encoded = { 16 | admission_id: list(set(code_map[code] for code in codes)) 17 | for admission_id, codes in admission_codes.items() 18 | } 19 | return admission_codes_encoded, code_map 20 | -------------------------------------------------------------------------------- /preprocess/parse_csv.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | from collections import OrderedDict 4 | 5 | import pandas 6 | import pandas as pd 7 | import numpy as np 8 | 9 | 10 | class EHRParser: 11 | pid_col = 'pid' 12 | adm_id_col = 'adm_id' 13 | adm_time_col = 'adm_time' 14 | cid_col = 'cid' 15 | 16 | def __init__(self, path): 17 | self.path = path 18 | 19 | self.skip_pid_check = False 20 | 21 | self.patient_admission = None 22 | self.admission_codes = None 23 | self.admission_procedures = None 24 | self.admission_medications = None 25 | 26 | self.parse_fn = {'d': self.set_diagnosis} 27 | 28 | def set_admission(self): 29 | raise NotImplementedError 30 | 31 | def set_diagnosis(self): 32 | raise NotImplementedError 33 | 34 | @staticmethod 35 | def to_standard_icd9(code: str): 36 | raise NotImplementedError 37 | 38 | def parse_admission(self): 39 | print('parsing the csv file of admission ...') 40 | filename, cols, converters = self.set_admission() 41 | admissions = pd.read_csv(os.path.join(self.path, filename), usecols=list(cols.values()), converters=converters) 42 | admissions = self._after_read_admission(admissions, cols) 43 | all_patients = OrderedDict() 44 | for i, row in admissions.iterrows(): 45 | if i % 100 == 0: 46 | print('\r\t%d in %d rows' % (i + 1, len(admissions)), end='') 47 | pid, adm_id, adm_time = row[cols[self.pid_col]], row[cols[self.adm_id_col]], row[cols[self.adm_time_col]] 48 | if pid not in all_patients: 49 | all_patients[pid] = [] 50 | admission = all_patients[pid] 51 | admission.append({self.adm_id_col: adm_id, self.adm_time_col: adm_time}) 52 | print('\r\t%d in %d rows' % (len(admissions), len(admissions))) 53 | 54 | patient_admission = OrderedDict() 55 | for pid, admissions in all_patients.items(): 56 | if len(admissions) >= 2: 57 | patient_admission[pid] = sorted(admissions, key=lambda admission: admission[self.adm_time_col]) 58 | 59 | self.patient_admission = patient_admission 60 | 61 | def _after_read_admission(self, admissions, cols): 62 | return admissions 63 | 64 | def _parse_concept(self, concept_type): 65 | assert concept_type in self.parse_fn.keys() 66 | filename, cols, converters = self.parse_fn[concept_type]() 67 | concepts = pd.read_csv(os.path.join(self.path, filename), usecols=list(cols.values()), converters=converters) 68 | concepts = self._after_read_concepts(concepts, concept_type, cols) 69 | result = OrderedDict() 70 | for i, row in concepts.iterrows(): 71 | if i % 100 == 0: 72 | print('\r\t%d in %d rows' % (i + 1, len(concepts)), end='') 73 | pid = row[cols[self.pid_col]] 74 | if self.skip_pid_check or pid in self.patient_admission: 75 | adm_id, code = row[cols[self.adm_id_col]], row[cols[self.cid_col]] 76 | if code == '': 77 | continue 78 | if adm_id not in result: 79 | result[adm_id] = [] 80 | codes = result[adm_id] 81 | codes.append(code) 82 | print('\r\t%d in %d rows' % (len(concepts), len(concepts))) 83 | return result 84 | 85 | def _after_read_concepts(self, concepts, concept_type, cols): 86 | return concepts 87 | 88 | def parse_diagnoses(self): 89 | print('parsing csv file of diagnosis ...') 90 | self.admission_codes = self._parse_concept('d') 91 | 92 | def calibrate_patient_by_admission(self): 93 | print('calibrating patients by admission ...') 94 | del_pids = [] 95 | for pid, admissions in self.patient_admission.items(): 96 | for admission in admissions: 97 | adm_id = admission[self.adm_id_col] 98 | if adm_id not in self.admission_codes: 99 | break 100 | else: 101 | continue 102 | del_pids.append(pid) 103 | for pid in del_pids: 104 | admissions = self.patient_admission[pid] 105 | for admission in admissions: 106 | adm_id = admission[self.adm_id_col] 107 | for concepts in [self.admission_codes]: 108 | if adm_id in concepts: 109 | del concepts[adm_id] 110 | del self.patient_admission[pid] 111 | 112 | def calibrate_admission_by_patient(self): 113 | print('calibrating admission by patients ...') 114 | adm_id_set = set() 115 | for admissions in self.patient_admission.values(): 116 | for admission in admissions: 117 | adm_id_set.add(admission[self.adm_id_col]) 118 | del_adm_ids = [adm_id for adm_id in self.admission_codes if adm_id not in adm_id_set] 119 | for adm_id in del_adm_ids: 120 | del self.admission_codes[adm_id] 121 | 122 | def sample_patients(self, sample_num, seed): 123 | np.random.seed(seed) 124 | keys = list(self.patient_admission.keys()) 125 | selected_pids = np.random.choice(keys, sample_num, False) 126 | self.patient_admission = {pid: self.patient_admission[pid] for pid in selected_pids} 127 | admission_codes = dict() 128 | for admissions in self.patient_admission.values(): 129 | for admission in admissions: 130 | adm_id = admission[self.adm_id_col] 131 | admission_codes[adm_id] = self.admission_codes[adm_id] 132 | self.admission_codes = admission_codes 133 | 134 | def parse(self, sample_num=None, seed=6669): 135 | self.parse_admission() 136 | self.parse_diagnoses() 137 | self.calibrate_patient_by_admission() 138 | self.calibrate_admission_by_patient() 139 | if sample_num is not None: 140 | self.sample_patients(sample_num, seed) 141 | return self.patient_admission, self.admission_codes 142 | 143 | 144 | class Mimic3Parser(EHRParser): 145 | def set_admission(self): 146 | filename = 'ADMISSIONS.csv' 147 | cols = {self.pid_col: 'SUBJECT_ID', self.adm_id_col: 'HADM_ID', self.adm_time_col: 'ADMITTIME'} 148 | converter = { 149 | 'SUBJECT_ID': int, 150 | 'HADM_ID': int, 151 | 'ADMITTIME': lambda cell: datetime.strptime(str(cell), '%Y-%m-%d %H:%M:%S') 152 | } 153 | return filename, cols, converter 154 | 155 | def set_diagnosis(self): 156 | filename = 'DIAGNOSES_ICD.csv' 157 | cols = {self.pid_col: 'SUBJECT_ID', self.adm_id_col: 'HADM_ID', self.cid_col: 'ICD9_CODE'} 158 | converter = {'SUBJECT_ID': int, 'HADM_ID': int, 'ICD9_CODE': Mimic3Parser.to_standard_icd9} 159 | return filename, cols, converter 160 | 161 | @staticmethod 162 | def to_standard_icd9(code: str): 163 | code = str(code) 164 | if code == '': 165 | return code 166 | split_pos = 4 if code.startswith('E') else 3 167 | icd9_code = code[:split_pos] + '.' + code[split_pos:] if len(code) > split_pos else code 168 | return icd9_code 169 | 170 | 171 | class Mimic4Parser(EHRParser): 172 | def __init__(self, path): 173 | super().__init__(path) 174 | self.icd_ver_col = 'icd_version' 175 | self.icd_map = self._load_icd_map() 176 | self.patient_year_map = self._load_patient() 177 | 178 | def _load_icd_map(self): 179 | print('loading ICD-10 to ICD-9 map ...') 180 | filename = 'icd10-icd9.csv' 181 | cols = ['ICD10', 'ICD9'] 182 | converters = {'ICD10': str, 'ICD9': str} 183 | icd_csv = pandas.read_csv(os.path.join(self.path, filename), usecols=cols, converters=converters) 184 | icd_map = {row['ICD10']: row['ICD9'] for _, row in icd_csv.iterrows()} 185 | return icd_map 186 | 187 | def _load_patient(self): 188 | print('loading patients anchor year ...') 189 | filename = 'patients.csv' 190 | cols = ['subject_id', 'anchor_year', 'anchor_year_group'] 191 | converters = {'subject_id': int, 'anchor_year': int, 'anchor_year_group': lambda cell: int(str(cell)[:4])} 192 | patient_csv = pandas.read_csv(os.path.join(self.path, filename), usecols=cols, converters=converters) 193 | patient_year_map = {row['subject_id']: row['anchor_year'] - row['anchor_year_group'] 194 | for i, row in patient_csv.iterrows()} 195 | return patient_year_map 196 | 197 | def set_admission(self): 198 | filename = 'admissions.csv' 199 | cols = {self.pid_col: 'subject_id', self.adm_id_col: 'hadm_id', self.adm_time_col: 'admittime'} 200 | converter = { 201 | 'subject_id': int, 202 | 'hadm_id': int, 203 | 'admittime': lambda cell: datetime.strptime(str(cell), '%Y-%m-%d %H:%M:%S') 204 | } 205 | return filename, cols, converter 206 | 207 | def set_diagnosis(self): 208 | filename = 'diagnoses_icd.csv' 209 | cols = { 210 | self.pid_col: 'subject_id', 211 | self.adm_id_col: 'hadm_id', 212 | self.cid_col: 'icd_code', 213 | self.icd_ver_col: 'icd_version' 214 | } 215 | converter = {'subject_id': int, 'hadm_id': int, 'icd_code': str, 'icd_version': int} 216 | return filename, cols, converter 217 | 218 | def _after_read_admission(self, admissions, cols): 219 | print('\tselecting valid admission ...') 220 | valid_admissions = [] 221 | n = len(admissions) 222 | for i, row in admissions.iterrows(): 223 | if i % 100 == 0: 224 | print('\r\t\t%d in %d rows' % (i + 1, n), end='') 225 | pid = row[cols[self.pid_col]] 226 | year = row[cols[self.adm_time_col]].year - self.patient_year_map[pid] 227 | if year > 2012: 228 | valid_admissions.append(i) 229 | print('\r\t\t%d in %d rows' % (n, n)) 230 | print('\t\tremaining %d rows' % len(valid_admissions)) 231 | return admissions.iloc[valid_admissions] 232 | 233 | def _after_read_concepts(self, concepts, concept_type, cols): 234 | print('\tmapping ICD-10 to ICD-9 ...') 235 | n = len(concepts) 236 | if concept_type == 'd': 237 | def _10to9(i, row): 238 | if i % 100 == 0: 239 | print('\r\t\t%d in %d rows' % (i + 1, n), end='') 240 | cid = row[cid_col] 241 | if row[icd_ver_col] == 10: 242 | if cid not in self.icd_map: 243 | code = self.icd_map[cid + '1'] if cid + '1' in self.icd_map else '' 244 | else: 245 | code = self.icd_map[cid] 246 | if code == 'NoDx': 247 | code = '' 248 | else: 249 | code = cid 250 | return Mimic4Parser.to_standard_icd9(code) 251 | 252 | cid_col, icd_ver_col = cols[self.cid_col], self.icd_ver_col 253 | col = np.array([_10to9(i, row) for i, row in concepts.iterrows()]) 254 | print('\r\t\t%d in %d rows' % (n, n)) 255 | concepts[cid_col] = col 256 | return concepts 257 | 258 | @staticmethod 259 | def to_standard_icd9(code: str): 260 | return Mimic3Parser.to_standard_icd9(code) 261 | 262 | 263 | class EICUParser(EHRParser): 264 | def __init__(self, path): 265 | super().__init__(path) 266 | self.skip_pid_check = True 267 | 268 | def set_admission(self): 269 | filename = 'patient.csv' 270 | cols = { 271 | self.pid_col: 'patienthealthsystemstayid', 272 | self.adm_id_col: 'patientunitstayid', 273 | self.adm_time_col: 'hospitaladmitoffset' 274 | } 275 | converter = { 276 | 'patienthealthsystemstayid': int, 277 | 'patientunitstayid': int, 278 | 'hospitaladmitoffset': lambda cell: -int(cell) 279 | } 280 | return filename, cols, converter 281 | 282 | def set_diagnosis(self): 283 | filename = 'diagnosis.csv' 284 | cols = {self.pid_col: 'diagnosisid', self.adm_id_col: 'patientunitstayid', self.cid_col: 'icd9code'} 285 | converter = {'diagnosisid': int, 'patientunitstayid': int, 'icd9code': EICUParser.to_standard_icd9} 286 | return filename, cols, converter 287 | 288 | @staticmethod 289 | def to_standard_icd9(code: str): 290 | code = str(code) 291 | if code == '': 292 | return code 293 | code = code.split(',')[0] 294 | c = code[0].lower() 295 | dot = code.find('.') 296 | if dot == -1: 297 | dot = None 298 | if not c.isalpha(): 299 | prefix = code[:dot] 300 | if len(prefix) < 3: 301 | code = ('%03d' % int(prefix)) + code[dot:] 302 | return code 303 | if c == 'e': 304 | prefix = code[1:dot] 305 | if len(prefix) != 3: 306 | return '' 307 | if c != 'e' or code[0] != 'v': 308 | return '' 309 | return code 310 | 311 | def parse_diagnoses(self): 312 | super().parse_diagnoses() 313 | t = OrderedDict.fromkeys(self.admission_codes.keys()) 314 | for adm_id, codes in self.admission_codes.items(): 315 | t[adm_id] = list(set(codes)) 316 | self.admission_codes = t 317 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | joblib==1.0.1 2 | numpy==1.20.2 3 | pandas==1.2.3 4 | Pillow==8.2.0 5 | python-dateutil==2.8.1 6 | pytz==2021.1 7 | scikit-learn==0.24.1 8 | scipy==1.6.2 9 | six==1.15.0 10 | threadpoolctl==2.1.0 11 | torch==1.8.1+cu111 12 | torchaudio==0.8.1 13 | torchvision==0.9.1+cu111 14 | typing-extensions==3.7.4.3 15 | -------------------------------------------------------------------------------- /run_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import _pickle as pickle 3 | 4 | from preprocess import save_sparse, save_data 5 | from preprocess.parse_csv import Mimic3Parser, Mimic4Parser, EICUParser 6 | from preprocess.encode import encode_code 7 | from preprocess.build_dataset import split_patients, build_code_xy, build_heart_failure_y 8 | from preprocess.auxiliary import generate_code_code_adjacent, generate_neighbors, normalize_adj, divide_middle, generate_code_levels 9 | 10 | if __name__ == '__main__': 11 | conf = { 12 | 'mimic3': { 13 | 'parser': Mimic3Parser, 14 | 'train_num': 6000, 15 | 'test_num': 1000, 16 | 'threshold': 0.01 17 | }, 18 | 'mimic4': { 19 | 'parser': Mimic4Parser, 20 | 'train_num': 8000, 21 | 'test_num': 1000, 22 | 'threshold': 0.01, 23 | 'sample_num': 10000 24 | }, 25 | 'eicu': { 26 | 'parser': EICUParser, 27 | 'train_num': 8000, 28 | 'test_num': 1000, 29 | 'threshold': 0.01 30 | } 31 | } 32 | from_saved = True 33 | data_path = 'data' 34 | dataset = 'mimic4' # mimic3, eicu, or mimic4 35 | dataset_path = os.path.join(data_path, dataset) 36 | raw_path = os.path.join(dataset_path, 'raw') 37 | if not os.path.exists(raw_path): 38 | os.makedirs(raw_path) 39 | print('please put the CSV files in `data/%s/raw`' % dataset) 40 | exit() 41 | parsed_path = os.path.join(dataset_path, 'parsed') 42 | if from_saved: 43 | patient_admission = pickle.load(open(os.path.join(parsed_path, 'patient_admission.pkl'), 'rb')) 44 | admission_codes = pickle.load(open(os.path.join(parsed_path, 'admission_codes.pkl'), 'rb')) 45 | else: 46 | parser = conf[dataset]['parser'](raw_path) 47 | sample_num = conf[dataset].get('sample_num', None) 48 | patient_admission, admission_codes = parser.parse(sample_num) 49 | print('saving parsed data ...') 50 | if not os.path.exists(parsed_path): 51 | os.makedirs(parsed_path) 52 | pickle.dump(patient_admission, open(os.path.join(parsed_path, 'patient_admission.pkl'), 'wb')) 53 | pickle.dump(admission_codes, open(os.path.join(parsed_path, 'admission_codes.pkl'), 'wb')) 54 | 55 | patient_num = len(patient_admission) 56 | max_admission_num = max([len(admissions) for admissions in patient_admission.values()]) 57 | avg_admission_num = sum([len(admissions) for admissions in patient_admission.values()]) / patient_num 58 | max_visit_code_num = max([len(codes) for codes in admission_codes.values()]) 59 | avg_visit_code_num = sum([len(codes) for codes in admission_codes.values()]) / len(admission_codes) 60 | print('patient num: %d' % patient_num) 61 | print('max admission num: %d' % max_admission_num) 62 | print('mean admission num: %.2f' % avg_admission_num) 63 | print('max code num in an admission: %d' % max_visit_code_num) 64 | print('mean code num in an admission: %.2f' % avg_visit_code_num) 65 | 66 | print('encoding code ...') 67 | admission_codes_encoded, code_map = encode_code(patient_admission, admission_codes) 68 | code_num = len(code_map) 69 | print('There are %d codes' % code_num) 70 | 71 | code_levels = generate_code_levels(data_path, code_map) 72 | pickle.dump({ 73 | 'code_levels': code_levels, 74 | }, open(os.path.join(parsed_path, 'code_levels.pkl'), 'wb')) 75 | 76 | train_pids, valid_pids, test_pids = split_patients( 77 | patient_admission=patient_admission, 78 | admission_codes=admission_codes, 79 | code_map=code_map, 80 | train_num=conf[dataset]['train_num'], 81 | test_num=conf[dataset]['test_num'] 82 | ) 83 | print('There are %d train, %d valid, %d test samples' % (len(train_pids), len(valid_pids), len(test_pids))) 84 | code_adj = generate_code_code_adjacent(pids=train_pids, patient_admission=patient_admission, 85 | admission_codes_encoded=admission_codes_encoded, 86 | code_num=code_num, threshold=conf[dataset]['threshold']) 87 | 88 | common_args = [patient_admission, admission_codes_encoded, max_admission_num, code_num] 89 | print('building train codes features and labels ...') 90 | (train_code_x, train_codes_y, train_visit_lens) = build_code_xy(train_pids, *common_args) 91 | print('building valid codes features and labels ...') 92 | (valid_code_x, valid_codes_y, valid_visit_lens) = build_code_xy(valid_pids, *common_args) 93 | print('building test codes features and labels ...') 94 | (test_code_x, test_codes_y, test_visit_lens) = build_code_xy(test_pids, *common_args) 95 | 96 | print('generating train neighbors ...') 97 | train_neighbors = generate_neighbors(train_code_x, train_visit_lens, code_adj) 98 | print('generating valid neighbors ...') 99 | valid_neighbors = generate_neighbors(valid_code_x, valid_visit_lens, code_adj) 100 | print('generating test neighbors ...') 101 | test_neighbors = generate_neighbors(test_code_x, test_visit_lens, code_adj) 102 | 103 | print('generating train middles ...') 104 | train_divided = divide_middle(train_code_x, train_neighbors, train_visit_lens) 105 | print('generating valid middles ...') 106 | valid_divided = divide_middle(valid_code_x, valid_neighbors, valid_visit_lens) 107 | print('generating test middles ...') 108 | test_divided = divide_middle(test_code_x, test_neighbors, test_visit_lens) 109 | 110 | print('building train heart failure labels ...') 111 | train_hf_y = build_heart_failure_y('428', train_codes_y, code_map) 112 | print('building valid heart failure labels ...') 113 | valid_hf_y = build_heart_failure_y('428', valid_codes_y, code_map) 114 | print('building test heart failure labels ...') 115 | test_hf_y = build_heart_failure_y('428', test_codes_y, code_map) 116 | 117 | encoded_path = os.path.join(dataset_path, 'encoded') 118 | if not os.path.exists(encoded_path): 119 | os.makedirs(encoded_path) 120 | print('saving encoded data ...') 121 | pickle.dump(patient_admission, open(os.path.join(encoded_path, 'patient_admission.pkl'), 'wb')) 122 | pickle.dump(admission_codes_encoded, open(os.path.join(encoded_path, 'codes_encoded.pkl'), 'wb')) 123 | pickle.dump(code_map, open(os.path.join(encoded_path, 'code_map.pkl'), 'wb')) 124 | pickle.dump({ 125 | 'train_pids': train_pids, 126 | 'valid_pids': valid_pids, 127 | 'test_pids': test_pids 128 | }, open(os.path.join(encoded_path, 'pids.pkl'), 'wb')) 129 | 130 | print('saving standard data ...') 131 | standard_path = os.path.join(dataset_path, 'standard') 132 | train_path = os.path.join(standard_path, 'train') 133 | valid_path = os.path.join(standard_path, 'valid') 134 | test_path = os.path.join(standard_path, 'test') 135 | if not os.path.exists(standard_path): 136 | os.makedirs(standard_path) 137 | if not os.path.exists(train_path): 138 | os.makedirs(train_path) 139 | os.makedirs(valid_path) 140 | os.makedirs(test_path) 141 | 142 | print('\tsaving training data') 143 | save_data(train_path, train_code_x, train_visit_lens, train_codes_y, train_hf_y, train_divided, train_neighbors) 144 | print('\tsaving valid data') 145 | save_data(valid_path, valid_code_x, valid_visit_lens, valid_codes_y, valid_hf_y, valid_divided, valid_neighbors) 146 | print('\tsaving test data') 147 | save_data(test_path, test_code_x, test_visit_lens, test_codes_y, test_hf_y, test_divided, test_neighbors) 148 | 149 | code_adj = normalize_adj(code_adj) 150 | save_sparse(os.path.join(standard_path, 'code_adj'), code_adj) 151 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from models.model import Model 9 | from utils import load_adj, EHRDataset, format_time, MultiStepLRScheduler 10 | from metrics import evaluate_codes, evaluate_hf 11 | 12 | 13 | def historical_hot(code_x, code_num, lens): 14 | result = np.zeros((len(code_x), code_num), dtype=int) 15 | for i, (x, l) in enumerate(zip(code_x, lens)): 16 | result[i] = x[l - 1] 17 | return result 18 | 19 | 20 | if __name__ == '__main__': 21 | seed = 6669 22 | dataset = 'mimic4' # 'mimic3' or 'eicu' 23 | task = 'h' # 'm' or 'h' 24 | use_cuda = True 25 | device = torch.device('cuda' if torch.cuda.is_available() and use_cuda else 'cpu') 26 | 27 | code_size = 48 28 | graph_size = 32 29 | hidden_size = 150 # rnn hidden size 30 | t_attention_size = 32 31 | t_output_size = hidden_size 32 | batch_size = 32 33 | epochs = 200 34 | 35 | random.seed(seed) 36 | np.random.seed(seed) 37 | torch.manual_seed(seed) 38 | torch.cuda.manual_seed(seed) 39 | 40 | dataset_path = os.path.join('data', dataset, 'standard') 41 | train_path = os.path.join(dataset_path, 'train') 42 | valid_path = os.path.join(dataset_path, 'valid') 43 | test_path = os.path.join(dataset_path, 'test') 44 | 45 | code_adj = load_adj(dataset_path, device=device) 46 | code_num = len(code_adj) 47 | print('loading train data ...') 48 | train_data = EHRDataset(train_path, label=task, batch_size=batch_size, shuffle=True, device=device) 49 | print('loading valid data ...') 50 | valid_data = EHRDataset(valid_path, label=task, batch_size=batch_size, shuffle=False, device=device) 51 | print('loading test data ...') 52 | test_data = EHRDataset(test_path, label=task, batch_size=batch_size, shuffle=False, device=device) 53 | 54 | test_historical = historical_hot(valid_data.code_x, code_num, valid_data.visit_lens) 55 | 56 | task_conf = { 57 | 'm': { 58 | 'dropout': 0.45, 59 | 'output_size': code_num, 60 | 'evaluate_fn': evaluate_codes, 61 | 'lr': { 62 | 'init_lr': 0.01, 63 | 'milestones': [20, 30], 64 | 'lrs': [1e-3, 1e-5] 65 | } 66 | }, 67 | 'h': { 68 | 'dropout': 0.0, 69 | 'output_size': 1, 70 | 'evaluate_fn': evaluate_hf, 71 | 'lr': { 72 | 'init_lr': 0.01, 73 | 'milestones': [2, 3, 20], 74 | 'lrs': [1e-3, 1e-4, 1e-5] 75 | } 76 | } 77 | } 78 | output_size = task_conf[task]['output_size'] 79 | activation = torch.nn.Sigmoid() 80 | loss_fn = torch.nn.BCELoss() 81 | evaluate_fn = task_conf[task]['evaluate_fn'] 82 | dropout_rate = task_conf[task]['dropout'] 83 | 84 | param_path = os.path.join('data', 'params', dataset, task) 85 | if not os.path.exists(param_path): 86 | os.makedirs(param_path) 87 | 88 | model = Model(code_num=code_num, code_size=code_size, 89 | adj=code_adj, graph_size=graph_size, hidden_size=hidden_size, t_attention_size=t_attention_size, 90 | t_output_size=t_output_size, 91 | output_size=output_size, dropout_rate=dropout_rate, activation=activation).to(device) 92 | optimizer = torch.optim.Adam(model.parameters(), lr=0.01) 93 | scheduler = MultiStepLRScheduler(optimizer, epochs, task_conf[task]['lr']['init_lr'], 94 | task_conf[task]['lr']['milestones'], task_conf[task]['lr']['lrs']) 95 | 96 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 97 | print(pytorch_total_params) 98 | 99 | for epoch in range(epochs): 100 | print('Epoch %d / %d:' % (epoch + 1, epochs)) 101 | model.train() 102 | total_loss = 0.0 103 | total_num = 0 104 | steps = len(train_data) 105 | st = time.time() 106 | scheduler.step() 107 | for step in range(len(train_data)): 108 | optimizer.zero_grad() 109 | code_x, visit_lens, divided, y, neighbors = train_data[step] 110 | output = model(code_x, divided, neighbors, visit_lens).squeeze() 111 | loss = loss_fn(output, y) 112 | loss.backward() 113 | optimizer.step() 114 | total_loss += loss.item() * output_size * len(code_x) 115 | total_num += len(code_x) 116 | 117 | end_time = time.time() 118 | remaining_time = format_time((end_time - st) / (step + 1) * (steps - step - 1)) 119 | print('\r Step %d / %d, remaining time: %s, loss: %.4f' 120 | % (step + 1, steps, remaining_time, total_loss / total_num), end='') 121 | train_data.on_epoch_end() 122 | et = time.time() 123 | time_cost = format_time(et - st) 124 | print('\r Step %d / %d, time cost: %s, loss: %.4f' % (steps, steps, time_cost, total_loss / total_num)) 125 | valid_loss, f1_score = evaluate_fn(model, valid_data, loss_fn, output_size, test_historical) 126 | torch.save(model.state_dict(), os.path.join(param_path, '%d.pt' % epoch)) 127 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from preprocess import load_sparse 7 | 8 | 9 | def load_adj(path, device=torch.device('cpu')): 10 | filename = os.path.join(path, 'code_adj.npz') 11 | adj = torch.from_numpy(load_sparse(filename)).to(device=device, dtype=torch.float32) 12 | return adj 13 | 14 | 15 | class EHRDataset: 16 | def __init__(self, data_path, label='m', batch_size=32, shuffle=True, device=torch.device('cpu')): 17 | super().__init__() 18 | self.path = data_path 19 | self.code_x, self.visit_lens, self.y, self.divided, self.neighbors = self._load(label) 20 | 21 | self._size = self.code_x.shape[0] 22 | self.idx = np.arange(self._size) 23 | self.batch_size = batch_size 24 | self.shuffle = shuffle 25 | self.device = device 26 | 27 | def _load(self, label): 28 | code_x = load_sparse(os.path.join(self.path, 'code_x.npz')) 29 | visit_lens = np.load(os.path.join(self.path, 'visit_lens.npz'))['lens'] 30 | if label == 'm': 31 | y = load_sparse(os.path.join(self.path, 'code_y.npz')) 32 | elif label == 'h': 33 | y = np.load(os.path.join(self.path, 'hf_y.npz'))['hf_y'] 34 | else: 35 | raise KeyError('Unsupported label type') 36 | divided = load_sparse(os.path.join(self.path, 'divided.npz')) 37 | neighbors = load_sparse(os.path.join(self.path, 'neighbors.npz')) 38 | return code_x, visit_lens, y, divided, neighbors 39 | 40 | def on_epoch_end(self): 41 | if self.shuffle: 42 | np.random.shuffle(self.idx) 43 | 44 | def size(self): 45 | return self._size 46 | 47 | def label(self): 48 | return self.y 49 | 50 | def __len__(self): 51 | len_ = self._size // self.batch_size 52 | return len_ if self._size % self.batch_size == 0 else len_ + 1 53 | 54 | def __getitem__(self, index): 55 | device = self.device 56 | start = index * self.batch_size 57 | end = start + self.batch_size 58 | slices = self.idx[start:end] 59 | code_x = torch.from_numpy(self.code_x[slices]).to(device) 60 | visit_lens = torch.from_numpy(self.visit_lens[slices]).to(device=device, dtype=torch.long) 61 | y = torch.from_numpy(self.y[slices]).to(device=device, dtype=torch.float32) 62 | divided = torch.from_numpy(self.divided[slices]).to(device) 63 | neighbors = torch.from_numpy(self.neighbors[slices]).to(device) 64 | return code_x, visit_lens, divided, y, neighbors 65 | 66 | 67 | class MultiStepLRScheduler: 68 | def __init__(self, optimizer, epochs, init_lr, milestones, lrs): 69 | self.optimizer = optimizer 70 | self.epochs = epochs 71 | self.init_lr = init_lr 72 | self.lrs = self._generate_lr(milestones, lrs) 73 | self.current_epoch = 0 74 | 75 | def _generate_lr(self, milestones, lrs): 76 | milestones = [1] + milestones + [self.epochs + 1] 77 | lrs = [self.init_lr] + lrs 78 | lr_grouped = np.concatenate([np.ones((milestones[i + 1] - milestones[i], )) * lrs[i] 79 | for i in range(len(milestones) - 1)]) 80 | return lr_grouped 81 | 82 | def step(self): 83 | lr = self.lrs[self.current_epoch] 84 | for group in self.optimizer.param_groups: 85 | group['lr'] = lr 86 | self.current_epoch += 1 87 | 88 | def reset(self): 89 | self.current_epoch = 0 90 | 91 | 92 | def format_time(seconds): 93 | if seconds <= 60: 94 | time_str = '%.1fs' % seconds 95 | elif seconds <= 3600: 96 | time_str = '%dm%.1fs' % (seconds // 60, seconds % 60) 97 | else: 98 | time_str = '%dh%dm%.1fs' % (seconds // 3600, (seconds % 3600) // 60, seconds % 60) 99 | return time_str 100 | --------------------------------------------------------------------------------