├── EHR2Vec ├── __init__.py ├── statistic.py ├── hparams.py ├── SLE_EHR2Vec_Runner.py └── EHR2Vec_modules.py ├── example ├── __init__.py ├── dict_types.pkl └── train_data_example.pkl └── README.md /EHR2Vec/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf8 2 | #@Author:hadoop 3 | #@File:__init__.py -------------------------------------------------------------------------------- /example/__init__.py: -------------------------------------------------------------------------------- 1 | #encoding:utf8 2 | #@Time:2019/12/23 21:50 3 | #@File:__init__.py.py 4 | -------------------------------------------------------------------------------- /example/dict_types.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jingsongs/EHR2Vec/HEAD/example/dict_types.pkl -------------------------------------------------------------------------------- /example/train_data_example.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jingsongs/EHR2Vec/HEAD/example/train_data_example.pkl -------------------------------------------------------------------------------- /EHR2Vec/statistic.py: -------------------------------------------------------------------------------- 1 | #encoding:utf8 2 | #@File:statistic.py 3 | 4 | import pickle 5 | import argparse 6 | 7 | def initParamaters(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--dict_types', default="./example/dict_types.pkl", help='dict_types path') 10 | parser.add_argument('--example_train_path', default='./example/train_data_example.pkl', 11 | help="example train data path") 12 | return parser 13 | 14 | def get_types_number_maxlen(exapmple_dict_types,example_train_path): 15 | types = pickle.load(open(exapmple_dict_types, 'rb')) 16 | train_data = pickle.load(open(example_train_path, 'rb')) 17 | print(train_data) 18 | print(types) 19 | return len(types),max([len(visit) for visit in train_data]) 20 | 21 | hparams=initParamaters() 22 | hp=hparams.parse_args() 23 | types_number,maxlen=get_types_number_maxlen(hp.dict_types,hp.example_train_path) 24 | print('The number of dict_types:{0},max length of input event:{1}'.format(types_number,maxlen)) 25 | 26 | -------------------------------------------------------------------------------- /EHR2Vec/hparams.py: -------------------------------------------------------------------------------- 1 | #encoding:utf8 2 | #@File:hparams.py 3 | 4 | import argparse 5 | 6 | class Hparams: 7 | parser = argparse.ArgumentParser() 8 | #input 9 | parser.add_argument('--n_input', default=15, type=int) 10 | parser.add_argument('--maxlen_seqs', default=6, type=int,help="max length of input event") 11 | #model 12 | parser.add_argument('--d_model', default=8, type=int,help="dimension of entity embedding") 13 | parser.add_argument('--d_ff', default=32, type=int, help="number of neurons of feedword network") 14 | parser.add_argument('--num_blocks', default=1, type=int, help="number of blocks") 15 | parser.add_argument('--num_heads', default=8, type=int, help="head number of the multi-head attenion") 16 | parser.add_argument('--dropout_rate', default=0.1, type=float, help="dropout rate") 17 | #train 18 | parser.add_argument('--max_epoch', default=10, type=int) 19 | parser.add_argument('--batch_size', default=8, type=int) 20 | parser.add_argument('--display_step', default=1, type=int,help='display frequency of the training process') 21 | #path 22 | parser.add_argument('--data_path', default='./example/train_data_example.pkl', help='path of the train data') 23 | parser.add_argument('--dict_types_path', default='./example/dict_types.pkl', help='path of the dict_types') 24 | parser.add_argument('--save_model_path', default='./example/EHR2vec_model/', help='save model path') 25 | parser.add_argument('--entity_embedding_path', default='./example/entity_embedding.pkl', help='save entity_embedding path') 26 | 27 | 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EHR2Vec 2 | EHR2Vec is an embedding translation tool for medical entities based on attention mechanism. We have provided a version of tensorflow implement and it is constantly being improved. 3 | 4 | # Prerequisites 5 | 1. We use Python 3.6 and Tensorflow 1.8.0. 6 | 2. Download/clone the EHR2Vec code 7 | 8 | 9 | # Running 10 | 1. Preparing training data: 11 | Training data should be a list of list of medical entities include diagnosis, medication, labtest and symptom and the medical entities should be transformed to integers. In our model, a patient medical event includes many medical entities and this medical events of a patient constitute a patient complete medical process. In the meantime, a patient medical events are sorted in chronological order. So the out layer list denote a patient, each inner layer denote each event of this patient and we use -1 to separate every event of a patient. For example, [[22,50,33], [-1] ,[4,58,60],[20]] means there are two patients where the first patient only had one event which include three entities and the second patient had two event include [4,58,60] and [20]. And you should save the transformation vocabulary (from medical entities to integers). In our model, the length of each event is fixed and we use 0 as padding value. So 0 should also be saved in the transformation vocabulary which means padding. Both of the two files need to be pickled use python pickle. 12 | 13 | 2. Model hyper-parameters configuration: 14 | The max length of the medical event of all patients needs to be count and called maxlen_seqs in our model. The number of all your medical entities in your vocabulary is required in our model and called n_input. The default dimension of entity is 512 which you can change as you like. The attention mechanism parameters include num_heads, num_blokcs and d_ff you can set. The number of epoch and the size of batch should be configured for your own machine. 15 | 16 | 3. Running: 17 | You can train the model with the default hyper-parameters except the training data path, n_input, maxlen_seqs, dict_types_path, save_model_path and embedding save path. And you can use the simple execution command to run the model: 18 | 19 | python3 SLE_EHR2Vec_Runner.py --data_path \ --n_input \ --maxlen_seqs \ --dict_types_path \ --entity_embedding_path \ --save_model_path \ 20 | 21 | The complete execution command includes all the hyper-parameters: 22 | 23 | python3 SLE_EHR2Vec_Runner.py --data_path --n_input --maxlen_seqs --d_model --d_ff --num_blocks -- num_heads --dropout_rate < dropout rate> --dict_types_path --entity_embedding_path --save_model_path --max_epoch --batch_size --display_step 24 | 25 | 26 | ## Example of how to run EHR2Vec with the provide train_data_example.pkl and dict_types.pkl 27 | 28 | 1、 Count the total number of input entities(n_input) and the max length(max_seq_length) of all events with the command. 29 | 30 | python3 statistic.py --dict_types ./example/dict_types.pkl --example_train_path ./example/train_data_example.pkl 31 | 32 | You will get the n_input number and max_seq_length number from the output. 33 | 34 | 2、 Run the EHR2Vec model and get the entity embedding: 35 | 36 | python3 SLE_EHR2Vec_Runner.py --n_input 17 --maxlen_seqs 7 --data_path ./example/train_data_example.pkl --dict_types_path ./example/dict_types.pkl --save_model_path ./example/EHR2Vec_model/ --entity_embedding_path ./example/entity_embedding.pkl 37 | 38 | 3、The vectors result can be found in the entity_embedding_path. 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /EHR2Vec/SLE_EHR2Vec_Runner.py: -------------------------------------------------------------------------------- 1 | #encoding:utf8 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import pickle 6 | import os 7 | from collections import OrderedDict 8 | from EHR2Vec_modules import EHR2Vec 9 | from hparams import Hparams 10 | 11 | 12 | def load_data(x_file): 13 | x_seq = np.array(pickle.load(open(x_file, 'rb'))) 14 | return x_seq 15 | 16 | def pickTwo(iVector, jVector,maxlen_seqs): 17 | for first in range(maxlen_seqs): 18 | for second in range(maxlen_seqs): 19 | if first == second: continue 20 | iVector.append(first) 21 | jVector.append(second) 22 | 23 | def pickTwo_vi(Vts,vi,vj): 24 | num1_v1 = 0 25 | for first_v,v1 in enumerate(Vts): 26 | if v1==[-1]: 27 | num1_v1+=1 28 | continue 29 | seconds = [] 30 | num1_v2 = 0 31 | for second_v,v2 in enumerate(Vts): 32 | if v2!=[-1]: 33 | seconds.append(v2) 34 | if num1_v1==0: 35 | if second_v != first_v: 36 | vi.append(first_v - num1_v1) 37 | vj.append(second_v) 38 | else: 39 | if second_v>=len(seconds)-num1_v2: 40 | second_v=second_v-num1_v2 41 | if second_v!= first_v - num1_v1: 42 | vi.append(first_v - num1_v1) 43 | vj.append(second_v) 44 | else: 45 | num1_v2+=1 46 | if num1_v2==num1_v1: 47 | seconds = Vts[:second_v+1] 48 | else: 49 | seconds=[0]*100 50 | if num1_v2>num1_v1: 51 | break 52 | 53 | def pad_matrix(seqs, maxlen_seqs): 54 | i_vec = [] 55 | j_vec = [] 56 | vi_vec=[] 57 | vj_vec=[] 58 | pickTwo_vi(seqs.tolist(),vi_vec,vj_vec) 59 | sents=[] 60 | for idx,seq_id in enumerate(seqs): 61 | if not seq_id[0] == -1: 62 | seq_id_array=np.array(seq_id) 63 | #seq_id_array_1=(seq_id_array+np.ones_like(seq_id_array)) 64 | sents.append(seq_id_array) 65 | pickTwo(i_vec, j_vec,maxlen_seqs) 66 | X=np.zeros([len(sents),maxlen_seqs],np.int32) 67 | for i,x in enumerate(sents): 68 | X[i]=np.lib.pad(x,[0,maxlen_seqs-len(x)],'constant',constant_values=(0,0)) 69 | return X, i_vec, j_vec,vi_vec,vj_vec 70 | 71 | def model_train(model, saver, hp): 72 | for epoch in range(hp.max_epoch): 73 | print('epoch %d'%epoch) 74 | avg_cost = 0. 75 | x_seq= load_data(hp.data_path) 76 | total_batch = int(np.ceil(len(x_seq) / hp.batch_size)) 77 | for index in range(total_batch): 78 | x_batch = x_seq[index * hp.batch_size: (index + 1) * hp.batch_size] 79 | x, i_vec, j_vec,vi_vec,vj_vec= pad_matrix(x_batch, hp.maxlen_seqs) 80 | cost=model.partial_fit(x=x,i_vec=i_vec, j_vec=j_vec,vi=vi_vec,vj=vj_vec) 81 | avg_cost += cost / len(x_seq) * hp.batch_size 82 | if epoch % hp.display_step == 0: 83 | print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost)) 84 | save_path=hp.save_model_path 85 | if os.path.exists(save_path): 86 | path=os.path.join(save_path,'EHR2Vec') 87 | else: 88 | os.makedirs(save_path) 89 | path=os.path.join(save_path,'EHR2Vec') 90 | if epoch == hp.max_epoch - 1: 91 | saver.save(sess=model.sess, save_path=path,global_step=hp.max_epoch) 92 | 93 | def get_code_representation(model, saver,dirpath,dict_types_file,entity_embedding_path): 94 | ckpt = tf.train.get_checkpoint_state(dirpath) 95 | if ckpt and ckpt.model_checkpoint_path: 96 | saver.restore(model.sess, ckpt.model_checkpoint_path) 97 | embeddings = model.get_weights_embeddings() 98 | types = pickle.load(open(dict_types_file, 'rb')) 99 | types = OrderedDict(sorted(types.items(),key=lambda x:x[1])) 100 | file = open(entity_embedding_path, 'wb') 101 | dict = {} 102 | for w, (k, v) in zip(embeddings, types.items()): 103 | dict[k] = w 104 | pickle.dump(dict, file) 105 | file.close() 106 | else: 107 | print('ERROR') 108 | 109 | 110 | def main(_): 111 | hparams = Hparams() 112 | parser = hparams.parser 113 | hp = parser.parse_args() 114 | 115 | model = EHR2Vec(n_input=hp.n_input, d_model=hp.d_model,batch_size=hp.batch_size, 116 | maxseq_len=hp.maxlen_seqs,d_ff=hp.d_ff,num_blocks=hp.num_blocks, 117 | num_heads=hp.num_heads,dropout_rate=hp.dropout_rate) 118 | saver = tf.train.Saver() 119 | model_train(model, saver, hp) 120 | get_code_representation(model,saver,hp.save_model_path,hp.dict_types_path,hp.entity_embedding_path) 121 | 122 | if __name__ == "__main__": 123 | tf.app.run() 124 | -------------------------------------------------------------------------------- /EHR2Vec/EHR2Vec_modules.py: -------------------------------------------------------------------------------- 1 | #encoding:utf8 2 | #@File: EHR2Vec_modules.py 3 | 4 | import logging 5 | logging.basicConfig(level=logging.INFO) 6 | import tensorflow as tf 7 | 8 | def ln(inputs, epsilon=1e-8, scope="ln"): 9 | ''' 10 | LayerNormaliztion 11 | ''' 12 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 13 | inputs_shape = inputs.get_shape() 14 | params_shape = inputs_shape[-1:] 15 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True) 16 | beta = tf.get_variable("beta", params_shape, initializer=tf.zeros_initializer()) 17 | gamma = tf.get_variable("gamma", params_shape, initializer=tf.ones_initializer()) 18 | normalized = (inputs - mean) / ((variance + epsilon) ** (.5)) 19 | outputs = gamma * normalized + beta 20 | return outputs 21 | 22 | def get_token_embeddings(vocab_size, num_units, zero_pad=True): 23 | with tf.variable_scope("shared_weight_matrix",reuse=tf.AUTO_REUSE): 24 | embeddings = tf.get_variable(name='weight_mat', 25 | dtype=tf.float32, 26 | shape=(vocab_size, num_units), 27 | initializer=tf.contrib.layers.xavier_initializer(uniform=False)) 28 | if zero_pad: 29 | embeddings = tf.concat((tf.zeros(shape=[1, num_units]),embeddings[1:, :]), 0) 30 | return embeddings 31 | 32 | 33 | def scaled_dot_product_attention(Q, K, V,dropout_rate=0.1,training=True, 34 | scope="scaled_dot_product_attention"): 35 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 36 | d_k = Q.get_shape().as_list()[-1] 37 | outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])) 38 | outputs /= d_k ** 0.5 39 | outputs = mask(outputs, Q, K, type="key") 40 | outputs = tf.nn.softmax(outputs) 41 | attention = tf.transpose(outputs, [0, 2, 1]) 42 | tf.summary.image("attention", tf.expand_dims(attention[:1], -1)) 43 | # query masking 44 | outputs = mask(outputs, Q, K, type="query") 45 | outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=training) 46 | outputs = tf.matmul(outputs, V) 47 | return outputs 48 | 49 | 50 | def mask(inputs, queries=None, keys=None, type=None): 51 | padding_num = -2 ** 32 + 1 52 | if type =="key": 53 | # Generate masks 54 | masks = tf.sign(tf.reduce_sum(tf.abs(keys), axis=-1)) 55 | masks = tf.expand_dims(masks, 1) 56 | masks = tf.tile(masks, [1, tf.shape(queries)[1], 1]) 57 | paddings = tf.ones_like(inputs) * padding_num 58 | outputs = tf.where(tf.equal(masks, 0), paddings, inputs) 59 | elif type =="query": 60 | # Generate masks 61 | masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) 62 | masks = tf.expand_dims(masks, -1) 63 | masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) 64 | outputs = inputs * masks 65 | return outputs 66 | 67 | 68 | def multihead_attention(queries, keys, values,num_heads=4,dropout_rate=0.1,training=True, 69 | scope="multihead_attention"): 70 | d_model = queries.get_shape().as_list()[-1] 71 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 72 | Q = tf.layers.dense(queries, d_model) 73 | K = tf.layers.dense(keys, d_model) 74 | V = tf.layers.dense(values, d_model) 75 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) 76 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) 77 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) 78 | outputs = scaled_dot_product_attention(Q_, K_, V_, dropout_rate, training) 79 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) 80 | outputs += queries 81 | outputs = ln(outputs) 82 | return outputs 83 | 84 | def ff(inputs, num_units, scope="positionwise_feedforward"): 85 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): 86 | outputs = tf.layers.dense(inputs, num_units[0], activation=tf.nn.relu) 87 | outputs = tf.layers.dense(outputs, num_units[1]) 88 | outputs += inputs 89 | outputs = ln(outputs) 90 | return outputs 91 | 92 | class EHR2Vec(object): 93 | def __init__(self, n_input, d_model, batch_size, maxseq_len,d_ff,num_blocks,num_heads,dropout_rate,log_eps=1e-8, 94 | optimizer=tf.train.AdadeltaOptimizer(learning_rate=0.5), init_scale=0.01): 95 | self.n_input = n_input 96 | self.d_model = d_model 97 | self.log_eps = log_eps 98 | self.init_scale = init_scale 99 | self.optimizer = optimizer 100 | self.batch_size = batch_size 101 | self.maxseq_len = maxseq_len 102 | self.d_ff=d_ff 103 | self.num_blocks=num_blocks 104 | self.num_heads=num_heads 105 | self.dropout_rate=dropout_rate 106 | self.embeddings = get_token_embeddings(self.n_input, self.d_model, zero_pad=True) 107 | 108 | self.i_vec = tf.placeholder(tf.int32) 109 | self.j_vec = tf.placeholder(tf.int32) 110 | self.idx = tf.placeholder(tf.int32, shape=[None, self.maxseq_len]) 111 | self.vi = tf.placeholder(tf.int32, shape=[None, 1]) 112 | self.vj = tf.placeholder(tf.int32, shape=[None, 1]) 113 | self.v = self.encode(self.idx, self.embeddings)*0.1 114 | self.emb_cost = self._initialize_entity_cost() 115 | self.vivlcost = self._initialize_visit_cost() 116 | self.cost = self.emb_cost + self.vivlcost 117 | self.optimizer = self.optimizer.minimize(self.cost) 118 | init = tf.global_variables_initializer() 119 | self.sess = tf.Session() 120 | self.sess.run(init) 121 | 122 | 123 | def encode(self, xs,embeddings,training=True): 124 | with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE): 125 | enc = tf.nn.embedding_lookup(embeddings, xs) 126 | enc *= self.d_model**0.5 127 | for i in range(self.num_blocks): 128 | with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE): 129 | enc = multihead_attention(queries=enc, 130 | keys=enc, 131 | values=enc, 132 | num_heads=self.num_heads, 133 | dropout_rate=self.dropout_rate, 134 | training=training) 135 | enc = ff(enc, num_units=[self.d_ff, self.d_model]) 136 | return enc 137 | 138 | def _initialize_entity_cost(self): 139 | norms = tf.reduce_sum(tf.exp(tf.matmul(self.v, tf.transpose(self.v, [0, 2, 1]))), axis=2) 140 | wi_emb = tf.gather(self.v, self.i_vec, axis=1) 141 | wj_emb = tf.gather(self.v, self.j_vec, axis=1) 142 | exp = tf.exp(tf.reduce_sum(wi_emb * wj_emb, axis=2)) 143 | norms2 = tf.gather(norms, self.i_vec, axis=1) 144 | #norms2 = tf.gather_nd(norms, tf.reshape(self.i_vec, (-1, 1))) 145 | log_sum = tf.reduce_sum(-tf.log(exp / norms2+ self.log_eps)) 146 | return log_sum 147 | 148 | def _initialize_visit_cost(self): 149 | batch_v = tf.squeeze(tf.layers.dense(self.v, 1),axis=-1)*0.1 150 | norms = tf.reduce_sum(tf.exp(tf.matmul(batch_v, tf.transpose(batch_v, [1, 0]))), axis=1) 151 | wi_emb = tf.gather_nd(batch_v, self.vi) 152 | wj_emb = tf.gather_nd(batch_v, self.vj) 153 | exp = tf.exp(tf.reduce_sum(wi_emb * wj_emb, axis=1)) 154 | norms2 = tf.gather_nd(norms, self.vj) 155 | log_sum = tf.reduce_sum(-tf.log(exp / norms2+self.log_eps)) 156 | return log_sum 157 | 158 | def partial_fit(self, x=None, i_vec=None, j_vec=None, vi=None, vj=None): 159 | cost, opt = self.sess.run((self.cost, self.optimizer), feed_dict= 160 | {self.idx: x, self.i_vec: i_vec, self.j_vec: j_vec, self.vi: vi, self.vj: vj}) 161 | return cost 162 | 163 | def get_visit_representation(self, x=None): 164 | visit_representation = self.sess.run(self.v, feed_dict={self.x: x}) 165 | return visit_representation 166 | 167 | def get_weights_embeddings(self): 168 | return self.sess.run(self.embeddings) 169 | --------------------------------------------------------------------------------