├── HELSTM.py ├── MergeEmbedding.py ├── README.md ├── dataExt.py ├── extractor.cpp ├── task.py ├── 数据生成说明_v7 ├── data_process │ ├── build_event.py │ ├── config │ │ └── blacklist.reg │ ├── dataExt.py │ ├── event_des.py │ ├── extract.py │ ├── extractor.cpp │ ├── extractor.py │ ├── gather_stat.py │ ├── gather_static_data.py │ ├── lab_process_data.py │ ├── select_feature.py │ ├── sort_stat.py │ ├── stat_data.py │ ├── stat_value.py │ └── util.py └── 数据生成说明.pdf └── 数据集说明.docx /HELSTM.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | import theano.tensor as T 4 | import lasagne 5 | from lasagne import nonlinearities 6 | from lasagne import init 7 | from lasagne.layers.base import MergeLayer, Layer 8 | from lasagne.layers.recurrent import Gate 9 | 10 | class HELSTMGate(object): 11 | def __init__(self, 12 | Period=init.Uniform((10, 100)), 13 | Shift=init.Uniform((0.,1000.)), 14 | On_End=init.Constant(0.05), 15 | Event_W=init.GlorotUniform(), 16 | Event_b=init.Constant(0.), 17 | out_W=init.GlorotUniform(), 18 | out_b=init.Constant(0.)): 19 | 20 | self.Period = Period 21 | self.Shift = Shift 22 | self.On_End = On_End 23 | self.Event_W = Event_W 24 | self.Event_b = Event_b 25 | self.out_W = out_W 26 | self.out_b = out_b 27 | 28 | class HELSTMLayer(MergeLayer): 29 | def __init__(self, incoming, time_input, event_input, 30 | num_units, num_attention, model='HELSTM',#model options: LSTM, PLSTM or HELSTM 31 | mask_input=None, 32 | ingate=Gate(), 33 | forgetgate=Gate(), 34 | cell=Gate(W_cell=None, nonlinearity=nonlinearities.tanh), 35 | timegate=HELSTMGate(), 36 | nonlinearity=nonlinearities.tanh, 37 | cell_init=init.Constant(0.), 38 | hid_init=init.Constant(0.), 39 | outgate=Gate(), 40 | backwards=False, 41 | learn_init=False, 42 | peepholes=True, 43 | grad_clipping=0, 44 | bn=False, 45 | only_return_final=False, 46 | off_alpha=1e-3, 47 | **kwargs): 48 | incomings = [incoming, time_input, event_input] 49 | self.time_incoming_idx = 1 50 | self.event_incoming_idx = 2 51 | self.mask_incoming_index = -2 52 | self.hid_init_incoming_index = -2 53 | self.cell_init_incoming_index = -2 54 | 55 | if mask_input is not None: 56 | incomings.append(mask_input) 57 | self.mask_incoming_index = len(incomings)-1 58 | if isinstance(hid_init, Layer): 59 | incomings.append(hid_init) 60 | self.hid_init_incoming_index = len(incomings)-1 61 | if isinstance(cell_init, Layer): 62 | incomings.append(cell_init) 63 | self.cell_init_incoming_index = len(incomings)-1 64 | 65 | super(HELSTMLayer, self).__init__(incomings, **kwargs) 66 | 67 | self.nonlinearity = nonlinearity 68 | self.learn_init=learn_init 69 | self.num_units = num_units 70 | self.num_attention = num_attention 71 | self.peepholes = peepholes 72 | self.grad_clipping = grad_clipping 73 | self.backwards = backwards 74 | self.off_alpha = off_alpha 75 | self.only_return_final = only_return_final 76 | self.model = model 77 | if self.model == 'LSTM': 78 | print 'using LSTM' 79 | elif self.model == 'PLSTM': 80 | print 'using PLSTM' 81 | else: 82 | assert self.model=='HELSTM' 83 | print 'using HELSTM' 84 | 85 | input_shape = self.input_shapes[0] 86 | num_inputs = np.prod(input_shape[2:]) 87 | 88 | def add_gate_params(gate, gate_name): 89 | return (self.add_param(gate.W_in, (num_inputs, num_units), 90 | name="W_in_to_{}".format(gate_name)), 91 | self.add_param(gate.W_hid, (num_units, num_units), 92 | name="W_hid_to_{}".format(gate_name)), 93 | self.add_param(gate.b, (num_units,), 94 | name="b_{}".format(gate_name), 95 | regularizable=False), 96 | gate.nonlinearity) 97 | 98 | # Add in parameters from the supplied Gate instances 99 | (self.W_in_to_ingate, self.W_hid_to_ingate, self.b_ingate, 100 | self.nonlinearity_ingate) = add_gate_params(ingate, 'ingate') 101 | 102 | (self.W_in_to_forgetgate, self.W_hid_to_forgetgate, self.b_forgetgate, 103 | self.nonlinearity_forgetgate) = add_gate_params(forgetgate, 104 | 'forgetgate') 105 | 106 | (self.W_in_to_cell, self.W_hid_to_cell, self.b_cell, 107 | self.nonlinearity_cell) = add_gate_params(cell, 'cell') 108 | 109 | (self.W_in_to_outgate, self.W_hid_to_outgate, self.b_outgate, 110 | self.nonlinearity_outgate) = add_gate_params(outgate, 'outgate') 111 | 112 | # If peephole (cell to gate) connections were enabled, initialize 113 | # peephole connections. These are elementwise products with the cell 114 | # state, so they are represented as vectors. 115 | if self.peepholes: 116 | self.W_cell_to_ingate = self.add_param( 117 | ingate.W_cell, (num_units, ), name="W_cell_to_ingate") 118 | 119 | self.W_cell_to_forgetgate = self.add_param( 120 | forgetgate.W_cell, (num_units, ), name="W_cell_to_forgetgate") 121 | 122 | self.W_cell_to_outgate = self.add_param( 123 | outgate.W_cell, (num_units, ), name="W_cell_to_outgate") 124 | 125 | # Setup initial values for the cell and the hidden units 126 | if isinstance(cell_init, Layer): 127 | self.cell_init = cell_init 128 | else: 129 | self.cell_init = self.add_param( 130 | cell_init, (1, num_units), name="cell_init", 131 | trainable=learn_init, regularizable=False) 132 | 133 | if isinstance(hid_init, Layer): 134 | self.hid_init = hid_init 135 | else: 136 | self.hid_init = self.add_param( 137 | hid_init, (1, self.num_units), name="hid_init", 138 | trainable=learn_init, regularizable=False) 139 | 140 | if bn: 141 | self.bn = lasagne.layers.BatchNormLayer(input_shape, axes=(0,1)) # create BN layer for correct input shape 142 | self.params.update(self.bn.params) # make BN params your params 143 | else: 144 | self.bn = False 145 | 146 | def add_timegate_params(gate, gate_name, attention=False): 147 | params = [self.add_param(gate.Period, (num_units, ), 148 | name="Period_{}".format(gate_name)), 149 | self.add_param(gate.Shift, (num_units, ), 150 | name="Shift_{}".format(gate_name)), 151 | self.add_param(gate.On_End, (num_units, ), 152 | name="On_End_{}".format(gate_name))] 153 | if attention: 154 | params += [self.add_param(gate.Event_W, (num_inputs, num_attention), 155 | name="Event_W_{}".format(gate_name)), 156 | self.add_param(gate.Event_b, (num_attention, ), 157 | name="Event_b_{}".format(gate_name)), 158 | self.add_param(gate.out_W, (num_attention, num_units), 159 | name="out_b_{}".format(gate_name)), 160 | self.add_param(gate.out_b, (num_units, ), 161 | name="out_b_{}".format(gate_name))] 162 | return params 163 | 164 | if model!='LSTM': 165 | if model=='PLSTM': 166 | (self.period_timegate, self.shift_timegate, self.on_end_timegate) = add_timegate_params(timegate, 'timegate') 167 | else: 168 | assert model == 'HELSTM' 169 | (self.period_timegate, self.shift_timegate, self.on_end_timegate, 170 | self.event_w_timegate, self.event_b_timegate, self.out_w_timegate, self.out_b_timegate) = add_timegate_params(timegate, 'timegate', attention=True) 171 | 172 | def get_gate_params(self): 173 | gates = [self.period_timegate, self.shift_timegate, self.on_end_timegate] 174 | if self.model=="PLSTM": 175 | return gates 176 | else: 177 | assert self.model=="HELSTM" 178 | gates = gates + [self.event_w_timegate, self.event_b_timegate, self.out_w_timegate, self.out_b_timegate] 179 | return gates 180 | 181 | def get_output_shape_for(self, input_shapes): 182 | input_shape = input_shapes[0] 183 | if self.only_return_final: 184 | return input_shape[0], self.num_units 185 | else: 186 | return input_shape[0], input_shape[1], self.num_units 187 | 188 | def get_output_for(self, inputs, deterministic=False, **kwargs): 189 | input = inputs[0] 190 | time_input = inputs[self.time_incoming_idx] 191 | event_input = inputs[self.event_incoming_idx] 192 | 193 | mask = None 194 | hid_init = None 195 | cell_init = None 196 | if self.mask_incoming_index > 0: 197 | mask = inputs[self.mask_incoming_index] 198 | if self.hid_init_incoming_index > 0: 199 | hid_init = inputs[self.hid_init_incoming_index] 200 | if self.cell_init_incoming_index > 0: 201 | cell_init = inputs[self.cell_init_incoming_index] 202 | 203 | if self.bn: 204 | input = self.bn.get_output_for(input) 205 | 206 | input = input.dimshuffle(1, 0, 2) 207 | time_input = time_input.dimshuffle(1, 0) 208 | 209 | seq_len, num_batch, _ = input.shape 210 | 211 | # Stack input weight matrices into a (num_inputs, 4*num_units) 212 | # matrix, which speeds up computation 213 | W_in_stacked = T.concatenate( 214 | [self.W_in_to_ingate, self.W_in_to_forgetgate, 215 | self.W_in_to_cell, self.W_in_to_outgate], axis=1) 216 | 217 | # Same for hidden weight matrices 218 | W_hid_stacked = T.concatenate( 219 | [self.W_hid_to_ingate, self.W_hid_to_forgetgate, 220 | self.W_hid_to_cell, self.W_hid_to_outgate], axis=1) 221 | 222 | # Stack biases into a (4*num_units) vector 223 | b_stacked = T.concatenate( 224 | [self.b_ingate, self.b_forgetgate, 225 | self.b_cell, self.b_outgate], axis=0) 226 | 227 | input = T.dot(input, W_in_stacked) + b_stacked 228 | 229 | # PHASED LSTM: If test time, off-phase means really shut. 230 | if deterministic: 231 | print('Using true off for testing.') 232 | off_slope = 0.0 233 | else: 234 | print('Using {} for off_slope.'.format(self.off_alpha)) 235 | off_slope = self.off_alpha 236 | 237 | if self.model != 'LSTM': 238 | # PHASED LSTM: Pregenerate broadcast vars. 239 | # Same neuron in different batches has same shift and period. Also, 240 | # precalculate the middle (on_mid) and end (on_end) of the open-phase 241 | # ramp. 242 | shift_broadcast = self.shift_timegate.dimshuffle(['x',0]) 243 | period_broadcast = T.abs_(self.period_timegate.dimshuffle(['x',0])) 244 | on_mid_broadcast = T.abs_(self.on_end_timegate.dimshuffle(['x',0])) * 0.5 * period_broadcast 245 | on_end_broadcast = T.abs_(self.on_end_timegate.dimshuffle(['x',0])) * period_broadcast 246 | 247 | if self.model == 'HELSTM': 248 | event_W = self.event_w_timegate 249 | event_b = T.shape_padleft(self.event_b_timegate, 2) 250 | out_W = self.out_w_timegate 251 | out_b = T.shape_padleft(self.out_b_timegate, 2) 252 | hid_attention = nonlinearities.leaky_rectify(T.dot(event_input, event_W) + event_b) 253 | out_attention = nonlinearities.sigmoid(T.dot(hid_attention, out_W) + out_b) 254 | out_attention = out_attention.dimshuffle(1, 0, 2) 255 | 256 | def slice_w(x, n): 257 | return x[:, n*self.num_units:(n+1)*self.num_units] 258 | 259 | def step(input_n, cell_previous, hid_previous, *args): 260 | gates = input_n + T.dot(hid_previous, W_hid_stacked) 261 | 262 | # Clip gradients 263 | if self.grad_clipping: 264 | gates = theano.gradient.grad_clip( 265 | gates, -self.grad_clipping, self.grad_clipping) 266 | 267 | # Extract the pre-activation gate values 268 | ingate = slice_w(gates, 0) 269 | forgetgate = slice_w(gates, 1) 270 | cell_input = slice_w(gates, 2) 271 | outgate = slice_w(gates, 3) 272 | 273 | if self.peepholes: 274 | # Compute peephole connections 275 | ingate += cell_previous*self.W_cell_to_ingate 276 | forgetgate += cell_previous*self.W_cell_to_forgetgate 277 | 278 | # Apply nonlinearities 279 | ingate = self.nonlinearity_ingate(ingate) 280 | forgetgate = self.nonlinearity_forgetgate(forgetgate) 281 | cell_input = self.nonlinearity_cell(cell_input) 282 | 283 | # Mix in new stuff 284 | cell = forgetgate*cell_previous + ingate*cell_input 285 | 286 | if self.peepholes: 287 | outgate += cell*self.W_cell_to_outgate 288 | outgate = self.nonlinearity_outgate(outgate) 289 | 290 | # Compute new hidden unit activation 291 | hid = outgate*self.nonlinearity(cell) 292 | return [cell, hid] 293 | 294 | # PHASED LSTM: The actual calculation of the time gate 295 | def calc_time_gate(time_input_n): 296 | # Broadcast the time across all units 297 | t_broadcast = time_input_n.dimshuffle([0,'x']) 298 | # Get the time within the period 299 | in_cycle_time = T.mod(t_broadcast + shift_broadcast, period_broadcast) 300 | # Find the phase 301 | is_up_phase = T.le(in_cycle_time, on_mid_broadcast) 302 | is_down_phase = T.gt(in_cycle_time, on_mid_broadcast)*T.le(in_cycle_time, on_end_broadcast) 303 | 304 | # Set the mask 305 | sleep_wake_mask = T.switch(is_up_phase, in_cycle_time/on_mid_broadcast, 306 | T.switch(is_down_phase, 307 | (on_end_broadcast-in_cycle_time)/on_mid_broadcast, 308 | off_slope*(in_cycle_time/period_broadcast))) 309 | 310 | return sleep_wake_mask 311 | 312 | #HELSTM: Mask the updates based on the time phase and event attention 313 | def step_masked(input_n, time_input_n, event_input_n, mask_n, cell_previous, hid_previous, *args): 314 | cell, hid = step(input_n, cell_previous, hid_previous, *args) 315 | 316 | if self.model != 'LSTM': 317 | # Get time gate openness 318 | sleep_wake_mask = calc_time_gate(time_input_n) 319 | 320 | if self.model == 'HELSTM': 321 | sleep_wake_mask = event_input_n*sleep_wake_mask 322 | 323 | # Sleep if off, otherwise stay a bit on 324 | cell = sleep_wake_mask*cell + (1.-sleep_wake_mask)*cell_previous 325 | hid = sleep_wake_mask*hid + (1.-sleep_wake_mask)*hid_previous 326 | 327 | #Skip over any input with mask 0 by copying the previous 328 | #hidden state; proceed normally for any input with mask 1. 329 | cell = T.switch(mask_n, cell, cell_previous) 330 | hid = T.switch(mask_n, hid, hid_previous) 331 | 332 | return [cell, hid] 333 | 334 | if mask is not None: 335 | # mask is given as (batch_size, seq_len). Because scan iterates 336 | # over first dimension, we dimshuffle to (seq_len, batch_size) and 337 | # add a broadcastable dimension 338 | mask = mask.dimshuffle(1, 0, 'x') 339 | else: 340 | mask = T.ones_like(time_input).dimshuffle(0,1,'x') 341 | 342 | if self.model != 'HELSTM': 343 | out_attention = event_input#if not using HELSTM, out_attention is of no use but still need to assign a value to complete sequences 344 | sequences = [input, time_input, out_attention, mask] 345 | step_fun = step_masked 346 | 347 | ones = T.ones((num_batch, 1)) 348 | if not isinstance(self.cell_init, Layer): 349 | # Dot against a 1s vector to repeat to shape (num_batch, num_units) 350 | cell_init = T.dot(ones, self.cell_init) 351 | 352 | if not isinstance(self.hid_init, Layer): 353 | # Dot against a 1s vector to repeat to shape (num_batch, num_units) 354 | hid_init = T.dot(ones, self.hid_init) 355 | 356 | # Scan op iterates over first dimension of input and repeatedly 357 | # applies the step function 358 | cell_out, hid_out = theano.scan( 359 | fn=step_fun, 360 | sequences=sequences, 361 | outputs_info=[cell_init, hid_init], 362 | go_backwards=self.backwards)[0] 363 | 364 | # When it is requested that we only return the final sequence step, 365 | # we need to slice it out immediately after scan is applied 366 | if self.only_return_final: 367 | hid_out = hid_out[-1] 368 | else: 369 | # dimshuffle back to (n_batch, n_time_steps, n_features)) 370 | hid_out = hid_out.dimshuffle(1, 0, 2) 371 | 372 | # if scan is backward reverse the output 373 | if self.backwards: 374 | hid_out = hid_out[:, ::-1] 375 | 376 | return hid_out 377 | -------------------------------------------------------------------------------- /MergeEmbedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | import theano.tensor as T 4 | 5 | import lasagne 6 | 7 | from lasagne.layers.base import MergeLayer, Layer 8 | from lasagne.layers.input import InputLayer 9 | 10 | class MergeEmbeddingLayer(MergeLayer): 11 | """ 12 | output: embed_event + T.sum(embed_feature_idx * tanh(feature_trans*(feature_value + embed_feature_b)), axis=2) 13 | """ 14 | 15 | def __init__(self, embed_event, embed_feature_idx, embed_feature_b, embed_feature_trans, feature_value, **kwargs): 16 | incomings = [embed_event, embed_feature_idx, embed_feature_b, embed_feature_trans, feature_value] 17 | super(MergeEmbeddingLayer, self).__init__(incomings, **kwargs) 18 | 19 | def get_output_shape_for(self, input_shapes): 20 | input_shape = input_shapes[0] 21 | return input_shape 22 | 23 | def get_output_for(self, inputs, deterministic=False, **kwargs): 24 | event = inputs[0]#(None, 1000, embed) 25 | feature_idx = inputs[1] #(None, 1000, feature_num, embed) 26 | feature_b = inputs[2] #(None, 1000, feature_num, 1) 27 | feature_trans = inputs[3]#(None, 1000, feature_num, 1) 28 | feature_value = inputs[4]#(None, 1000, feature_num) 29 | value_up = T.shape_padright(feature_value, 1)#(None, 1000, feature_num, 1) 30 | bias_value = feature_trans*(value_up + feature_b) 31 | bias_value_broad = T.addbroadcast(bias_value, 3)#make the last axis broadcastable 32 | v_idx = T.sum(feature_idx * lasagne.nonlinearities.tanh(bias_value_broad), axis=2)#(None, 1000, embed) 33 | return v_idx + event 34 | 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HELSTM 2 | This is the model "Heterogeneous Event LSTM"(HELSTM) for the paper "Learning the Joint Representation of Heterogeneous Temporal Events for Clinical Endpoint Prediction" 3 | 4 | You can find the paper [here](https://arxiv.org/abs/1803.04837) 5 | 6 | task.py shows how we used HELSTM to do end-to-end prediction. Data will be load from "/data/", or you can change data_path in task.py. 7 | 8 | We divided all data into three parts: train_data, valid_data and test_datam, and processed data into h5py format, containing "time", "label", "event", "feature_id" and "feature_value". Data sequences should have the same length (by padding 0 in the end) and we set length to 1000 in task.py. 9 | 10 | How to use the data generator: 11 | 12 | 1.put extractor.cpp and dataExt.py under the same directory as the origin mimic data. 13 | 14 | 2.compile and run extractor.cpp as: 15 | g++ extractor.cpp -o extractor 16 | ./extractor 17 | 18 | 3.run dataExt.py as 19 | python dataExt.py 20 | -------------------------------------------------------------------------------- /dataExt.py: -------------------------------------------------------------------------------- 1 | import io 2 | import h5py 3 | import sys 4 | import time 5 | import numpy as np 6 | 7 | def getPast(st): 8 | pos = st.find("-") 9 | pos2 = st.find("-", pos + 1) 10 | pos3 = st.find(" ", pos2 + 1) 11 | pos4 = st.find(":", pos3 + 1) 12 | pos5 = st.find(":", pos4 + 1) 13 | year = int(st[0:pos]) 14 | month = int(st[pos + 1:pos2]) 15 | day = int(st[pos2 + 1:pos3]) 16 | hour = int(st[pos3 + 1:pos4]) 17 | minute = int(st[pos4 + 1:pos5]) 18 | second = int(st[pos5 + 1:len(st)]) 19 | hour = hour - 12 20 | if (hour < 0): 21 | hour = hour + 24 22 | day = day - 1 23 | if (day <= 0): 24 | month = month - 1 25 | day = day + 30 26 | if (month in [0, 1, 3, 5, 7, 8, 10]): 27 | day = day + 1 28 | if (month == 2): 29 | day = day - 2 30 | if (month <= 0): 31 | month = month + 12 32 | year = year - 1 33 | return str(year) + "-" + str(month) + "-" + str(day) + " " + str(hour) + ":" + str(minute) + ":" + str(second) 34 | 35 | f = h5py.File('Lab.h5', 'w') 36 | f2 = open('dataSeq.txt', 'r') 37 | sys.stdin = f2 38 | dt = [["" for j in range(5000)] for i in range(45563)] 39 | print 1 40 | evt = [[0 for j in range(5000)] for i in range(45563)] 41 | print 2 42 | ftr = [[[0.0 for k in range(6)] for j in range(5000)] for i in range(45563)] 43 | print 3 44 | lbl = [[[0.0 for k in range(5)] for j in range(3682)] for i in range(45563)] 45 | pt = [] 46 | pNum = 0 47 | print 'read start' 48 | n = input() 49 | for samples in range(n): 50 | print str(samples) + "/" + str(n) 51 | st = f2.readline() 52 | pos = st.find("\t") 53 | if (pos == -1): 54 | break; 55 | patient = int(st[0:pos]) 56 | st = st[pos + 1:len(st)] 57 | pos = st.find("\t") 58 | eventNum = int(st[0:pos]) 59 | labelNum = int(st[pos + 1:-1]) 60 | #print eventNum 61 | eNum = 0 62 | lNum = 0 63 | startDate = 0 64 | for i in range(eventNum): 65 | st = f2.readline() 66 | pos = st.find("\t") 67 | pos2 = st.find("\t", pos + 1) 68 | #print pos 69 | #print st 70 | labelFlag = int(st[0:pos]) 71 | event = int(st[pos + 1:pos2]) 72 | date = st[pos2 + 1:-1] 73 | dt[pNum][eNum] = date 74 | evt[pNum][eNum] = event 75 | ctg = 0 76 | if (labelFlag == 1): 77 | ctg = input() 78 | featureNum = input() 79 | for j in range(featureNum): 80 | st = f2.readline() 81 | pos = st.find("\t") 82 | ftr[pNum][eNum][j * 2] = float(st[0:pos]) 83 | ftr[pNum][eNum][j * 2 + 1] = float(st[pos + 1:-1]) 84 | while (featureNum < 3): 85 | ftr[pNum][eNum][featureNum * 2] = 0 86 | ftr[pNum][eNum][featureNum * 2 + 1] = 0 87 | featureNum = featureNum + 1 88 | if (labelFlag == 1): 89 | lbl[pNum][lNum][0] = evt[pNum][eNum] 90 | lbl[pNum][lNum][1] = ftr[pNum][eNum][1] 91 | lbl[pNum][lNum][2] = ctg 92 | pastTime = getPast(date) 93 | while (dt[pNum][startDate] < pastTime and startDate < eNum): 94 | startDate = startDate + 1 95 | startDate = startDate - 1 96 | lbl[pNum][lNum][3] = startDate 97 | lbl[pNum][lNum][4] = startDate - 1000 + 1 98 | if (lbl[pNum][lNum][4] < 0): 99 | lbl[pNum][lNum][4] = 0 100 | lNum = lNum + 1 101 | eNum = eNum + 1 102 | pt.append(patient) 103 | pNum = pNum + 1 104 | 105 | print "read done" 106 | #print seqn 107 | #grp.create_dataset("row_id", data = rowid) 108 | #grp.create_dataset("subject_id", data = subid) 109 | f.create_dataset("patient", data = pt) 110 | f.create_dataset("event", data = evt) 111 | f.create_dataset("time", data = dt) 112 | #f.create_dataset("event_catAtt", data = atr) 113 | f.create_dataset("feature", data = ftr) 114 | f.create_dataset("label", data = lbl) 115 | 116 | -------------------------------------------------------------------------------- /extractor.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace std; 10 | 11 | struct Event{ 12 | int id; 13 | vector featureNum; 14 | vector featureValue; 15 | bool flag; 16 | }; 17 | 18 | map > patientMap; 19 | map featureMap; 20 | 21 | string toString(int c){ 22 | string s = ""; 23 | while (c > 0){ 24 | s = (char)((c % 10) + '0') + s; 25 | c = c / 10; 26 | } 27 | while (s.size() < 2) s = '0' + s; 28 | return s; 29 | } 30 | 31 | string getTime(string time){ 32 | int pos = time.find(" ", 0); 33 | int pos2 = time.find(":", 0); 34 | int c = atoi((time.substr(pos + 1, pos2 - pos - 1)).c_str()); 35 | if (c > 12) c -= 12; 36 | else c = 0; 37 | string h = toString(c); 38 | if (c < 10) h = "0" + h; 39 | return time.substr(0, pos) + " " + h + ":" + time.substr(pos2 + 1, time.size() - pos2 - 1); 40 | } 41 | 42 | void getFeature(string s, Event & e){ 43 | s = s + ' '; 44 | while (s.find(" ", 0) != string::npos){ 45 | int pos = s.find(" ", 0); 46 | int p = s.find(":", 0); 47 | string sa = s.substr(0, p); 48 | string sb = s.substr(p + 1, pos - p - 1); 49 | s.erase(0, pos + 1); 50 | int num = atoi(sa.c_str()); 51 | double value = atof(sb.c_str()); 52 | e.featureNum.push_back(num); 53 | e.featureValue.push_back(value); 54 | } 55 | } 56 | 57 | void getExt(string st, bool flag){ 58 | ifstream f0(st.c_str()); 59 | string s; 60 | while (getline(f0, s)){ 61 | Event e; 62 | e.flag = flag; 63 | s = s.substr(0, s.size() - 1); 64 | s = s + '\t'; 65 | int pos = s.find("\t", 0); 66 | string tmp = s.substr(0, pos); 67 | int num = atoi(tmp.c_str()); 68 | s.erase(0, pos + 1); 69 | pos = s.find("\t", 0); 70 | tmp = s.substr(0, pos); 71 | int patient = atoi(tmp.c_str()); 72 | s.erase(0, pos + 1); 73 | pos = s.find("\t", 0); 74 | string feature = s.substr(0, pos); 75 | getFeature(feature, e); 76 | s.erase(0, pos + 1); 77 | pos = s.find("\t", 0); 78 | string date = s.substr(0, pos); 79 | s.erase(0, pos + 1); 80 | if (patientMap.find(patient) == patientMap.end()){ 81 | multimap eventMap; 82 | patientMap.insert(make_pair(patient, eventMap)); 83 | } 84 | e.id = num; 85 | patientMap[patient].insert(make_pair(date, e)); 86 | } 87 | } 88 | 89 | void getApart(string st){ 90 | ifstream f0(st.c_str()); 91 | string s; 92 | while (getline(f0, s)){ 93 | s = s.substr(0, s.size() - 1); 94 | s = s.substr(1, s.size() - 2); 95 | s = s + ','; 96 | int pos = s.find(",", 0); 97 | int pos2 = s.find(":", 0); 98 | string tmp = s.substr(pos2 + 2, pos - pos2 - 2); 99 | int num = atoi(tmp.c_str()); 100 | s.erase(0, pos + 2); 101 | pos = s.find(",", 0); 102 | s.erase(0, pos + 2); 103 | pos = s.find(",", 0); 104 | string feature = s.substr(0, pos); 105 | s.erase(0, pos + 2); 106 | if (s.find("labevents", 0) == string::npos) continue; 107 | if (feature.find("abnormal", 0) != string::npos){ 108 | featureMap.insert(make_pair(num, 1)); 109 | } else if (feature.find("delta", 0) != string::npos){ 110 | featureMap.insert(make_pair(num, 2)); 111 | }else{ 112 | featureMap.insert(make_pair(num, 0)); 113 | } 114 | } 115 | } 116 | 117 | int main(){ 118 | getExt("chartevent_1.tsv", false); 119 | getExt("chartevent_2.tsv", false); 120 | getExt("chartevent_3.tsv", false); 121 | getExt("chartevent_4.tsv", false); 122 | getExt("chartevent_5.tsv", false); 123 | getExt("chartevent_6.tsv", false); 124 | getExt("chartevent_8.tsv", false); 125 | getExt("chartevent_9.tsv", false); 126 | getExt("chartevent_10.tsv", false); 127 | getExt("chartevent_11.tsv", false); 128 | getExt("chartevent_12.tsv", false); 129 | getExt("chartevent_13.tsv", false); 130 | getExt("chartevent_14.tsv", false); 131 | getExt("admissions.admittime.tsv", false); 132 | getExt("admissions.deathtime.tsv", false); 133 | getExt("admissions.dischtime.tsv", false); 134 | getExt("datetimeevents.tsv", false); 135 | getExt("icustays.tsv", false); 136 | getExt("inputevents_cv.tsv", false); 137 | getExt("inputevents_mv.tsv", false); 138 | getExt("labevents.tsv", true); 139 | getExt("outputevents.tsv", false); 140 | getExt("procedureevents.tsv", false); 141 | getApart("feature_info.tsv"); 142 | ofstream f0("dataSeq.txt"); 143 | map >::iterator i; 144 | for (i = patientMap.begin(); i != patientMap.end(); i++){ 145 | int num = i->first; 146 | int sum = 0; 147 | multimap::iterator j; 148 | multimap::iterator start = patientMap[num].begin(); 149 | //f0<second).flag) continue; 155 | if (sum < 100) continue; 156 | multimap::iterator l; 157 | string sttime = getTime(j->first); 158 | int eventsum = 0; 159 | for (l = start; l != j; l++){ 160 | if (l->first <= sttime){ 161 | eventsum++; 162 | }else{ 163 | break; 164 | } 165 | } 166 | if (eventsum < 100) continue; 167 | f0< 0; l++, eventsum--){ 169 | string time = l->first; 170 | Event e = l->second; 171 | f0<second).id<<' '<<(j->second).featureValue[0]<<' '<second).id]<0, dtype='int8') 26 | f.close() 27 | print 'done.' 28 | return event, feature_id, feature_value, label, times, mask 29 | 30 | def load_data_all(): 31 | global ICU_test_data 32 | global ICU_train_data 33 | global ICU_valid_data 34 | data_path = "/data/" 35 | file_name = "ICUIn_test_1000.h5" 36 | ICU_test_data = load_data(data_path, file_name) 37 | file_name = "ICUIn_train_1000.h5" 38 | ICU_train_data = load_data(data_path, file_name) 39 | file_name = "ICUIn_valid_1000.h5" 40 | ICU_valid_data = load_data(data_path, file_name) 41 | 42 | def get_data(set_name, kind): 43 | global ICU_test_data 44 | global ICU_train_data 45 | global ICU_valid_data 46 | if(kind=="test"): 47 | return ICU_test_data 48 | elif(kind=="train"): 49 | return ICU_train_data 50 | else: 51 | assert kind=="valid" 52 | return ICU_valid_data 53 | 54 | class ExponentialUniformInit(Initializer): 55 | """ 56 | """ 57 | def __init__(self, range): 58 | self.range = range 59 | 60 | def sample(self, shape): 61 | return floatX(np.exp(get_rng().uniform(low=self.range[0], 62 | high=self.range[1], size=shape))) 63 | 64 | def get_rnn(event_var, feature_idx, feature_value, mask_var, time_var, arch_size, num_attention = 0, embed_size=40, init_period = (1, 3), 65 | seq_len=1000, GRAD_CLIP=100, bn=False, model_type='LSTM'): 66 | 67 | #input layers 68 | l_in_event = lasagne.layers.InputLayer(shape=(None, seq_len), input_var = event_var) 69 | l_in_feature_idx = lasagne.layers.InputLayer(shape=(None, seq_len, 3), input_var = feature_idx) 70 | l_in_feature_value = lasagne.layers.InputLayer(shape=(None, seq_len, 3), input_var = feature_value) 71 | l_mask = lasagne.layers.InputLayer(shape=(None, seq_len), input_var=mask_var) 72 | l_t = lasagne.layers.InputLayer(shape=(None, seq_len), input_var=time_var) 73 | 74 | #embed event 75 | embed_event = lasagne.layers.EmbeddingLayer(l_in_event, input_size=3418, output_size=embed_size) 76 | #embed feature_idx 77 | embed_feature_idx = lasagne.layers.EmbeddingLayer(l_in_feature_idx, input_size=649, output_size=embed_size) 78 | #embed feature_value bias 79 | embed_feature_b = lasagne.layers.EmbeddingLayer(l_in_feature_idx, input_size=649, output_size = 1) 80 | #embed feature_value trans 81 | embed_feature_trans = lasagne.layers.EmbeddingLayer(l_in_feature_idx, input_size=649, output_size = 1) 82 | 83 | embed_params = [embed_event.W, embed_feature_idx.W, embed_feature_b.W, embed_feature_trans.W] 84 | 85 | #get input_var 86 | l_in_merge = MergeEmbeddingLayer(embed_event, embed_feature_idx, embed_feature_b, embed_feature_trans, l_in_feature_value) 87 | 88 | if model_type=="LSTM": 89 | l_in_merge = lasagne.layers.ConcatLayer([l_in_merge, lasagne.layers.ReshapeLayer(l_t, [-1, seq_len, 1])], axis=2) 90 | 91 | l_forward = HELSTMLayer(incoming=l_in_merge, time_input=l_t, event_input=embed_event, num_units=arch_size[1], 92 | num_attention=num_attention, model=model_type, mask_input=l_mask, 93 | ingate=Gate(), 94 | forgetgate=Gate(), 95 | cell=Gate(W_cell=None, nonlinearity=lasagne.nonlinearities.tanh), 96 | outgate=Gate(), 97 | nonlinearity=lasagne.nonlinearities.tanh, 98 | grad_clipping=GRAD_CLIP, 99 | bn=bn, 100 | only_return_final=True, 101 | timegate=HELSTMGate( 102 | Period=ExponentialUniformInit(init_period), 103 | Shift=lasagne.init.Uniform((0., 1000)), 104 | On_End=lasagne.init.Constant(0.05))) 105 | 106 | gate_params = [] 107 | if model_type != 'LSTM': 108 | gate_params = l_forward.get_gate_params() 109 | 110 | # Softmax 111 | l_dense = lasagne.layers.DenseLayer(l_forward, num_units=arch_size[2],nonlinearity=lasagne.nonlinearities.leaky_rectify) 112 | l_out = lasagne.layers.NonlinearityLayer(l_dense, nonlinearity=lasagne.nonlinearities.softmax) 113 | return l_out, gate_params, embed_params 114 | 115 | def get_train_and_val_fn(inputs, target_var, network): 116 | # Get network output 117 | prediction = lasagne.layers.get_output(network) 118 | # Calculate training accuracy 119 | train_acc = T.mean(T.eq(T.argmax(prediction, axis=1), target_var), dtype=theano.config.floatX) 120 | # Calculate crossentropy between predictions and targets 121 | loss = lasagne.objectives.categorical_crossentropy(prediction, target_var) 122 | loss = loss.mean() 123 | 124 | # Fetch trainable parameters 125 | params = lasagne.layers.get_all_params(network, trainable=True) 126 | # Calculate updates for the parameters given the loss 127 | updates = lasagne.updates.adam(loss, params, learning_rate=1e-3) 128 | 129 | # Fetch network output, using deterministic methods 130 | test_prediction = lasagne.layers.get_output(network, deterministic=True) 131 | # Again calculate crossentropy, this time using (test-time) determinstic pass 132 | test_loss = lasagne.objectives.categorical_crossentropy(test_prediction, target_var) 133 | test_loss = test_loss.mean() 134 | # Also, create an expression for the classification accuracy: 135 | test_acc = T.mean(T.eq(T.argmax(test_prediction, axis=1), target_var), dtype=theano.config.floatX) 136 | # Add in the targets to the function inputs 137 | fn_inputs = inputs + [target_var] 138 | # Compile a train function with the updates, returning loss and accuracy 139 | train_fn = theano.function(fn_inputs, [loss, train_acc, prediction], updates=updates) 140 | # Compile a second function computing the validation loss and accuracy: 141 | val_fn = theano.function(fn_inputs, [test_loss, test_acc, test_prediction]) 142 | 143 | return train_fn, val_fn 144 | 145 | def get_minibatches_idx(n, minibatch_size, shuffle=True): 146 | """ 147 | Used to shuffle the dataset at each iteration. 148 | """ 149 | 150 | idx_list = np.arange(n, dtype="int32") 151 | 152 | if shuffle: 153 | np.random.shuffle(idx_list) 154 | 155 | minibatches = [] 156 | minibatch_start = 0 157 | for i in range(n // minibatch_size): 158 | minibatches.append(idx_list[minibatch_start: 159 | minibatch_start + minibatch_size]) 160 | minibatch_start += minibatch_size 161 | 162 | if (minibatch_start != n): 163 | # Make a minibatch out of what is left 164 | minibatches.append(idx_list[minibatch_start:]) 165 | 166 | return zip(range(len(minibatches)), minibatches) 167 | 168 | def valid(train_times, valid_data, test_fn, valid_file): 169 | # valid the accuracy 170 | # And a full pass over the validation data: 171 | valid_err = 0 172 | valid_acc = 0 173 | valid_auc = 0 174 | valid_batches = 0 175 | y_true_all = [] 176 | y_score_all = [] 177 | 178 | valid_event, valid_feature_idx, valid_feature_value, valid_mask, valid_time, valid_label, num_valid, batch_size = valid_data 179 | valid_kf = get_minibatches_idx(num_valid, batch_size) 180 | num_valid_batches = len(valid_kf) 181 | for _, valid_batch in valid_kf: 182 | start_time = time.clock() 183 | b_event = valid_event[valid_batch] 184 | b_feature_idx = valid_feature_idx[valid_batch] 185 | b_feature_value = valid_feature_value[valid_batch] 186 | b_mask = valid_mask[valid_batch] 187 | b_t = valid_time[valid_batch] 188 | b_label = valid_label[valid_batch] 189 | err, acc, pre = test_fn(b_event, b_feature_idx, b_feature_value, b_t, b_mask, b_label) 190 | y_true = np.asarray(b_label) 191 | y_score = np.asarray(pre)[:,1] 192 | y_true_all += list(y_true) 193 | y_score_all += list(y_score) 194 | valid_err += err 195 | valid_acc += acc 196 | valid_batches += 1 197 | 198 | print("\tBatch {} of {} : Loss: {} | Accuracy: {} ".format(valid_batches, num_valid_batches, 199 | err, acc*100.)) 200 | print("Time:", (time.clock()-start_time)) 201 | valid_err /= valid_batches 202 | valid_acc = valid_acc*100./valid_batches 203 | y_true_all = np.asarray(y_true_all) 204 | y_score_all = np.asarray(y_score_all) 205 | auc_all = roc_auc_score(y_true_all, y_score_all) 206 | ap_all = average_precision_score(y_true_all, y_score_all) 207 | print("Valid loss:", valid_err) 208 | print("Valid acc:", valid_acc) 209 | valid_file.write("Train times:{} Loss:{} Acc:{} Auc:{} Prc:{}\n".format(train_times, valid_err, valid_acc, auc_all, ap_all)) 210 | 211 | def model(embed, hidden, attention, _period, model_type, data_set, name, seed): 212 | np.random.seed(seed) 213 | if model_type!="HELSTM": 214 | attention = 0 215 | prefix = data_set+"_" 216 | num_attention = attention 217 | arch_size = [None, hidden, 2] 218 | embed_size = embed 219 | max_epoch = 8 220 | batch_size = 128 221 | valid_freq = 100 222 | 223 | input_event = T.matrix('input_event', dtype='int16') 224 | input_feature_idx = T.tensor3('input_idx', dtype='int16') 225 | input_feature_value = T.tensor3('input_value', dtype='float32') 226 | input_time = T.matrix('input_time', dtype='float32') 227 | input_mask = T.matrix('input_mask', dtype='int8') 228 | input_target = T.ivector('input_target') 229 | 230 | print 'load test data' 231 | test_event, test_feature_idx, test_feature_value, test_label, test_time, test_mask = get_data(data_set, "test") 232 | num_test = len(test_event) 233 | #pack them all for further valid use 234 | test_data = (test_event, test_feature_idx, test_feature_value, test_mask, test_time, test_label, num_test, batch_size) 235 | 236 | print 'load train data' 237 | train_event, train_feature_idx, train_feature_value, train_label, train_time, train_mask = get_data(data_set, "train") 238 | num_train = len(train_event) 239 | #pack them all for further valid use 240 | train_data = (train_event, train_feature_idx, train_feature_value, train_mask, train_time, train_label, num_train, batch_size) 241 | 242 | print 'load valid data' 243 | valid_event, valid_feature_idx, valid_feature_value, valid_label, valid_time, valid_mask = get_data(data_set, "valid") 244 | num_valid = len(valid_event) 245 | #pack them all for further valid use 246 | valid_data = (valid_event, valid_feature_idx, valid_feature_value, valid_mask, valid_time, valid_label, num_valid, batch_size) 247 | 248 | 249 | print 'Build network' 250 | network, gate_params, embed_params = get_rnn(input_event, input_feature_idx, input_feature_value, input_mask, 251 | input_time, arch_size, num_attention = num_attention, embed_size = embed_size, init_period = _period, model_type = model_type) 252 | 253 | print 'Compile' 254 | train_fn, test_fn = get_train_and_val_fn([input_event, input_feature_idx, input_feature_value, input_time, input_mask], input_target, network) 255 | 256 | print 'Start training' 257 | train_file = open(name+"/"+prefix+"train.txt",'w') 258 | valid_file = open(name+"/"+prefix+"valid.txt",'w') 259 | test_file = open(name+"/"+prefix + "test.txt",'w') 260 | 261 | train_times = 0 262 | for epoch in xrange(max_epoch): 263 | print epoch 264 | train_err = 0 265 | train_acc = 0 266 | train_auc = 0 267 | train_y_true_all = [] 268 | train_y_score_all = [] 269 | train_batches = 0 270 | 271 | kf = get_minibatches_idx(num_train, batch_size) 272 | num_train_batches = len(kf) 273 | for _, train_batch in kf: 274 | train_times += 1 275 | start_time = time.clock() 276 | b_event = train_event[train_batch] 277 | b_feature_idx = train_feature_idx[train_batch] 278 | b_feature_value = train_feature_value[train_batch] 279 | b_mask = train_mask[train_batch] 280 | b_t = train_time[train_batch] 281 | b_label = train_label[train_batch] 282 | err, acc, pre = train_fn(b_event, b_feature_idx, b_feature_value, b_t, b_mask, b_label) 283 | dat = np.asarray(pre) 284 | dat_shape = dat.shape 285 | train_err += err 286 | train_acc += acc 287 | train_batches += 1 288 | 289 | print("\tBatch {} of {} in epoch {}: Loss: {} | Accuracy: {} ".format(train_batches, num_train_batches, 290 | epoch, err, acc*100. )) 291 | print("Time:", (time.clock()-start_time)) 292 | if(train_times%valid_freq == 0): 293 | valid(train_times, valid_data, test_fn, valid_file) 294 | valid(train_times, test_data, test_fn, test_file) 295 | 296 | print('Completed.') 297 | train_file.close() 298 | valid_file.close() 299 | test_file.close() 300 | 301 | def choose_model(embed, hidden, attention, period, model_type, name, seed): 302 | name = '{}-{}'.format(name, seed) 303 | os.mkdir(name) 304 | f = open(name+"/log.txt",'w') 305 | f.write("model:{} embed:{} hidden:{} attention:{} period:{} {} seed:{}\n".format(model_type, embed, hidden, attention, period[0], period[1], seed)) 306 | f.close() 307 | model(embed, hidden, attention, period, model_type, "ICU", name, seed) 308 | 309 | if __name__ == '__main__': 310 | load_data_all() 311 | choose_model(32, 64, 32, (1., 2.), "HELSTM", "exp_HELSTM", 1) 312 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/build_event.py: -------------------------------------------------------------------------------- 1 | from util import * 2 | from select_feature import load_value_type_text, TypeFeature 3 | from extractor import parse_line 4 | import glob 5 | import itertools 6 | 7 | class Feature(): 8 | def __init__(self, index, value): 9 | self.index = index 10 | self.value = value 11 | 12 | def __cmp__(self, other): 13 | return cmp(self.index, other.index) 14 | 15 | def __str__(self): 16 | return str(self.index) + ":" + str(self.value) 17 | 18 | @staticmethod 19 | def parse_features(string): 20 | features = [] 21 | string = string.strip() 22 | if string == "": 23 | return [] 24 | for pair in string.split(" "): 25 | index, value = pair.split(":") 26 | features.append(Feature(int(index), float(value))) 27 | return features 28 | 29 | 30 | class FeatureExtractor(): 31 | def __init__(self, feature_index, value_index): 32 | self.feature_index = feature_index 33 | self.value_index = value_index 34 | 35 | 36 | class TimeFeatureExtractor(FeatureExtractor): 37 | nerror = 0 38 | def extract(self, time, values, base): 39 | value = values[self.value_index] 40 | value = parse_time(value) 41 | if value is None: 42 | return None 43 | else: 44 | duration = (value - time).total_seconds()/3600.0 45 | if duration < 0: 46 | TimeFeatureExtractor.nerror += 1 47 | return None 48 | return Feature(base + self.feature_index, duration) 49 | 50 | 51 | class NumberFeatureExtractor(FeatureExtractor): 52 | def extract(self, time, values, base): 53 | value = values[self.value_index] 54 | value = parse_number(value) 55 | if value is None: 56 | return None 57 | else: 58 | return Feature(base + self.feature_index, value) 59 | 60 | class Event(): 61 | def __init__(self, event_idx, features, pid, time): 62 | self.index = event_idx 63 | self.features = features 64 | self.pid = pid 65 | self.time = time 66 | 67 | def is_valid(self): 68 | for feature in self.features: 69 | if feature is None: 70 | return False 71 | return True 72 | 73 | def __cmp__(self, other): 74 | ret = cmp(self.time, other.time) 75 | if ret != 0: 76 | return ret 77 | else: 78 | return cmp(self.index, other.index) 79 | 80 | def __str__(self): 81 | out = [self.index, self.pid, " ".join(map(str, self.features)), self.time] 82 | out = map(str, out) 83 | return "\t".join(out) 84 | 85 | 86 | 87 | @staticmethod 88 | def load_from_line(line): 89 | parts = line.strip().split('\t') 90 | index = int(parts[0]) 91 | pid = int(parts[1]) 92 | features = Feature.parse_features(parts[2]) 93 | time = parse_time(parts[3]) 94 | return Event(index, features, pid, time) 95 | 96 | 97 | class EventBuilder(): 98 | sep = "|&|" 99 | def __init__(self, type_feature, text_map): 100 | self.type_feature = type_feature 101 | self.event_builder_init(text_map) 102 | self.value_builder_init(text_map) 103 | 104 | def builder_des(self): 105 | des_list = [] 106 | for event_fac in reversed(self.feature_texts): 107 | if len(event_fac) > 0: 108 | des_list.append(event_fac) 109 | if len(des_list) == 0: 110 | return [""] 111 | else: 112 | return [x for x in itertools.product(*des_list)] 113 | 114 | 115 | 116 | def event_builder_init(self, text_map): 117 | self.event_factor = [] 118 | self.feature_texts = [] 119 | self.orders = [] 120 | self.event_dim = 1 121 | for feature in self.type_feature.features: 122 | order = feature.order 123 | self.orders.append(order) 124 | value_feature_name = self.type_feature.rtype + "#" + str(order) 125 | if feature.main_type == "text": 126 | text = sorted(text_map[value_feature_name]) 127 | self.event_factor.append(self.event_dim) 128 | self.event_dim *= len(text) 129 | self.feature_texts.append(text) 130 | else: 131 | self.event_factor.append(0) 132 | self.feature_texts.append([]) 133 | 134 | def value_builder_init(self, text_map): 135 | self.feature_dim = 0 136 | self.extractors = [] 137 | for value_feature in self.type_feature.features: 138 | if value_feature.main_type == "time": 139 | self.extractors.append(TimeFeatureExtractor(self.feature_dim, value_feature.order)) 140 | self.feature_dim += 1 141 | elif value_feature.main_type == "number": 142 | self.extractors.append(NumberFeatureExtractor(self.feature_dim, value_feature.order)) 143 | self.feature_dim += 1 144 | 145 | def set_event_base(self, event_base): 146 | self.event_base = event_base 147 | return self.event_base + self.event_dim 148 | 149 | def set_feature_base(self, feature_base): 150 | self.feature_base = feature_base 151 | return self.feature_base + self.feature_dim 152 | 153 | def build_event(self, time, values): 154 | values = [values[order] for order in self.orders] 155 | event = 0 156 | for i in range(len(values)): 157 | if self.event_factor[i] > 0: 158 | value = values[i] 159 | idx = self.feature_texts[i].index(value.strip().lower()) 160 | event = event + idx * self.event_factor[i] 161 | return event + self.event_base 162 | 163 | def build_features(self, time, values): 164 | features = [] 165 | for extractor in self.extractors: 166 | features.append(extractor.extract(time, values, self.feature_base)) 167 | return features 168 | 169 | def build(self, row): 170 | parts = parse_line(row) 171 | pid = int(parts[0].split("_")[0]) 172 | time = parse_time(parts[1]) 173 | values = parts[3].split(EventBuilder.sep) 174 | event_idx = self.build_event(time, values) 175 | features = self.build_features(time, values) 176 | event = Event(event_idx, features, pid, time) 177 | if event.is_valid(): 178 | return event 179 | else: 180 | return None 181 | 182 | def load_type_features(filepath): 183 | s = "" 184 | type_features = [] 185 | for line in file(filepath): 186 | if line.startswith("\t"): 187 | s = s + line 188 | else: 189 | if s != "": 190 | s = s.rstrip('\n') 191 | type_features.append(TypeFeature.load_from_str(s)) 192 | s = line 193 | if s != "": 194 | s = s.rstrip('\n') 195 | type_features.append(TypeFeature.load_from_str(s)) 196 | return type_features 197 | 198 | def gen_builders(type_features, text_map, build_des_file, feature_des_file): 199 | ''' 200 | event: event 0 is padding, event 1 is intervel 201 | feature: feature 0 is during time feature 202 | ''' 203 | builders = {} 204 | event_dim = 2 205 | feature_dim = 1 206 | f = file(build_des_file, 'w') 207 | fea_f = file(feature_des_file, 'w') 208 | for type_feature in type_features: 209 | builder = EventBuilder(type_feature, text_map) 210 | l = event_dim 211 | event_dim = builder.set_event_base(event_dim) 212 | r = event_dim 213 | event_des = builder.builder_des() 214 | for i in range(l, r): 215 | f.write("%d %s %s\n" %(i, type_feature.rtype, "\t".join(event_des[i - l]))) 216 | l = feature_dim 217 | feature_dim = builder.set_feature_base(feature_dim) 218 | r = feature_dim 219 | if r > l: 220 | fea_f.write("%s\t%s\n" %(type_feature.rtype, "\t".join(map(str, range(l,r))))) 221 | builders[type_feature.rtype] = builder 222 | f.close() 223 | fea_f.close() 224 | print "event_dim =", event_dim 225 | print "feature_dim =", feature_dim 226 | return builders 227 | 228 | def build_event(filepath, builders): 229 | name = ".".join(os.path.basename(filepath).split(".")[:-1]) 230 | print "build event =", name 231 | outf = file(os.path.join(event_dir, name + ".tsv"), 'w') 232 | for line in file(filepath): 233 | line = line.strip() 234 | parts = parse_line(line) 235 | rtype = parts[2] 236 | if rtype in builders: 237 | event = builders[rtype].build(line) 238 | if event is not None: 239 | outf.write(str(event)+"\n") 240 | outf.close() 241 | 242 | def print_event(builders, filepath): 243 | f = file(filepath, 'w') 244 | builders = sorted(builders.values(), key = lambda x:x.event_base) 245 | for builder in builders: 246 | st = builder.event_base 247 | ed = builder.event_base + builder.event_dim - 1 248 | f.write(str(st) + "-" + str(ed) + '\t') 249 | f.write(builder.type_feature.rtype) 250 | f.write("\n") 251 | f.close() 252 | 253 | 254 | if __name__ == '__main__': 255 | text_map = load_value_type_text(os.path.join(result_dir, "value_type_text.tsv")) 256 | type_features = load_type_features(os.path.join(result_dir, 'selected_features.tsv')) 257 | event_des_file = os.path.join(result_dir, "event_des_text.tsv") 258 | feature_des_file = os.path.join(result_dir, "feature_des.tsv") 259 | builders = gen_builders(type_features, text_map, event_des_file, feature_des_file) 260 | if not os.path.exists(event_dir): 261 | os.mkdir(event_dir) 262 | for filepath in glob.glob(data_dir + "/*tsv"): 263 | name = os.path.basename(filepath) 264 | # if name in ["labevents.tsv", "datetimeevents.tsv"]: 265 | # if filepath.find("labevents") == -1: 266 | # continue 267 | build_event(filepath, builders) 268 | print "#TimeDuration < 0 error =", TimeFeatureExtractor.nerror 269 | 270 | # print_event(builders, "static_data/event_des.txt") 271 | # print builders['labevents.51294'].feature_texts 272 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/config/blacklist.reg: -------------------------------------------------------------------------------- 1 | # non-time feature 2 | ^diagnoses_icd* 3 | ^patients* 4 | ^procedures_icd* 5 | ^cptevents* 6 | 7 | # other 8 | ^datetimeevents.226515$ 9 | ^datetimeevents.226724$ 10 | ^microbioevents* 11 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/dataExt.py: -------------------------------------------------------------------------------- 1 | import io 2 | import h5py 3 | import sys 4 | import time 5 | import numpy as np 6 | 7 | def getPast(st): 8 | pos = st.find("-") 9 | pos2 = st.find("-", pos + 1) 10 | pos3 = st.find(" ", pos2 + 1) 11 | pos4 = st.find(":", pos3 + 1) 12 | pos5 = st.find(":", pos4 + 1) 13 | year = int(st[0:pos]) 14 | month = int(st[pos + 1:pos2]) 15 | day = int(st[pos2 + 1:pos3]) 16 | hour = int(st[pos3 + 1:pos4]) 17 | minute = int(st[pos4 + 1:pos5]) 18 | second = int(st[pos5 + 1:len(st)]) 19 | hour = hour - 12 20 | if (hour < 0): 21 | hour = hour + 24 22 | day = day - 1 23 | if (day <= 0): 24 | month = month - 1 25 | day = day + 30 26 | if (month in [0, 1, 3, 5, 7, 8, 10]): 27 | day = day + 1 28 | if (month == 2): 29 | day = day - 2 30 | if (month <= 0): 31 | month = month + 12 32 | year = year - 1 33 | return str(year) + "-" + str(month) + "-" + str(day) + " " + str(hour) + ":" + str(minute) + ":" + str(second) 34 | 35 | f = h5py.File('Lab.h5', 'w') 36 | f2 = open('dataSeq.txt', 'r') 37 | sys.stdin = f2 38 | print 'read start' 39 | n = input() 40 | dt = [["" for j in range(5000)] for i in range(n)] 41 | evt = [[0 for j in range(5000)] for i in range(n)] 42 | ftr = [[[0.0 for k in range(6)] for j in range(5000)] for i in range(n)] 43 | lbl = [[[0.0 for k in range(5)] for j in range(3682)] for i in range(n)] 44 | pt = [] 45 | pNum = 0 46 | for samples in range(n): 47 | print str(samples) + "/" + str(n) 48 | st = f2.readline() 49 | pos = st.find("\t") 50 | if (pos == -1): 51 | break; 52 | patient = int(st[0:pos]) 53 | st = st[pos + 1:len(st)] 54 | pos = st.find("\t") 55 | eventNum = int(st[0:pos]) 56 | labelNum = int(st[pos + 1:-1]) 57 | #print eventNum 58 | eNum = 0 59 | lNum = 0 60 | startDate = 0 61 | for i in range(eventNum): 62 | st = f2.readline() 63 | pos = st.find("\t") 64 | pos2 = st.find("\t", pos + 1) 65 | #print pos 66 | #print st 67 | labelFlag = int(st[0:pos]) 68 | event = int(st[pos + 1:pos2]) 69 | date = st[pos2 + 1:-1] 70 | dt[pNum][eNum] = date 71 | evt[pNum][eNum] = event 72 | ctg = 0 73 | if (labelFlag == 1): 74 | ctg = input() 75 | featureNum = input() 76 | for j in range(featureNum): 77 | st = f2.readline() 78 | pos = st.find("\t") 79 | ftr[pNum][eNum][j * 2] = float(st[0:pos]) 80 | ftr[pNum][eNum][j * 2 + 1] = float(st[pos + 1:-1]) 81 | while (featureNum < 3): 82 | ftr[pNum][eNum][featureNum * 2] = 0 83 | ftr[pNum][eNum][featureNum * 2 + 1] = 0 84 | featureNum = featureNum + 1 85 | if (labelFlag == 1): 86 | lbl[pNum][lNum][0] = evt[pNum][eNum] 87 | lbl[pNum][lNum][1] = ftr[pNum][eNum][1] 88 | lbl[pNum][lNum][2] = ctg 89 | pastTime = getPast(date) 90 | while (dt[pNum][startDate] < pastTime and startDate < eNum): 91 | startDate = startDate + 1 92 | startDate = startDate - 1 93 | lbl[pNum][lNum][3] = startDate 94 | lbl[pNum][lNum][4] = startDate - 1000 + 1 95 | if (lbl[pNum][lNum][4] < 0): 96 | lbl[pNum][lNum][4] = 0 97 | lNum = lNum + 1 98 | eNum = eNum + 1 99 | pt.append(patient) 100 | pNum = pNum + 1 101 | 102 | print "read done" 103 | #print seqn 104 | #grp.create_dataset("row_id", data = rowid) 105 | #grp.create_dataset("subject_id", data = subid) 106 | f.create_dataset("patient", data = pt) 107 | f.create_dataset("event", data = evt) 108 | f.create_dataset("time", data = dt) 109 | #f.create_dataset("event_catAtt", data = atr) 110 | f.create_dataset("feature", data = ftr) 111 | f.create_dataset("label", data = lbl) 112 | 113 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/event_des.py: -------------------------------------------------------------------------------- 1 | from util import * 2 | import json 3 | 4 | def load_event_des_pattern(): 5 | event_des = {} 6 | labevents_items = load_items(os.path.join(static_data_dir, 'labitem_code.tsv')) 7 | d_items = load_items(os.path.join(static_data_dir, 'item_code.tsv')) 8 | rtype = None 9 | for line in file(os.path.join(result_dir, "selected_features.tsv"), 'r'): 10 | line = line.rstrip() 11 | if line.startswith("\t"): 12 | line = line.lstrip('\t') 13 | p = line.split(" ") 14 | feature_type = p[1] 15 | if feature_type == "text": 16 | event_des[rtype]['feature'].append("text") 17 | else: 18 | event_des[rtype]['feature'].append('feature') 19 | else: 20 | rtype = line.split("#")[0] 21 | p = rtype.split(".") 22 | table = p[0] 23 | des = table 24 | if len(p) > 1: 25 | item_id = p[1] 26 | if is_number(p[1]): 27 | if table == "labevents": 28 | des = labevents_items[int(item_id)] 29 | else: 30 | des = d_items[int(item_id)] 31 | event_des[rtype] = {"des":des, "feature":[]} 32 | return event_des 33 | 34 | 35 | def column_name_map(): 36 | return { 37 | "admissions.admit": ["admission_type"], 38 | "admissions.disch": ["FLAG"], 39 | "admissions.death": ["FLAG"], 40 | "icustays": ["outtime"], 41 | "labevents": ["value", "flag"], 42 | "chartevents": ['value'], 43 | 'inputevents_cv': ['amount', 'rate'], 44 | "inputevents_mv": ['endtime', 'amount', 'rate'], 45 | "outputevents": ['value'], 46 | "procedureevents_mv": ['endtime', 'value'], 47 | "datetimeevents": ['Flag'] 48 | 49 | } 50 | 51 | def get_event_text_map(): 52 | ret = {} 53 | for line in file(os.path.join(result_dir, "event_des_text.tsv")): 54 | parts = line.strip("\n").split(" ") 55 | event_id = int(parts[0]) 56 | event_type = parts[1] 57 | value = " ".join(parts[2:]).split("\t") 58 | ret[event_id] = value 59 | return ret 60 | 61 | def get_id2rtype(): 62 | ret = {} 63 | for line in file(os.path.join(result_dir, "event_des_text.tsv")): 64 | parts = line.strip("\n").split(" ") 65 | event_id = int(parts[0]) 66 | event_type = parts[1] 67 | value = " ".join(parts[2:]).split("\t") 68 | ret[event_id] = event_type 69 | return ret 70 | 71 | def load_event_featureidx_map(): 72 | ret ={} 73 | for line in file(os.path.join(result_dir, 'feature_des.tsv')): 74 | parts = line.strip().split('\t') 75 | rtype = parts[0] 76 | ret[rtype] = [] 77 | for feature_idx in parts[1:]: 78 | ret[rtype].append(int(feature_idx)) 79 | return ret 80 | 81 | def get_feature(feature, idx): 82 | return feature[idx] 83 | 84 | class EventDescription: 85 | def __init__(self): 86 | self.column_map = column_name_map() 87 | self.event_text_map = get_event_text_map() 88 | self.event_des_pattern = load_event_des_pattern() 89 | self.id2rtype = get_id2rtype() 90 | self.event_featureidx_map = load_event_featureidx_map() 91 | 92 | def get_name(self, rtype): 93 | if rtype in self.column_map: 94 | return self.column_map[rtype] 95 | table = rtype.split('.')[0] 96 | return self.column_map[table] 97 | 98 | def reverse_text_feature_name(self, names, feature_types): 99 | f_names = [] 100 | text_names = [] 101 | for name, f_type in zip(names, feature_types): 102 | if f_type == "text": 103 | text_names.append(name) 104 | else: 105 | f_names.append(name) 106 | text_names.reverse() 107 | t_idx = 0 108 | f_idx = 0 109 | new_names = [] 110 | for f_type in feature_types: 111 | if f_type == "text": 112 | new_names.append(text_names[t_idx]) 113 | t_idx += 1 114 | else: 115 | new_names.append(f_names[f_idx]) 116 | f_idx += 1 117 | return new_names 118 | 119 | def get_des(self, event_id, feature_pair): 120 | text_features = self.event_text_map[event_id] 121 | rtype = self.id2rtype[event_id] 122 | names = self.get_name(rtype) 123 | 124 | num_feature_idx = self.event_featureidx_map.get(rtype, []) 125 | pattern = self.event_des_pattern[rtype] 126 | names = self.reverse_text_feature_name(names, pattern['feature']) 127 | text_idx = 0 128 | num_idx = 0 129 | features = [] 130 | ret = ["event = " + pattern['des'], '{'] 131 | for feature_type in pattern['feature']: 132 | if feature_type == "text": 133 | features.append(text_features[text_idx]) 134 | text_idx += 1 135 | else: 136 | idx = num_feature_idx[num_idx] 137 | num_idx += 1 138 | features.append(get_feature(feature_pair, idx)) 139 | for idx, feature in enumerate(features): 140 | name = names[idx] 141 | ret.append(name + " = " + str(feature)) 142 | ret.append('}') 143 | return ret 144 | 145 | def write_feature_info(out_path): 146 | event_des = EventDescription() 147 | outf = file(out_path, 'w') 148 | 149 | for event_id in range(2, max(event_des.id2rtype.keys()) + 1): 150 | rtype = event_des.id2rtype[event_id] 151 | 152 | names = event_des.get_name(rtype) 153 | 154 | obj = { 155 | "event_id": event_id, 156 | "rtype": rtype, 157 | 'text_feature': [], 158 | "feature": [], 159 | } 160 | 161 | num_feature_idx = event_des.event_featureidx_map.get(rtype, []) 162 | pattern = event_des.event_des_pattern[rtype] 163 | names = event_des.reverse_text_feature_name(names, pattern['feature']) 164 | text_features = event_des.event_text_map[event_id] 165 | text_idx = 0 166 | num_idx = 0 167 | for name, feature_type in zip(names, pattern['feature']): 168 | if feature_type == 'text': 169 | value = text_features[text_idx] 170 | text_idx += 1 171 | obj['text_feature'].append("%s=%s" %(name, value)) 172 | else: 173 | index = num_feature_idx[num_idx] 174 | num_idx += 1 175 | obj['feature'].append("%s at %d" %(name, index))\ 176 | 177 | outf.write('%s\n' %(json.dumps(obj)) ) 178 | 179 | outf.close() 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | if __name__ == "__main__": 188 | write_feature_info(os.path.join(result_dir, 'feature_info.tsv')) 189 | # event_des = EventDescription() 190 | # feas = np.zeros(649) 191 | # feas[233] = 50 192 | # feas[234] = 30 193 | # print "\n".join(event_des.get_des(1726, feas)) 194 | 195 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/extract.py: -------------------------------------------------------------------------------- 1 | from util import connect, date2str, time2str, data_dir 2 | import os 3 | import sys 4 | from datetime import timedelta 5 | from extractor import * 6 | 7 | true_test_extractor = ConstantExtractor(True) 8 | exist_test = lambda x: x is not None and x != '' 9 | 10 | def get_patients_extractors(data_dir, extractor_map): 11 | extractors = [] 12 | table = "patients" 13 | 14 | id_extractor = MultiExtractor(names = ['subject_id']) 15 | dob_time_extractor = TimeExtractor(name = 'dob', converter = date2str) 16 | dob_type_extractor = FmtExtractor(names = [], fmt = "patients.dob") 17 | dob_value_extractor = ConstantExtractor("FLAG") 18 | dob_test = lambda x: x.year >= 1900 19 | dob_test_extractor = TestExtractor(name = 'dob', test = dob_test) 20 | dob_extractor = ExtractorInfo(table, os.path.join(data_dir, table + '.dob.tsv'), id_extractor, 21 | dob_time_extractor, dob_type_extractor, 22 | dob_value_extractor, dob_test_extractor) 23 | extractors.append(dob_extractor) 24 | 25 | dod_out_path = os.path.join(data_dir, table + ".dod.tsv") 26 | dod_type_extractor = FmtExtractor(names = [], fmt = "patients.dod") 27 | dod_time_extractor = TimeExtractor(name = 'dod', converter = date2str) 28 | dod_value_extractor = ConstantExtractor("FLAG") 29 | dod_test = exist_test 30 | dod_test_extractor = TestExtractor(name = 'dod', test = dod_test) 31 | dod_extractor = ExtractorInfo(table, dod_out_path, id_extractor, 32 | dod_time_extractor, dod_type_extractor, 33 | dod_value_extractor, dod_test_extractor) 34 | extractors.append(dod_extractor) 35 | extractor_map[table] = extractors 36 | return table 37 | 38 | def get_admissions_extractors(data_dir, extractor_map): 39 | extractors = [] 40 | table = "admissions" 41 | 42 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id'], sep = "_") 43 | admit_outpath = os.path.join(data_dir, table + ".admittime.tsv") 44 | admit_time_ext = TimeExtractor(name = 'admittime', converter = time2str) 45 | admit_type_ext = FmtExtractor(names = [], fmt = "admissions.admit") 46 | admmit_value_ext = MultiExtractor(names = ['admission_type']) 47 | admit_test_ext = TestExtractor(name = 'admission_type', test = exist_test) 48 | admit_extractor = ExtractorInfo(table, admit_outpath, id_extractor, 49 | admit_time_ext, admit_type_ext, 50 | admmit_value_ext, admit_test_ext) 51 | extractors.append(admit_extractor) 52 | 53 | disch_outpath = os.path.join(data_dir, table + ".dischtime.tsv") 54 | disch_time_ext = TimeExtractor(name = "dischtime", converter = time2str) 55 | disch_type_ext = FmtExtractor(names = [], fmt = "admissions.disch") 56 | disch_value_ext = ConstantExtractor("FLAG") 57 | disch_test_ext = true_test_extractor 58 | disch_extractor = ExtractorInfo(table, disch_outpath, id_extractor, 59 | disch_time_ext, disch_type_ext, 60 | disch_value_ext, disch_test_ext) 61 | extractors.append(disch_extractor) 62 | 63 | death_outpath = os.path.join(data_dir, table + ".deathtime.tsv") 64 | death_time_ext = TimeExtractor(name = 'deathtime', converter = time2str) 65 | death_type_ext = FmtExtractor(names = [], fmt = "admissions.death") 66 | death_value_ext = ConstantExtractor("FLAG") 67 | death_test_ext = TestExtractor(name = 'deathtime', test = exist_test) 68 | death_extractor = ExtractorInfo(table, death_outpath, id_extractor, 69 | death_time_ext, death_type_ext, 70 | death_value_ext, death_test_ext) 71 | extractors.append(death_extractor) 72 | 73 | extractor_map[table] = extractors 74 | return table 75 | 76 | def get_icustays_extractors(data_dir, extractor_map): 77 | extractors = [] 78 | table = 'icustays' 79 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id', 'icustay_id'], sep = "_") 80 | 81 | intime_outpath = os.path.join(data_dir, table + '.tsv') 82 | intime_time_ext = TimeExtractor(name = 'intime', converter = time2str) 83 | intime_type_ext = FmtExtractor(names = [], fmt = 'icustays') 84 | intime_value_ext = MultiExtractor(names = ['outtime']) 85 | intime_test_ext = TestExtractor(name = 'outtime', test = exist_test) 86 | intime_extractor = ExtractorInfo(table, intime_outpath, id_extractor, 87 | intime_time_ext, intime_type_ext, 88 | intime_value_ext, intime_test_ext) 89 | extractors.append(intime_extractor) 90 | extractor_map[table] = extractors 91 | return table 92 | 93 | 94 | # def get_callout_extractors(data_dir, extractor_map): 95 | # extractors = [] 96 | # table = 'callout' 97 | # id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id'], sep = "_") 98 | 99 | # create_outpath = os.path.join(data_dir, table + '.createtime.tsv') 100 | # create_time_ext = TimeExtractor(name = 'createtime', converter = time2str) 101 | # create_type_ext = FmtExtractor(names = [], fmt = 'callout.createtime') 102 | # create_value_ext = ConstantExtractor(1) 103 | # create_test_ext = true_test_extractor 104 | # create_extractor = ExtractorInfo(table, create_outpath, id_extractor, 105 | # create_time_ext, create_type_ext, 106 | # create_value_ext, create_test_ext) 107 | # extractors.append(create_extractor) 108 | 109 | # update_outpath = os.path.join(data_dir, table + ".updatetime.tsv") 110 | # update_time_ext = TimeExtractor(name = 'updatetime', converter = time2str) 111 | # update_type_ext = FmtExtractor(names = [], fmt = 'callout.updatetime') 112 | # update_value_ext = ConstantExtractor(1) 113 | # update_test_ext = true_test_extractor 114 | # update_extractor = ExtractorInfo(table, update_outpath, id_extractor, 115 | # update_time_ext, update_type_ext, 116 | # update_value_ext, update_test_ext) 117 | # extractors.append(update_extractor) 118 | 119 | # ack_outpath = os.path.join(data_dir, table + ".acknowledgetime.tsv") 120 | # ack_time_ext = TimeExtractor(name = 'acknowledgetime', converter = time2str) 121 | # ack_type_ext = FmtExtractor(names = [], fmt = 'callout.acknowledge') 122 | # ack_value_ext = MultiExtractor(names = ['acknowledge_status']) 123 | # ack_test_ext = TestExtractor(name = 'acknowledgetime', test = exist_test) 124 | # # ack_test_ext = TestExtractor(name = 'acknowledge') 125 | # ack_extractor = ExtractorInfo(table, ack_outpath, id_extractor, 126 | # ack_time_ext, ack_type_ext, 127 | # ack_value_ext, ack_test_ext) 128 | # extractors.append(ack_extractor) 129 | 130 | # outcome_outpath = os.path.join(data_dir, table + '.outcome.tsv') 131 | # outcome_time_ext = TimeExtractor(name = 'outcometime', converter = time2str) 132 | # outcome_type_ext = FmtExtractor(names = [], fmt = 'callout.outcome') 133 | # outcome_value_ext = MultiExtractor(names = ['callout_outcome']) 134 | # outcome_test_ext = true_test_extractor 135 | # outcome_extractor = ExtractorInfo(table, outcome_outpath, id_extractor, 136 | # outcome_time_ext, outcome_type_ext, 137 | # outcome_value_ext, outcome_test_ext) 138 | # extractors.append(outcome_extractor) 139 | 140 | # extractor_map[table] = extractors 141 | # return table 142 | 143 | 144 | def get_labevents_extractors(data_dir, extractor_map): 145 | extractors = [] 146 | table = 'labevents' 147 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id'], sep = "_") 148 | 149 | outpath = os.path.join(data_dir, table + '.tsv') 150 | time_ext = TimeExtractor(name = 'charttime', converter = time2str) 151 | type_ext = FmtExtractor(names = ['itemid'], fmt = 'labevents.%s') 152 | # value unit flag 153 | # value_ext = MultiExtractor(names = ['value', 'valueuom', 'flag']) 154 | value_ext = MultiExtractor(names = ['value', 'flag']) 155 | test_ext = TestExtractor(name = "value", test = exist_test) 156 | extractor = ExtractorInfo(table, outpath, id_extractor, 157 | time_ext, type_ext, 158 | value_ext, test_ext) 159 | extractors.append(extractor) 160 | 161 | extractor_map[table] = extractors 162 | return table 163 | 164 | def get_microbiologyevents_extractors(data_dir, extractor_map): 165 | extractors = [] 166 | table = 'microbiologyevents' 167 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id'], sep = "_") 168 | 169 | outpath = os.path.join(data_dir, table + '.tsv') 170 | charttime_ext = TimeExtractor(name = 'charttime', converter = time2str) 171 | chartdate_ext = TimeExtractor(name = 'chartdate', converter = date2str) 172 | time_ext = SelectExtractor([charttime_ext, chartdate_ext]) 173 | # specimen organ ab 174 | type_ext = FmtExtractor(names = ['spec_itemid', 'org_itemid', 'ab_itemid'], fmt = 'microbioevents.%s&%s&%s') 175 | # text comp value inter 176 | value_ext = MultiExtractor(names = ['dilution_text', 'dilution_comparison', 'dilution_value', 'interpretation']) 177 | test_ext = TestExtractor(name = 'interpretation', test = exist_test) 178 | extractor = ExtractorInfo(table, outpath, id_extractor, 179 | time_ext, type_ext, 180 | value_ext, test_ext) 181 | extractors.append(extractor) 182 | 183 | extractor_map[table] = extractors 184 | return table 185 | 186 | def get_outputevents_extractors(data_dir, extractor_map): 187 | extractors = [] 188 | table = 'outputevents' 189 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id', 'icustay_id'], sep = "_") 190 | 191 | outpath = os.path.join(data_dir, table + '.tsv') 192 | time_ext = TimeExtractor(name = 'charttime', converter = time2str) 193 | type_ext = FmtExtractor(names = ['itemid'], fmt = 'outputevents.%s') 194 | # value_ext = MultiExtractor(names = ['value', 'valueuom']) 195 | value_ext = MultiExtractor(names = ['value']) 196 | test_ext = TestExtractor(name = "value", test = exist_test) 197 | extractor = ExtractorInfo(table, outpath, id_extractor, 198 | time_ext, type_ext, 199 | value_ext, test_ext) 200 | extractors.append(extractor) 201 | 202 | extractor_map[table] = extractors 203 | return table 204 | 205 | def get_diagnoses_extractors(data_dir, extractor_map): 206 | extractors = [] 207 | table = 'diagnoses_icd' 208 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id'], sep = "_") 209 | 210 | outpath = os.path.join(data_dir, table + ".tsv") 211 | time_ext = ConstantExtractor(None) 212 | type_ext = FmtExtractor(names = [], fmt = 'diagnoses_icd') 213 | value_ext = MultiExtractor(names = ['seq_num', 'icd9_code']) 214 | test_ext = true_test_extractor 215 | extractor = ExtractorInfo(table, outpath, id_extractor, 216 | time_ext, type_ext, 217 | value_ext, test_ext) 218 | extractors.append(extractor) 219 | 220 | extractor_map[table] = extractors 221 | return table 222 | 223 | def get_prescriptions_extractors(data_dir, extractor_map): 224 | extractors = [] 225 | table = 'prescriptions' 226 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id', 'icustay_id'], sep = "_") 227 | 228 | st_outpath = os.path.join(data_dir, table + '.tsv') 229 | sttime_ext = TimeExtractor(name = 'startdate', converter = date2str) 230 | st_type_ext = FmtExtractor(names = [], fmt = 'prescriptions') 231 | value_ext = MultiExtractor(names = ['enddate', 'formulary_drug_cd']) 232 | st_test_ext = TestExtractors(names = ['startdate', 'enddate', 'formulary_drug_cd'], 233 | test = exist_test) 234 | st_extractor = ExtractorInfo(table, st_outpath, id_extractor, 235 | sttime_ext, st_type_ext, 236 | value_ext, st_test_ext) 237 | extractors.append(st_extractor) 238 | 239 | extractor_map[table] = extractors 240 | return table 241 | 242 | def get_datetimeevents_extractors(data_dir, extractor_map): 243 | extractors = [] 244 | table = 'datetimeevents' 245 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id', 'icustay_id'], sep = "_") 246 | 247 | outpath = os.path.join(data_dir, table + '.tsv') 248 | time_ext = TimeExtractor(name = 'charttime', converter = time2str) 249 | type_ext = FmtExtractor(names = ['itemid'], fmt = 'datetimeevents.%s') 250 | value_ext = ConstantExtractor("Flag") 251 | test_ext = TestExtractor(name = 'charttime', test = exist_test) 252 | extractor = ExtractorInfo(table, outpath, id_extractor, 253 | time_ext, type_ext, 254 | value_ext, test_ext) 255 | extractors.append(extractor) 256 | 257 | extractor_map[table] = extractors 258 | return table 259 | 260 | def get_chartevents_extractors(data_dir, extractor_map): 261 | extractors = [] 262 | table = 'chartevents' 263 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id', 'icustay_id'], sep = '_') 264 | 265 | outpath = os.path.join(data_dir, table + ".tsv") 266 | time_ext = TimeExtractor(name = 'charttime', converter = time2str) 267 | type_ext = FmtExtractor(names = ['itemid'], fmt = 'chartevents.%s') 268 | value_ext = MultiExtractor(names = ['value']) 269 | test_ext = TestExtractor(name = 'value', test = exist_test) 270 | tables = [] 271 | for i in range(1, 15): 272 | sub_table = table + "_" + str(i) 273 | sub_outpath = os.path.join(data_dir, "chartevents_%d.tsv" %(i)) 274 | extractor = ExtractorInfo(sub_table, sub_outpath, id_extractor, 275 | time_ext, type_ext, 276 | value_ext, test_ext) 277 | extractor_map[sub_table] = [extractor] 278 | tables.append(sub_table) 279 | 280 | return tables 281 | 282 | def get_proceduresicd_extractors(data_dir, extractor_map): 283 | extractors = [] 284 | table = 'procedures_icd' 285 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id'], sep = "_") 286 | 287 | outpath = os.path.join(data_dir, table + ".tsv") 288 | time_ext = ConstantExtractor(None) 289 | type_ext = FmtExtractor(names = [], fmt = 'procedures_icd') 290 | value_ext = MultiExtractor(names = ['seq_num', 'icd9_code']) 291 | test_ext = true_test_extractor 292 | extractor = ExtractorInfo(table, outpath, id_extractor, 293 | time_ext, type_ext, 294 | value_ext, test_ext) 295 | extractors.append(extractor) 296 | 297 | extractor_map[table] = extractors 298 | return table 299 | 300 | def get_procedureevents_extractors(data_dir, extractor_map): 301 | extractors = [] 302 | table = "procedureevents_mv" 303 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id', 'icustay_id'], sep = "_") 304 | 305 | st_outpath = os.path.join(data_dir, table + ".tsv") 306 | st_time_ext = TimeExtractor(name = 'starttime', converter = time2str) 307 | st_type_ext = FmtExtractor(names = ['itemid'], fmt = 'procedureevents_mv.%s') 308 | # value_ext = MultiExtractor(names = ['endtime', 'value', 'valueuom']) 309 | value_ext = MultiExtractor(names = ['endtime', 'value']) 310 | test_ext = TestExtractors(names = ['endtime', 'value'], test = exist_test) 311 | 312 | st_extractor = ExtractorInfo(table, st_outpath, id_extractor, 313 | st_time_ext, st_type_ext, 314 | value_ext, test_ext) 315 | extractors.append(st_extractor) 316 | 317 | extractor_map[table] = extractors 318 | return table 319 | 320 | def get_inputevents_cv_extractors(data_dir, extractor_map): 321 | extractors = [] 322 | table = 'inputevents_cv' 323 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id', 'icustay_id'], sep = '_') 324 | 325 | outpath = os.path.join(data_dir, table + ".tsv") 326 | time_ext = TimeExtractor(name = 'charttime', converter = time2str) 327 | type_ext = FmtExtractor(names = ['itemid'], fmt = 'inputevents_cv.%s') 328 | # value_ext = MultiExtractor(names = ['amount', 'amountuom', 'rate', 'rateuom']) 329 | value_ext = MultiExtractor(names = ['amount', 'rate']) 330 | test_ext = TestExtractor(name = 'amount', test = exist_test) 331 | extractor = ExtractorInfo(table, outpath, id_extractor, 332 | time_ext, type_ext, 333 | value_ext, test_ext) 334 | extractors.append(extractor) 335 | 336 | extractor_map[table] = extractors 337 | return table 338 | 339 | def get_inputevents_mv_extractors(data_dir, extractor_map): 340 | extractors = [] 341 | table = 'inputevents_mv' 342 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id', 'icustay_id'], sep = '_') 343 | 344 | st_outpath = os.path.join(data_dir, table + ".tsv") 345 | st_time_ext = TimeExtractor(name = 'starttime', converter = time2str) 346 | st_type_ext = FmtExtractor(names = ['itemid'], fmt = 'inputevents_mv.%s') 347 | # value_ext = MultiExtractor(names = ['endtime', 'amount', 'amountuom', 'rate', 'rateuom']) 348 | value_ext = MultiExtractor(names = ['endtime', 'amount', 'rate']) 349 | test_ext = TestExtractors(names = ['endtime', 'amount'], test = exist_test) 350 | 351 | st_extractor = ExtractorInfo(table, st_outpath, id_extractor, 352 | st_time_ext, st_type_ext, 353 | value_ext, test_ext) 354 | extractors.append(st_extractor) 355 | 356 | extractor_map[table] = extractors; 357 | return table 358 | 359 | def get_cptevents_extractors(data_dir, extractor_map): 360 | extractors = [] 361 | table = 'cptevents' 362 | id_extractor = MultiExtractor(names = ['subject_id', 'hadm_id'], sep = "_") 363 | 364 | outpath = os.path.join(data_dir, table + '.tsv') 365 | time_ext = MultiExtractor(names = ['chartdate']) 366 | type_ext = FmtExtractor(names = [], fmt = 'cptevents') 367 | value_ext = MultiExtractor(names = ['cpt_cd', 'ticket_id_seq']) 368 | test_ext = true_test_extractor 369 | extractor = ExtractorInfo(table, outpath, id_extractor, 370 | time_ext, type_ext, 371 | value_ext, test_ext) 372 | extractors.append(extractor) 373 | extractor_map[table] = extractors 374 | return table 375 | 376 | 377 | 378 | 379 | if __name__ == '__main__': 380 | 381 | extractor_map = {} 382 | funcs = [get_patients_extractors, get_admissions_extractors, get_icustays_extractors, 383 | get_labevents_extractors, get_microbiologyevents_extractors, get_outputevents_extractors, get_diagnoses_extractors, 384 | get_prescriptions_extractors, get_datetimeevents_extractors, get_chartevents_extractors, get_proceduresicd_extractors, 385 | get_procedureevents_extractors, get_inputevents_cv_extractors, get_inputevents_mv_extractors, get_cptevents_extractors] 386 | 387 | # funcs = [get_labevents_extractors] 388 | if not os.path.exists(data_dir): 389 | os.mkdir(data_dir) 390 | 391 | tables = [] 392 | for func in funcs: 393 | table = func(data_dir, extractor_map) 394 | if type(table) == list: 395 | tables.extend(table) 396 | else: 397 | tables.append(table) 398 | for table in tables[:]: 399 | extract_from_table(table, extractor_map[table], only_test = False, limit = 1000000) 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/extractor.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhshen95/HELSTM/4a4c99797d6d717e96dcec805d9748bd641a242d/数据生成说明_v7/data_process/extractor.cpp -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/extractor.py: -------------------------------------------------------------------------------- 1 | from util import connect, date2str, time2str, data_dir 2 | import os 3 | import sys 4 | from datetime import timedelta 5 | import datetime 6 | 7 | fmt = "%-25s %-25s %-40s %s\n" 8 | 9 | def parse_line(line): 10 | line = line.rstrip() 11 | ret = [] 12 | ret.append(line[0:26]) 13 | ret.append(line[26:26+26]) 14 | ret.append(line[52:52+41]) 15 | ret.append(line[93:]) 16 | ret = [item.rstrip() for item in ret] 17 | return ret 18 | 19 | 20 | 21 | class Extractor: 22 | def extract(self, row): 23 | pass 24 | 25 | 26 | class MultiExtractor(Extractor): 27 | def __init__(self, names, sep = '|&|'): 28 | self.names = names 29 | self.sep = sep 30 | 31 | def extract(self, row): 32 | values = [row[name] for name in self.names] 33 | for i in range(len(values)): 34 | if values[i] is None: 35 | values[i] = "" 36 | values = map(str, values) 37 | return self.sep.join(values) 38 | 39 | 40 | class ConstantExtractor(Extractor): 41 | def __init__(self, constant): 42 | self.constant = constant 43 | 44 | def extract(self, row): 45 | return self.constant 46 | 47 | 48 | class TimeExtractor(Extractor): 49 | def __init__(self, name, converter): 50 | self.name = name 51 | self.converter = converter 52 | 53 | def extract(self, row): 54 | value = row[self.name] 55 | return self.converter(value) 56 | 57 | 58 | class TestExtractor(Extractor): 59 | def __init__(self, name, test): 60 | self.name = name 61 | self.test = test 62 | 63 | def extract(self, row): 64 | return self.test(row[self.name]) 65 | class TestExtractors(Extractor): 66 | def __init__(self, names, test): 67 | self.tests = [] 68 | for name in names: 69 | self.tests.append(TestExtractor(name, test)) 70 | 71 | def extract(self, row): 72 | ret = True 73 | for test in self.tests: 74 | ret = ret and test.extract(row) 75 | return ret 76 | 77 | class FmtExtractor(Extractor): 78 | def __init__(self, names, fmt): 79 | self.names = names 80 | self.fmt = fmt 81 | 82 | def extract(self, row): 83 | values = tuple([row[name] for name in self.names]) 84 | return self.fmt % values 85 | 86 | class SelectExtractor(Extractor): 87 | def __init__(self, sub_extractors): 88 | self.sub_extractors = sub_extractors 89 | 90 | def extract(self, row): 91 | for ext in self.sub_extractors: 92 | try: 93 | value = ext.extract(row) 94 | except Exception, e: 95 | pass 96 | else: 97 | return value 98 | return None 99 | 100 | 101 | 102 | class ExtractorInfo: 103 | def __init__(self, table, outpath, id_extractor, time_extractor, 104 | type_extractor = None, value_extractor = ConstantExtractor(1), 105 | test_extractor = ConstantExtractor(True)): 106 | self.table = table 107 | self.outpath = outpath 108 | self.id_extractor = id_extractor 109 | self.time_extractor = time_extractor 110 | self.type_extractor = type_extractor 111 | self.value_extractor = value_extractor 112 | self.test_extractor = test_extractor 113 | 114 | def open(self): 115 | self.outf = file(self.outpath, 'w') 116 | 117 | def extract(self, row): 118 | global fmt 119 | if self.test_extractor.extract(row): 120 | ID = self.id_extractor.extract(row) 121 | time = self.time_extractor.extract(row) 122 | dtype = self.type_extractor.extract(row) 123 | value = self.value_extractor.extract(row) 124 | out = (ID, time, dtype, value) 125 | self.outf.write(fmt % out) 126 | 127 | def close(self): 128 | self.outf.close() 129 | 130 | 131 | def extract_from_table(table, extractors, only_test = False, limit = 100000): 132 | print "query from [%s]" %table 133 | db = connect() 134 | offset = 0 135 | for extractor in extractors: 136 | extractor.open() 137 | while True: 138 | query = "select * from %s order by row_id limit %d offset %d" %(table, limit, offset) 139 | print '\t%s' %query 140 | res = db.query(query) 141 | cnt = 0 142 | for row in res.dictresult(): 143 | cnt += 1 144 | for extractor in extractors: 145 | extractor.extract(row) 146 | 147 | ntuples = res.ntuples() 148 | offset += limit 149 | if ntuples < limit or only_test: 150 | break 151 | for extractor in extractors: 152 | extractor.close() 153 | 154 | 155 | 156 | if __name__ == '__main__': 157 | extractor_map = {} 158 | get_patients_extractors(data_dir, extractor_map) 159 | get_admissions_extractors(data_dir, extractor_map) 160 | extract_from_table('patients', extractors) -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/gather_stat.py: -------------------------------------------------------------------------------- 1 | from util import stat_dir, result_dir 2 | import os 3 | import glob 4 | from stat_data import Stat 5 | 6 | def handle(filename, outf): 7 | base = os.path.basename(filename) 8 | print base 9 | get_cnt = lambda x:x.cnt 10 | get_rate = lambda x:x.rate 11 | stats = Stat.load_from_file(filename) 12 | outf.write("***** %s start *****\n" %base) 13 | for rtype in sorted(stats.keys()): 14 | out = [] 15 | stat = stats[rtype] 16 | nentry = stat.nentry() 17 | out.append(nentry) 18 | out.append(stat.get_mean(get_cnt)) 19 | out.append(stat.get_var(get_cnt)) 20 | stat.calc_rate() 21 | out.append(stat.get_mean(get_rate)) 22 | out.append(stat.get_var(get_rate)) 23 | out = ['%-20s' %str(item) for item in out] 24 | out.insert(0, '%-40s' %rtype) 25 | outf.write("".join(out) + "\n") 26 | outf.write("***** %s end *****\n\n" %base) 27 | 28 | 29 | if __name__ == '__main__': 30 | if not os.path.exists(result_dir): 31 | os.mkdir(result_dir) 32 | outf = file(os.path.join(result_dir, 'stat.tsv'), 'w') 33 | for filename in glob.glob(stat_dir + "/*.tsv"): 34 | handle(filename, outf) 35 | outf.close() 36 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/gather_static_data.py: -------------------------------------------------------------------------------- 1 | from util import * 2 | import datetime 3 | 4 | class Admission: 5 | def __init__(self): 6 | self.cnt = 0 7 | self.admit_type = None 8 | self.disch = None 9 | self.admit = None 10 | self.death = False 11 | 12 | def add_disch(self, disch): 13 | if self.disch == None: 14 | self.disch = disch 15 | else: 16 | self.disch = max(self.disch, disch) 17 | self.cnt += 1 18 | 19 | def add_admit(self, admit, admit_type): 20 | if self.admit is None: 21 | self.admit = admit 22 | self.admit_type = admit_type 23 | elif self.admit > admit: 24 | self.admit = admit 25 | self.admit_type = admit_type 26 | 27 | def add_death(self, death): 28 | self.death = True 29 | 30 | def range(self): 31 | return (self.disch - self.admit).days 32 | 33 | def __str__(self): 34 | out = map(str, [self.admit, self.disch, self.cnt, self.death, self.admit_type]) 35 | return "\t".join(out) 36 | 37 | @staticmethod 38 | def load_from_line(line): 39 | admission = Admission() 40 | parts = line.strip().split("\t") 41 | admission.cnt = int(parts[2]) 42 | admission.admit = parse_time(parts[0]) 43 | admission.disch = parse_time(parts[1]) 44 | admission.death = False if parts[3] == "False" else True 45 | admission.admit_type = parts[-1] 46 | return admission 47 | 48 | def load_admission(): 49 | ad_map = {} 50 | for line in file(static_data_dir + "/admission.tsv"): 51 | parts = line.strip().split("\t") 52 | pid = int(parts[0]) 53 | adm = Admission.load_from_line("\t".join(parts[1:])) 54 | ad_map[pid] = adm 55 | return ad_map 56 | 57 | class SingleAdmission: 58 | def __init__(self, pid, admit_time, disch_time, admit_type): 59 | self.pid = pid 60 | self.admit_time = admit_time 61 | self.disch_time = disch_time 62 | self.admit_type = admit_type 63 | 64 | def __str__(self): 65 | out = map(str, [self.pid, self.admit_time, self.disch_time, self.admit_type]) 66 | return "\t".join(out) 67 | 68 | @staticmethod 69 | def load_from_line(line): 70 | parts = line.strip().split("\t") 71 | pid = int(parts[0]) 72 | admit_time = parse_time(parts[1]) 73 | disch_time = parse_time(parts[2]) 74 | admit_type = parts[3] 75 | admission = SingleAdmission(pid, admit_time, disch_time, admit_type) 76 | return admission 77 | 78 | def get_d_icd_diagnoses(): 79 | diag_icd_map = {} 80 | db = connect() 81 | table = 'd_icd_diagnoses' 82 | query = 'select * from %s' %table 83 | res = db.query(query) 84 | for row in res.dictresult(): 85 | code = row['icd9_code'] 86 | value = (row['short_title'], row['long_title']) 87 | diag_icd_map[code] = value 88 | 89 | return diag_icd_map 90 | 91 | def get_d_labitems(): 92 | d_labitem_map = {} 93 | db = connect() 94 | table = "d_labitems" 95 | query = "select * from %s" %table 96 | res = db.query(query) 97 | for row in res.dictresult(): 98 | code = row['itemid'] 99 | value = row['label'] + " | " + row['fluid'] 100 | d_labitem_map[code] = value 101 | return d_labitem_map 102 | 103 | def get_d_items(): 104 | d_item_map = {} 105 | db = connect() 106 | table = "d_items" 107 | query = "select * from %s" %table 108 | res = db.query(query) 109 | for row in res.dictresult(): 110 | code = row['itemid'] 111 | value = row['label'] 112 | d_item_map[code] = value 113 | return d_item_map 114 | 115 | def get_single_admission(): 116 | admission_map = {} 117 | db = connect() 118 | table = "admissions" 119 | query = "select * from %s" %table 120 | res = db.query(query) 121 | cnt = 0 122 | for row in res.dictresult(): 123 | cnt += 1 124 | if cnt % 1000 == 0: 125 | print cnt 126 | hid = row['hadm_id'] 127 | pid = row['subject_id'] 128 | admit_time = row['admittime'] 129 | disch_time = row['dischtime'] 130 | admit_type = row['admission_type'] 131 | admission_map[hid] = SingleAdmission(pid, admit_time, disch_time, admit_type) 132 | 133 | return admission_map 134 | 135 | 136 | def write_map(data, filepath): 137 | outf = file(filepath, 'w') 138 | keys = data.keys() 139 | try: 140 | keys = map(int, keys) 141 | except Exception, e: 142 | pass 143 | for key in sorted(keys): 144 | outf.write('\t'.join(map(str, [key, data[key]])) + '\n') 145 | outf.close() 146 | 147 | def write_map_value(data, filepath): 148 | outf = file(filepath, "w") 149 | keys = data.keys() 150 | for key in sorted(keys): 151 | outf.write(str(data[key]) + "\n") 152 | outf.close() 153 | 154 | def get_admission_map(admit_path, disch_path, death_path): 155 | admission_map = {} 156 | id2event_value = load_id2event_value() 157 | for line in file(disch_path): 158 | parts = line.strip().split("\t") 159 | pid = int(parts[1]) 160 | time = parse_time(parts[3]) 161 | assert time is not None 162 | if not pid in admission_map: 163 | admission_map[pid] = Admission() 164 | admission_map[pid].add_disch(time) 165 | for line in file(admit_path): 166 | parts = line.strip().split("\t") 167 | pid = int(parts[1]) 168 | admit_type = id2event_value[int(parts[0])].split(".")[-1] 169 | time = parse_time(parts[3]) 170 | assert time is not None 171 | admission_map[pid].add_admit(time, admit_type) 172 | 173 | for line in file(death_path): 174 | parts = line.strip().split("\t") 175 | pid = int(parts[1]) 176 | time = parse_time(parts[3]) 177 | assert time is not None 178 | admission_map[pid].add_death(time) 179 | 180 | return admission_map 181 | 182 | 183 | 184 | 185 | if __name__ == '__main__': 186 | if not os.path.exists(static_data_dir): 187 | os.mkdir(static_data_dir) 188 | # diagnose code map 189 | # diag_icd_map = get_d_icd_diagnoses() 190 | # write_map(diag_icd_map, os.path.join(static_data_dir, 'diag_ICD9.tsv')) 191 | 192 | # item code map 193 | item_map = get_d_items() 194 | write_map(get_d_items(), os.path.join(static_data_dir, 'item_code.tsv')) 195 | 196 | # labitem code map 197 | labitem_map = get_d_labitems() 198 | write_map(labitem_map, os.path.join(static_data_dir, "labitem_code.tsv")) 199 | 200 | # gather admission from database 201 | # admission_map = get_single_admission() 202 | # write_map(admission_map, os.path.join(static_data_dir, "single_admission.tsv")) 203 | 204 | 205 | # admission map 206 | # admit_path = os.path.join(event_dir, "admissions.admittime.tsv") 207 | # disch_path = os.path.join(event_dir, 'admissions.dischtime.tsv') 208 | # death_path = os.path.join(event_dir, "admissions.deathtime.tsv") 209 | # admission_map = get_admission_map(admit_path, disch_path, death_path) 210 | # write_map(admission_map, os.path.join(static_data_dir, "admission.tsv")) 211 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/lab_process_data.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import time 4 | 5 | seq_len = 200 6 | 7 | def time_to_stamp(time_string, start, time_format='%Y-%m-%d %H:%M:%S'): 8 | if (time_string==""): 9 | return -10800. 10 | return time.mktime(time.strptime(time_string, time_format)) - start 11 | 12 | def load_data(path, start, end): 13 | f = h5py.File(path) 14 | labels = f['label'][start:end] 15 | events = f['event'][start:end] 16 | times = f['time'][start:end] 17 | features = f['feature'][start:end] 18 | 19 | time_shift = [] 20 | for id in xrange(times.shape[0]): 21 | event_time = times[id] 22 | event_shift = [] 23 | start_time = time_to_stamp(event_time[0], 0.) 24 | for i in xrange(event_time.shape[0]): 25 | event_shift.append(time_to_stamp(event_time[i], start_time)/10800.) 26 | time_shift.append(event_shift) 27 | times = np.asarray(time_shift, dtype='float32') 28 | 29 | chosen_event = [] 30 | chosen_time = [] 31 | chosen_label = [] 32 | chosen_feature_id = [] 33 | chosen_feature_value = [] 34 | tic = time.time() 35 | for id in xrange(labels.shape[0]): 36 | this_label = labels[id] 37 | this_event = events[id] 38 | this_feature_id = features[id][:,(0,2,4)] 39 | this_feature_value = features[id][:,(1,3,5)] 40 | this_time = times[id] 41 | 42 | chosen = this_label[(this_label[:,0]==34)+(this_label[:,0]==35)]#choose id 34 or 35 43 | for tmp in chosen: 44 | this_start = int(tmp[-1]) 45 | this_end = int(tmp[-2]) 46 | if this_end<=this_start+10:#ignore seq whose len<10 47 | continue 48 | if this_end>this_start+seq_len:#cut max seq len to seq_len 49 | this_start = this_end-seq_len 50 | pad_num = seq_len-this_end+this_start 51 | chosen_event.append( 52 | np.pad(this_event[this_start+1:this_end+1], ((0,pad_num),), 'constant')) 53 | chosen_time.append( 54 | np.pad(this_time[this_start+1:this_end+1], ((0,pad_num),), 'constant')) 55 | chosen_feature_id.append( 56 | np.pad(this_feature_id[this_start+1:this_end+1], ((0,pad_num),(0,0)), 'constant')) 57 | chosen_feature_value.append( 58 | np.pad(this_feature_value[this_start+1:this_end+1], ((0,pad_num),(0,0)), 'constant')) 59 | chosen_label.append(tmp[2])#use 0/1 as label 60 | 61 | chosen_event = np.asarray(chosen_event, dtype='int16') 62 | chosen_time = np.asarray(chosen_time, dtype='float32') 63 | chosen_feature_id = np.asarray(chosen_feature_id, dtype='int16') 64 | chosen_feature_value = np.asarray(chosen_feature_value, dtype='float32') 65 | chosen_label = np.asarray(chosen_label, dtype='float32') 66 | f.close() 67 | return chosen_event, chosen_time, chosen_feature_id, chosen_feature_value, chosen_label 68 | 69 | def load_data_all(name, start, end): 70 | path = '../Data/Lab.h5' 71 | f = h5py.File('../Data/{}.h5'.format(name), 'w') 72 | events, times, feature_id, feature_value, labels = load_data(path, start, end) 73 | f['events'] = events 74 | f['times'] = times 75 | f['feature_id'] = feature_id 76 | f['feature_value'] = feature_value 77 | f['labels'] = labels 78 | f.close() 79 | 80 | load_data_all('test', 0, 9112) 81 | load_data_all('train', 9112, 41006) 82 | load_data_all('valid', 41006, 45563) 83 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/select_feature.py: -------------------------------------------------------------------------------- 1 | from util import * 2 | import os 3 | import json 4 | import re 5 | 6 | class ValueCount: 7 | def __init__(self, ratio, number = 0): 8 | self.ratio = ratio 9 | self.number = number 10 | 11 | @staticmethod 12 | def load_from_str(part): 13 | l = part.split("/") 14 | number = int(l[0]) 15 | ratio = float(l[1]) 16 | return ValueCount(ratio, number) 17 | 18 | def __cmp__(self, other): 19 | return cmp(self.ratio, other.ratio) 20 | 21 | class FeatureValueCount: 22 | ''' 23 | parts:type nentry time% num% null% txt% ntxt_type 24 | ''' 25 | main_type_threshold = ValueCount(0.95) 26 | max_text_kinds = 40 27 | small_coverage = "small coverage" 28 | def __init__(self, parts): 29 | time_cnt = ValueCount.load_from_str(parts[2]) 30 | num_cnt = ValueCount.load_from_str(parts[3]) 31 | null_cnt = ValueCount.load_from_str(parts[4]) 32 | txt_cnt = ValueCount.load_from_str(parts[5]) 33 | self.cnts = { 34 | "time":time_cnt, 35 | "number":num_cnt, 36 | "null":null_cnt, 37 | "text":txt_cnt, 38 | } 39 | self.feature = parts[0] 40 | self.ntxt_type = int(parts[6]) 41 | 42 | def check_valid(self): 43 | main_type = self.main_type() 44 | coverage = self.get_count(main_type) 45 | if coverage < FeatureValueCount.main_type_threshold: 46 | print "**** coverage:%f" %coverage.ratio 47 | return False 48 | if main_type == "text" and self.ntxt_type > FeatureValueCount.max_text_kinds: 49 | print "**** text #types:%d" %self.ntxt_type 50 | return False 51 | return True 52 | 53 | def main_type(self): 54 | main_name = None 55 | max_value = ValueCount(0.0) 56 | for name in self.cnts: 57 | value = self.cnts[name] 58 | if value > max_value: 59 | max_value = value 60 | main_name = name 61 | return main_name 62 | 63 | def get_count(self, name): 64 | return self.cnts[name] 65 | 66 | 67 | 68 | def load_value_type_text(filepath): 69 | text_map = {} 70 | for line in file(filepath): 71 | parts = line.strip().split('\t') 72 | feature = parts[0] 73 | texts = json.loads(parts[2]) 74 | text_map[feature] = texts 75 | return text_map 76 | 77 | def load_value_type_stat(filepath): 78 | value_type_map = {} 79 | space_spliter = re.compile(r"\s+") 80 | for line in file(filepath): 81 | line = line.strip() 82 | if line.startswith("****"): 83 | continue 84 | parts = space_spliter.split(line) 85 | value_cnt = FeatureValueCount(parts) 86 | value_type_map[value_cnt.feature] = value_cnt 87 | return value_type_map 88 | 89 | class TypeFeature(): 90 | def __init__(self, rtype, nentry): 91 | self.rtype = rtype 92 | self.nentry = nentry 93 | self.features = [] 94 | 95 | def add_feature(self, value_feature): 96 | self.features.append(value_feature) 97 | 98 | def to_string(self): 99 | self.features.sort() 100 | out = [] 101 | out.append(self.rtype + "#" + str(self.nentry)) 102 | for feature in self.features: 103 | out.append("\t"+feature.to_string()) 104 | return "\n".join(out) 105 | 106 | @staticmethod 107 | def load_from_str(string): 108 | parts = string.split("\n") 109 | rtype, nentry = parts[0].split("#") 110 | type_feature = TypeFeature(rtype, int(nentry)) 111 | for part in parts[1:]: 112 | type_feature.features.append(ValueFeature.load_from_str(part.lstrip("\t"))) 113 | return type_feature 114 | 115 | 116 | class ValueFeature: 117 | def __init__(self, order, main_type, f_value_cnt = None): 118 | self.order = order 119 | self.main_type = main_type 120 | if f_value_cnt is not None: 121 | self.coverage = f_value_cnt.get_count(main_type).ratio 122 | self.ndim = 1 123 | if self.main_type == "text": 124 | self.ndim = f_value_cnt.ntxt_type 125 | 126 | 127 | def to_string(self): 128 | out = [self.order, self.main_type, self.coverage, self.ndim] 129 | out = map(str, out) 130 | return " ".join(out) 131 | 132 | @staticmethod 133 | def load_from_str(string): 134 | parts = string.split(" ") 135 | # print string 136 | order = int(parts[0]) 137 | main_type = parts[1] 138 | value_f = ValueFeature(order, main_type) 139 | value_f.coverage = float(parts[2]) 140 | value_f.ndim = int(parts[3]) 141 | return value_f 142 | 143 | def __cmp__(self, other): 144 | return cmp(self.order, other.order) 145 | 146 | def gen_type_feature(rtype, nentry, value_stat_map, text_map): 147 | Flag = True 148 | i = 0 149 | type_feature = TypeFeature(rtype, nentry) 150 | while True: 151 | order = i 152 | i += 1 153 | feature_name = rtype + "#" + str(order) 154 | if feature_name in value_stat_map: 155 | value_stat = value_stat_map[feature_name] 156 | if value_stat.check_valid(): 157 | value_feature = ValueFeature(order, value_stat.main_type(), value_stat) 158 | type_feature.add_feature(value_feature) 159 | else: 160 | if rtype.startswith("inputevents_mv") and order == 2: 161 | continue 162 | if rtype.startswith("inputevents_cv") and order == 1: 163 | continue 164 | return None 165 | else: 166 | break 167 | return type_feature 168 | 169 | 170 | 171 | def select_feature(stat_file, limit, text_map, value_stat_map): 172 | black_regs = load_reg(os.path.join(config_dir, 'blacklist.reg')) 173 | features = [] 174 | rtypes = [] 175 | space_spliter = re.compile(r"\s+") 176 | nfeature = 0 177 | dim = 0 178 | event_dim = 0 179 | for idx, line in enumerate(file(stat_file)): 180 | if idx >= limit: 181 | continue 182 | parts = space_spliter.split(line.strip()) 183 | rtype = parts[0] 184 | in_black = False 185 | for reg in black_regs: 186 | if reg.match(rtype): 187 | in_black = True 188 | if in_black: 189 | continue 190 | 191 | nentry = int(parts[1]) 192 | type_feature = gen_type_feature(rtype, nentry, value_stat_map, text_map) 193 | if type_feature is not None: 194 | features.append(type_feature) 195 | t_event_dim = 1 196 | cc = 0 197 | for value_feature in type_feature.features: 198 | if value_feature.main_type != "text": 199 | dim += value_feature.ndim 200 | # print value_feature.ndim 201 | if value_feature.ndim > 1: 202 | cc += 1 203 | t_event_dim *= value_feature.ndim 204 | if cc >= 2: 205 | print "**********", rtype 206 | event_dim += t_event_dim 207 | print "feature dim =", dim 208 | print "event dim =", event_dim 209 | print "nfeature =", len(features) 210 | 211 | return features 212 | 213 | def write_features(features, filepath): 214 | outf = file(filepath, 'w') 215 | 216 | for feature in features: 217 | outf.write(feature.to_string() + "\n") 218 | # main_type = feature.main_type 219 | # out = [feature.name, feature.f_coverage, main_type] 220 | # if main_type in error_main_types: 221 | # main_name, max_ratio = feature.f_value_cnt.max_ratio() 222 | # out.append(main_name) 223 | # out.append(max_ratio) 224 | # else: 225 | # value_cnt = feature.f_value_cnt.get_count(main_type) 226 | # out.append('%d/%f' %(value_cnt.number, value_cnt.ratio)) 227 | # if main_type == "text": 228 | # out.append(feature.f_value_cnt.ntxt_type) 229 | # out = map(str, out) 230 | # outf.write("\t".join(out) + "\n") 231 | outf.close() 232 | 233 | 234 | 235 | if __name__ == '__main__': 236 | text_map = load_value_type_text(os.path.join(result_dir, "value_type_text.tsv")) 237 | value_stat_map = load_value_type_stat(os.path.join(result_dir, "value_type_stat.tsv")) 238 | stat_file = os.path.join(result_dir, "sorted_stat.tsv") 239 | feature_limit = 1000 240 | features = select_feature(stat_file, feature_limit, text_map, value_stat_map) 241 | write_features(features, os.path.join(result_dir, "selected_features.tsv")) 242 | 243 | 244 | 245 | 246 | 247 | 248 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/sort_stat.py: -------------------------------------------------------------------------------- 1 | import re 2 | pattern = re.compile(r"\s{2,}") 3 | outf = file('result/sorted_stat.tsv', 'w') 4 | stats = [] 5 | for line in file('result/stat.tsv'): 6 | if line.startswith('****') or line.strip() == "": 7 | continue 8 | else: 9 | nentry = int(pattern.split(line)[1]) 10 | stats.append((nentry, line)) 11 | 12 | stats.sort(key = lambda x:x[0], reverse = True) 13 | for nentry, line in stats: 14 | outf.write(line) 15 | outf.close() 16 | 17 | 18 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/stat_data.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import extractor 4 | from util import data_dir, parse_time, stat_dir, time2str, result_dir 5 | import re 6 | import math 7 | 8 | 9 | class Stat: 10 | pattern = re.compile(r'\[(?P[\w\.&]+)] statistics: (?P\d+) entries') 11 | 12 | 13 | def __init__(self, rtype): 14 | self.rtype = rtype 15 | self.stat = {} 16 | self.value_stat = None 17 | self.row = 0 18 | 19 | 20 | def nentry(self): 21 | return len(self.stat) 22 | 23 | def nrow(self): 24 | return self.row 25 | 26 | def get_value_stat(self): 27 | return self.value_stat 28 | 29 | def add_whole_entry(self, entry): 30 | self.stat[entry.ID] = entry 31 | 32 | 33 | def add_entry(self, ID, time, value = None): 34 | if not ID in self.stat: 35 | self.stat[ID] = StatEntry(ID) 36 | self.stat[ID].add_time(time) 37 | if value is not None: 38 | parts = value.split("|&|") 39 | self.add_value(ID, parts) 40 | 41 | def add_value(self, ID, parts): 42 | if self.value_stat is None: 43 | nvalue = len(parts) 44 | self.value_stat = [0] * nvalue 45 | self.row += 1 46 | assert len(parts) == len(self.value_stat) 47 | 48 | for i in range(len(parts)): 49 | if parts[i] != "": 50 | self.value_stat[i] += 1 51 | 52 | def write_to_local(self, outf): 53 | outf.write('[%s] statistics: %d entries\n' %(self.rtype, len(self.stat))) 54 | for ID in sorted(self.stat.keys()): 55 | outf.write(self.stat[ID].to_line()) 56 | outf.write("\n") 57 | 58 | def get_mean(self, valuef): 59 | tot = 0.0 60 | nentry = 0 61 | for entry in self.stat.values(): 62 | value = valuef(entry) 63 | if value > 0: 64 | tot += value 65 | nentry += 1 66 | mean = 0 67 | if nentry > 0: 68 | mean = round(tot / nentry, 4) 69 | return (nentry, mean) 70 | 71 | def get_var(self, valuef): 72 | nentry, mean = self.get_mean(valuef) 73 | tot = 0 74 | for entry in self.stat.values(): 75 | value = valuef(entry) 76 | if value > 0: 77 | tot += (value - mean) * (value - mean) 78 | var = 0.0 79 | if nentry > 0: 80 | var = round(math.sqrt(tot / nentry), 4) 81 | return (nentry, var) 82 | 83 | def calc_rate(self): 84 | for entry in self.stat.values(): 85 | entry.calc_rate() 86 | 87 | @staticmethod 88 | def load_from_file(filename): 89 | stats = {} 90 | rtype = None 91 | for line in file(filename, 'r'): 92 | 93 | line = line.rstrip() 94 | if line.startswith("["): 95 | res = Stat.pattern.match(line) 96 | rtype = res.group('rtype') 97 | cnt = res.group("count") 98 | stats[rtype] = Stat(rtype) 99 | stat = stats[rtype] 100 | else: 101 | entry = StatEntry.load_from_line(line) 102 | stat.add_whole_entry(entry) 103 | return stats 104 | 105 | 106 | 107 | class StatEntry: 108 | fmt = '%s\t%d\t%s\t%s' 109 | 110 | def __init__(self, ID, cnt = 0, st = None, ed = None): 111 | self.ID = ID 112 | self.st = st 113 | self.ed = ed 114 | self.cnt = cnt 115 | 116 | def add_time(self, time): 117 | if self.st is None: 118 | self.st = time 119 | self.ed = time 120 | elif time is not None: 121 | self.st = min(self.st, time) 122 | self.ed = max(self.ed, time) 123 | self.cnt += 1 124 | 125 | 126 | def to_line(self): 127 | return StatEntry.fmt %(self.ID, self.cnt, self.st, self.ed) 128 | 129 | def calc_rate(self): 130 | self.rate = 0 131 | if self.st is not None: 132 | nhour = (self.ed - self.st).total_seconds()/3600.00 133 | if nhour > 0: 134 | self.rate = self.cnt / nhour 135 | 136 | @staticmethod 137 | def load_from_line(line): 138 | parts = line.rstrip().split("\t") 139 | ID = int(parts[0]) 140 | cnt = int(parts[1]) 141 | st = parts[2] 142 | ed = parts[3] 143 | if st != "None": 144 | st_time = parse_time(st) 145 | else: 146 | st_time = None 147 | if ed != "None": 148 | ed_time = parse_time(ed) 149 | else: 150 | ed_time = None 151 | entry = StatEntry(ID, cnt, st_time, ed_time) 152 | return entry 153 | 154 | 155 | def get_patientid(IDs_str): 156 | IDs = IDs_str.split("_") 157 | return int(IDs[0]) 158 | 159 | def get_id(IDs_str): 160 | IDs = IDs_str.split("_") 161 | return int(IDs[0]) 162 | 163 | # if len(IDs) >= 2: 164 | # if IDs[1] == "": 165 | # return None 166 | # return int(IDs[1]) 167 | # else: 168 | # return int(IDs[0]) 169 | 170 | def process(filename, outfilename, value_stats): 171 | stats = {} 172 | print 'process [%s]' %(os.path.basename(filename)) 173 | cnt = 0 174 | for line in file(filename, 'r'): 175 | cnt += 1 176 | if cnt % 100000 == 0: 177 | print "\t %d lines" %cnt 178 | parts = extractor.parse_line(line) 179 | 180 | ID = get_id(parts[0]) 181 | if ID is None: 182 | continue 183 | 184 | rtype = parts[2] 185 | if not rtype in stats: 186 | stats[rtype] = Stat(rtype) 187 | stat = stats[rtype] 188 | 189 | time = parse_time(parts[1]) 190 | stat.add_entry(ID, time, parts[3].strip()) 191 | 192 | values = parts[3].split("|&|") 193 | if len(values) >= 1: 194 | time = parse_time(values[0]) 195 | if time is not None: 196 | stat.add_entry(ID, time) 197 | if outfilename is not None: 198 | outf = file(outfilename, 'w') 199 | for rtype in sorted(stats.keys()): 200 | stats[rtype].write_to_local(outf) 201 | outf.close() 202 | 203 | if not value_stats is None: 204 | for rtype in stats.keys(): 205 | stat = stats[rtype] 206 | total = stat.nrow() 207 | value_cnts = stat.get_value_stat() 208 | for i in range(len(value_cnts)): 209 | key = rtype + "#" + str(i) 210 | rate = round(value_cnts[i] / (total + 0.0), 3) 211 | value = (rate, value_cnts[i]) 212 | value_stats[key] = value 213 | 214 | def count_table_event(filepath): 215 | print "count event from [%s]" %os.path.basename(filepath) 216 | event_cnt = {} 217 | for line in file(filepath): 218 | parts = extractor.parse_line(line) 219 | pid = int(parts[0].split("_")[0]) 220 | if not pid in event_cnt: 221 | event_cnt[pid] = 0 222 | event_cnt[pid] += 1 223 | return event_cnt 224 | 225 | def count_event(filepaths, outpath): 226 | patients = set() 227 | table_names = [] 228 | table_cnt = {} 229 | for filepath in filepaths: 230 | event_cnt = count_table_event(filepath) 231 | name = os.path.basename(filepath)[:-4] 232 | table_cnt[name] = event_cnt 233 | table_names.append(name) 234 | for pid in event_cnt: 235 | patients.add(pid) 236 | 237 | outf = file(outpath, 'w') 238 | table_names = sorted(table_names) 239 | out = ["pid"] 240 | out.extend(table_names) 241 | outf.write("\t".join(out) + "\n") 242 | for pid in sorted(patients): 243 | out = [pid] 244 | for table_name in table_names: 245 | cnt = table_cnt[table_name].get(pid, 0) 246 | out.append(cnt) 247 | out = map(str, out) 248 | outf.write("\t".join(out) + "\n") 249 | outf.close() 250 | 251 | 252 | def write_value_stat(value_stats, outpath): 253 | outf = file(outpath, 'w') 254 | for key in sorted(value_stats.keys(), reverse = True, key = lambda x:value_stats[x][0]): 255 | outf.write("%-35s %5f %d\n" %(key, value_stats[key][0], value_stats[key][1])) 256 | outf.close() 257 | 258 | class SimpleStat: 259 | def __init__(self): 260 | self.pid_event_cnt = {} 261 | self.nb_event = 0 262 | self.rtype_set = set() 263 | 264 | def add_pid(self, pid): 265 | if not pid in self.pid_event_cnt: 266 | self.pid_event_cnt[pid] = 0 267 | 268 | 269 | def add_data(self, line): 270 | self.nb_event += 1 271 | parts = extractor.parse_line(line) 272 | pid = get_id(parts[0]) 273 | rtype = parts[2] 274 | self.add_pid(pid) 275 | self.rtype_set.add(rtype) 276 | self.pid_event_cnt[pid] += 1 277 | 278 | def print_info(self): 279 | out_format = """ 280 | # of patients = {0} 281 | # of events = {1} 282 | Avg # of events per patient = {2} 283 | Max # of events per patient = {3} 284 | Min # of events per patient = {4} 285 | # of unique events = {5} 286 | """ 287 | nb_patients = len(self.pid_event_cnt) 288 | nb_events = self.nb_event 289 | ave_events = round((nb_events + 0.0) / nb_patients, 4) 290 | max_events = reduce(max, self.pid_event_cnt.values()) 291 | min_events = reduce(min, self.pid_event_cnt.values()) 292 | nb_event_type = len(self.rtype_set) 293 | print out_format.format( 294 | nb_patients, 295 | nb_events, 296 | ave_events, 297 | max_events, 298 | min_events, 299 | nb_event_type 300 | ) 301 | 302 | 303 | 304 | 305 | def gather_statistics(filepath, stat): 306 | print "gather info from %s" %filepath 307 | for line in file(filepath): 308 | stat.add_data(line) 309 | 310 | 311 | 312 | 313 | if __name__ == '__main__': 314 | # process('data/chartevents_8.tsv', 'stat/test.stat') 315 | 316 | # stat data 317 | if not os.path.exists(stat_dir): 318 | os.mkdir(stat_dir) 319 | if not os.path.exists(result_dir): 320 | os.mkdir(result_dir) 321 | value_stats = {} 322 | for filename in glob.glob(data_dir + "/*tsv"): 323 | stat_filename = os.path.join(stat_dir, os.path.basename(filename)) 324 | process(filename, stat_filename, value_stats) 325 | if value_stats and len(value_stats) > 0: 326 | write_value_stat(value_stats, os.path.join(result_dir, 'value_coverage_stat.tsv')) 327 | 328 | 329 | # filepaths = glob.glob(data_dir + "/*tsv") 330 | # count_event(filepaths, os.path.join(result_dir, 'event_cnt.tsv')) 331 | 332 | # print statistics 333 | # stat = SimpleStat() 334 | # for filename in glob.glob(data_dir + "/*tsv"): 335 | # gather_statistics(filename, stat) 336 | # stat.print_info() 337 | -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/stat_value.py: -------------------------------------------------------------------------------- 1 | from util import * 2 | import glob 3 | import os 4 | import extractor 5 | import json 6 | 7 | class ValueStat: 8 | header = "***** type nentry time% num% null% txt% ntxt_type" 9 | fmt = "%-3s %-8s %-7s%-7s%-7s%-7s%-7s" 10 | def __init__(self, order): 11 | self.order = order 12 | self.nnum = 0 13 | self.ntime = 0 14 | self.nentry = 0 15 | self.ntxt = 0 16 | self.nnull = 0 17 | self.ntype_txt = set() 18 | 19 | def add(self, value): 20 | value = value.strip() 21 | self.nentry += 1 22 | if value == "": 23 | self.nnull += 1 24 | self.ntxt += 1 25 | elif is_time(value): 26 | self.ntime += 1 27 | elif is_number(value): 28 | self.nnum += 1 29 | else: 30 | self.ntxt += 1 31 | self.ntype_txt.add(value.lower()) 32 | 33 | def __str__(self): 34 | self.nentry += 0.0 35 | out = ["%-3d" %self.order] 36 | out.append("%-8d" %self.nentry) 37 | counts = [self.ntime, self.nnum, self.nnull, self.ntxt] 38 | for count in counts: 39 | out.append("%5d/%-.3f" %(count, round(count/self.nentry, 3))) 40 | out.append("%-7d" %len(self.ntype_txt)) 41 | return " ".join(out) 42 | 43 | # out = [self.order, self.nentry, round(self.ntime/self.nentry, 3), 44 | # round(self.nnum/self.nentry, 3), round(self.nnull/self.nentry,3), 45 | # round(self.ntxt/self.nentry, 3), len(self.ntype_txt)] 46 | # out = map(str, out) 47 | # return ValueStat.fmt % tuple(out) 48 | 49 | 50 | class TypeValueStat: 51 | def __init__(self, rtype): 52 | self.nvalue = -1 53 | self.rtype = rtype 54 | self.value_stats = [] 55 | 56 | def add(self, values): 57 | if self.nvalue == -1: 58 | self.nvalue = len(values) 59 | for i in range(self.nvalue): 60 | self.value_stats.append(ValueStat(i)) 61 | assert self.nvalue == len(values) 62 | for i in range(self.nvalue): 63 | self.value_stats[i].add(values[i]) 64 | 65 | def to_string(self): 66 | ret = [] 67 | for value_stat in self.value_stats: 68 | ret.append("%15s" %self.rtype + "#" + str(value_stat)) 69 | return ret 70 | 71 | 72 | class FileValueStat: 73 | def __init__(self, filename): 74 | self.filename = filename 75 | self.type_stats = {} 76 | 77 | def add(self, rtype, values): 78 | if not rtype in self.type_stats: 79 | self.type_stats[rtype] = TypeValueStat(rtype) 80 | self.type_stats[rtype].add(values) 81 | 82 | 83 | 84 | def stat_value_types(filepath, outf, txt_outf = None): 85 | value_sep = "|&|" 86 | filename = ".".join(os.path.basename(filepath).split(".")[:-1]) 87 | print filename 88 | file_stat = FileValueStat(filename) 89 | for line in file(filepath): 90 | parts = extractor.parse_line(line) 91 | rtype = parts[2] 92 | values = parts[3].split(value_sep) 93 | file_stat.add(rtype, values) 94 | outf.write("***** %s start *****\n" %filename) 95 | for rtype in sorted(file_stat.type_stats.keys()): 96 | for out_str in file_stat.type_stats[rtype].to_string(): 97 | outf.write(out_str + "\n") 98 | outf.write("***** %s end *****\n" %filename) 99 | 100 | if txt_outf is not None: 101 | for rtype in sorted(file_stat.type_stats.keys()): 102 | type_stat = file_stat.type_stats[rtype] 103 | for value_stat in type_stat.value_stats: 104 | if len(value_stat.ntype_txt) > 0: 105 | name = type_stat.rtype + "#" + str(value_stat.order) 106 | text = json.dumps([text for text in value_stat.ntype_txt]) 107 | txt_outf.write(name + '\t' + str(len(value_stat.ntype_txt)) + '\t' + text + "\n") 108 | 109 | 110 | 111 | 112 | 113 | if __name__ == '__main__': 114 | outf = file(os.path.join(result_dir, "value_type_stat.tsv"), 'w') 115 | txt_outf = file(os.path.join(result_dir, "value_type_text.tsv"), 'w') 116 | outf.write(ValueStat.header+"\n") 117 | for filepath in glob.glob('data/' + "*.tsv"): 118 | stat_value_types(filepath, outf, txt_outf) 119 | outf.close() 120 | txt_outf.close() -------------------------------------------------------------------------------- /数据生成说明_v7/data_process/util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import datetime 4 | import time 5 | import re 6 | import numpy as np 7 | import commands 8 | try: 9 | from pg import DB 10 | except ImportError: 11 | sys.stderr.write('can\'t imprt module pg\n') 12 | import argparse 13 | 14 | 15 | def connect(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--host') 18 | parser.add_argument('--user') 19 | parser.add_argument('--passwd') 20 | parser.add_argument('--schema', default = 'mimiciii') 21 | args = parser.parse_args() 22 | 23 | # host = '162.105.146.246' 24 | # host = 'localhost' 25 | # schema = 'mimiciii' 26 | 27 | host = args.host 28 | user = args.user 29 | passwd = args.passwd 30 | Print('connect to %s, user = %s, search_path = %s' %(host, user, args.schema)) 31 | db = DB(host = host, user = user, passwd = passwd) 32 | db.query('set search_path to %s' %(args.schema)) 33 | return db 34 | 35 | class Patient(): 36 | 37 | bs_attrs = [] 38 | 39 | def __init__(self, row): 40 | self.values = {} 41 | self.names = [] 42 | for field in Patient.bs_attrs: 43 | self.values[field] = row[field] 44 | self.names.append(field) 45 | 46 | def to_row(self): 47 | ret = [] 48 | for name in self.names: 49 | ret.append(self.values[name]) 50 | return ret 51 | 52 | @staticmethod 53 | def set_attrs(columns): 54 | Patient.bs_attrs = [] 55 | for field in columns: 56 | Patient.bs_attrs.append(field) 57 | 58 | @staticmethod 59 | def write_to_local(patients, path): 60 | columns = None 61 | data = [] 62 | index = [] 63 | for pid, patient in patients.iteritems(): 64 | if columns is None: 65 | columns = patient.names 66 | data.append(patient.to_row()) 67 | index.append(pid) 68 | from pandas import DataFrame 69 | dt = DataFrame(data = data, index = index, columns = columns) 70 | dt.sort_index() 71 | dt.to_csv(path) 72 | 73 | def date2str(date): 74 | return date.strftime('%Y-%m-%d') 75 | 76 | def time2str(time): 77 | return time.strftime('%Y-%m-%d %H:%M:%S') 78 | 79 | def time_format_str(time): 80 | return '{0.year:4d}-{0.month:02d}-{0.day:02d} {0.hour:02d}:{0.minute:02d}:{0.second:02d}'.format(time) 81 | 82 | def parse_time(time_str): 83 | if len(time_str) in [18, 19]: 84 | try: 85 | return datetime.datetime.strptime(time_str, '%Y-%m-%d %H:%M:%S') 86 | except Exception, e: 87 | return None 88 | elif len(time_str) == 10: 89 | try: 90 | return datetime.datetime.strptime(time_str, '%Y-%m-%d') 91 | except Exception, e: 92 | return None 93 | elif len(time_str) in [12, 13, 14]: 94 | try: 95 | return datetime.datetime.strptime(time_str, '%m/%d/%y %H:%M') 96 | except Exception, e: 97 | return None 98 | elif len(time_str) == 16: 99 | try: 100 | return datetime.datetime.strptime(time_str, "%Y/%m/%d %H:%M") 101 | except Exception, e: 102 | return None 103 | return None 104 | 105 | def parse_number(number_str): 106 | try: 107 | return float(number_str) 108 | except Exception, e: 109 | return None 110 | 111 | def is_time(time_str): 112 | time = parse_time(time_str) 113 | return time is not None 114 | 115 | def is_number(number_str): 116 | number = parse_number(number_str) 117 | return number is not None 118 | 119 | def load_reg(filepath): 120 | regs = [] 121 | for line in file(filepath): 122 | line = line.strip() 123 | if line.startswith("#"): 124 | continue 125 | if line == "": 126 | continue 127 | regs.append(re.compile(line)) 128 | return regs 129 | 130 | def load_id2event_value(): 131 | ret = {} 132 | for line in file(os.path.join(result_dir, "event_des_text.tsv")): 133 | parts = line.strip("\n").split(" ") 134 | event_id = int(parts[0]) 135 | event_type = parts[1] 136 | value = " ".join(parts[2:]) 137 | ret[event_id] = event_type + '.' + value 138 | return ret 139 | 140 | def load_id2event_rtype(): 141 | ret = {} 142 | for line in file(os.path.join(result_dir, "event_des_text.tsv")): 143 | parts = line.strip("\n").split(" ") 144 | event_id = int(parts[0]) 145 | event_type = parts[1] 146 | value = " ".join(parts[2:]) 147 | ret[event_id] = event_type 148 | return ret 149 | 150 | 151 | 152 | def merge_prob(probs, ids, func): 153 | prob_map = {} 154 | assert len(probs) == len(ids) 155 | for i in range(len(probs)): 156 | prob = probs[i] 157 | sid = ids[i] 158 | if not sid in prob_map: 159 | prob_map[sid] = prob 160 | else: 161 | prob_map[sid] = func(prob_map[sid], prob) 162 | probs = [] 163 | for sid in sorted(prob_map.keys()): 164 | probs.append(prob_map[sid]) 165 | return np.array(probs) 166 | 167 | def merge_label(labels, ids): 168 | label_map = {} 169 | for i in range(len(labels)): 170 | label = labels[i] 171 | sid = ids[i] 172 | if not sid in label_map: 173 | label_map[sid] = label 174 | return np.array([label_map[sid] for sid in sorted(label_map.keys())]) 175 | 176 | def norm_to_prob(X): 177 | y = np.expand_dims(X.sum(-1), -1) 178 | y[y == 0] = 1 179 | return X / y 180 | 181 | def load_numpy_array(filepath): 182 | return np.load(filepath) 183 | 184 | def now(): 185 | return datetime.datetime.now().strftime('%m-%d %H:%M:%S') 186 | 187 | def merge_event_map(filepath): 188 | print "load event des from [%s]" %filepath 189 | new_idx_cnt = 2 190 | new_events_idx = {} 191 | old2new = {0: 0, 1: 1} 192 | for line in file(filepath): 193 | line = line.strip() 194 | if line == "": 195 | conitnue 196 | parts = line.split(" ") 197 | old_idx = int(parts[0]) 198 | rtype = parts[1] 199 | if not rtype in new_events_idx: 200 | new_events_idx[rtype] = new_idx_cnt 201 | new_idx_cnt += 1 202 | old2new[old_idx] = new_events_idx[rtype] 203 | return old2new 204 | 205 | def load_items(filepath): 206 | items = {} 207 | for line in file(filepath): 208 | line = line.strip() 209 | if line == "": 210 | continue 211 | p = line.split('\t') 212 | code = int(p[0]) 213 | if len(p) == 1: 214 | des = "" 215 | else: 216 | des = p[1] 217 | items[code] = des 218 | return items 219 | 220 | def load_setting(filepath, default_setting): 221 | setting = default_setting if default_setting else {} 222 | 223 | if filepath.startswith("@"): 224 | lines = filepath[1:].split("|") 225 | else: 226 | lines = file(filepath).readlines() 227 | for line in lines: 228 | line = line.rstrip() 229 | if line == "": 230 | continue 231 | if line.startswith("#"): 232 | continue 233 | parts = line.split("|") 234 | for key_value in parts: 235 | x = key_value.strip().split("=") 236 | if len(x) >= 2: 237 | key = x[0] 238 | if x[1] == "True": 239 | value = True 240 | elif x[1] == "False": 241 | value = False 242 | elif is_number(x[1]): 243 | if x[1].isdigit(): 244 | value = int(x[1]) 245 | else: 246 | value = float(x[1]) 247 | else: 248 | value = "=".join(x[1:]) 249 | setting[key] = value 250 | print "load arg %s = %s" %(key, value) 251 | return setting 252 | 253 | def get_nb_lines(filepath): 254 | output = commands.getoutput('wc -l %s' %filepath) 255 | p = int(output.split(" ")[0]) 256 | return p 257 | 258 | def get_nb_files(pattern): 259 | output = commands.getoutput("ls %s|wc -l" %pattern) 260 | return int(output) 261 | 262 | def Print(*l): 263 | l = map(str, l) 264 | print now() + "\t" + " ".join(l) 265 | 266 | def add_to_cnt_dict(d, key): 267 | if not key in d: 268 | d[key] = 0 269 | d[key] += 1 270 | 271 | script_dir = os.path.dirname(os.path.realpath(__file__)) 272 | data_dir = os.path.join(script_dir, 'data') 273 | stat_dir = os.path.join(script_dir, 'stat') 274 | result_dir = os.path.join(script_dir, 'result') 275 | static_data_dir = os.path.join(script_dir, "static_data") 276 | config_dir = os.path.join(script_dir, 'config') 277 | event_dir = os.path.join(script_dir, 'event') 278 | event_stat_dir = os.path.join(script_dir, "event_stat") 279 | # exper_dir = os.path.join(script_dir, "exper") 280 | death_exper_dir = os.path.join(script_dir, 'death_exper') 281 | death_seg_dir = os.path.join(death_exper_dir, 'segs') 282 | death_merged_exper_dir = os.path.join(script_dir, 'death_merged_exper') 283 | ICU_exper_dir = os.path.join(script_dir, "ICU_exper") 284 | ICU_merged_exper_dir = os.path.join(script_dir, "ICU_merged_exper") 285 | ICU_seg_dir = os.path.join(ICU_exper_dir, 'segs') 286 | ICU_emd_dir = os.path.join(ICU_exper_dir, 'embeddings') 287 | lab_exper_dir = os.path.join(script_dir, 'lab_exper') 288 | event_seq_stat_dir = os.path.join(script_dir, "event_seq_stat") 289 | graph_dir = os.path.join(script_dir, 'graph') 290 | time_dis_graph_dir = os.path.join(graph_dir, "time_dis") 291 | 292 | if __name__ == "__main__": 293 | st = parse_time("2101-10-20 19:08:00") 294 | dob = parse_time("2025-04-11 00:00:00") 295 | print st - dob 296 | 297 | -------------------------------------------------------------------------------- /数据生成说明_v7/数据生成说明.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhshen95/HELSTM/4a4c99797d6d717e96dcec805d9748bd641a242d/数据生成说明_v7/数据生成说明.pdf -------------------------------------------------------------------------------- /数据集说明.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jhshen95/HELSTM/4a4c99797d6d717e96dcec805d9748bd641a242d/数据集说明.docx --------------------------------------------------------------------------------