├── src ├── __init__.py ├── aux_data_generation.py ├── data_utils.py ├── metrics.py ├── write_tf.py ├── ML_Net.py ├── ML_Net_components.py ├── ML_Net_label_prediction_train.py └── ML_Net_label_count_prediction_train.py └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/aux_data_generation.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate auxillary data for model training and evaluation, including: 3 | ancestors.pk: ancestors array of icd codes 4 | """ 5 | 6 | import data_utils as data_utils 7 | import numpy as np 8 | 9 | output_data_dir = "/" 10 | 11 | parent_to_child_file = "../aux_data/MIMIC_parentTochild" 12 | 13 | I = 7042 14 | root = 5367 15 | # Load up the tree 16 | PtoC = [[] for i in range(I)] 17 | 18 | f = open(parent_to_child_file) 19 | for line in f: 20 | line = [int(x) for x in line.strip().split('|')] 21 | PtoC[line[0]].append(line[1]) 22 | f.close() 23 | 24 | # Create ancestors array 25 | ancestors = [[] for i in range(I)] 26 | children = PtoC[root] 27 | for child in children: 28 | ancestors[child].append(root) 29 | while len(children) > 0: 30 | new_children = [] 31 | for child in children: 32 | for gc in PtoC[child]: 33 | ancestors[gc].extend([child] + ancestors[child]) 34 | new_children.append(gc) 35 | children = new_children 36 | for i in range(len(ancestors)): 37 | ancestors[i] = np.array(ancestors[i] + [i]) 38 | 39 | data_utils.save_obj(ancestors, output_data_dir+"ancestors.pk") -------------------------------------------------------------------------------- /src/data_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | utilities functions of ML_Net 3 | """ 4 | 5 | import tensorflow as tf 6 | import pickle 7 | 8 | def _lm2lp(label_map, MAX_PAIRS): 9 | pos = tf.reshape(tf.where(label_map), [-1]) 10 | neg = tf.reshape(tf.where(tf.logical_not(label_map)), [-1]) 11 | 12 | neg_pos = tf.meshgrid(neg, pos, indexing='ij') 13 | neg_pos_mat = tf.reshape(tf.transpose(tf.stack(neg_pos)), [-1, 2]) # all the pairs for neg and pos by their index 14 | neg_pos_rand = tf.random_shuffle(neg_pos_mat) 15 | neg_pos_pad = tf.pad(neg_pos_rand, [[0, MAX_PAIRS], [0, 0]]) 16 | neg_pos_res = tf.slice(neg_pos_pad, [0, 0], [MAX_PAIRS, -1]) 17 | 18 | # MAX_PAIRS x 2 19 | return neg_pos_res 20 | 21 | 22 | def batch_iter_eval(zip_text_mesh, batch_size): 23 | """ 24 | Generates a batch iterator for a evaluation dataset. 25 | """ 26 | text_list, mesh_list = zip(*zip_text_mesh) 27 | text_list = list(text_list) 28 | mesh_list = list(mesh_list) 29 | data_length = len(text_list) 30 | num_batches_per_epoch = int((data_length - 1) / batch_size) + 1 31 | 32 | for batch_num in range(num_batches_per_epoch): 33 | start_index = batch_num * batch_size 34 | if (batch_num + 1) * batch_size >= data_length: 35 | yield text_list[data_length - batch_size:], mesh_list[data_length - batch_size:] 36 | else: 37 | end_index = (batch_num + 1) * batch_size 38 | yield text_list[start_index:end_index], mesh_list[start_index:end_index] 39 | 40 | 41 | def save_obj(obj, file_address): 42 | with open(file_address, 'wb') as f: 43 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 44 | 45 | 46 | def load_obj(file_address): 47 | with open(file_address, 'rb') as f: 48 | return pickle.load(f) 49 | 50 | 51 | def __parse_example_proto_with_elmo(example_serialized, NUM_CLASS, MAX_PAIRS): 52 | feature_map = { 53 | 'raw_text': tf.FixedLenFeature([], tf.string), 54 | 'labels': tf.VarLenFeature(dtype=tf.int64) 55 | } 56 | features = tf.parse_single_example(example_serialized, features=feature_map) 57 | 58 | raw_text = features['raw_text'] 59 | 60 | labels = features['labels'] 61 | label_map = tf.sparse_to_indicator(labels, NUM_CLASS) 62 | label_pairs = _lm2lp(label_map, MAX_PAIRS) 63 | label_map = tf.cast(label_map, tf.float32) 64 | label_map.set_shape([NUM_CLASS]) 65 | 66 | label_dict = {} 67 | label_dict['label_pair'] = label_pairs 68 | label_dict['label_map'] = label_map 69 | label_dict['raw_text'] = raw_text 70 | return label_dict 71 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | evaluation metrics, with refrence to scripts from: 3 | https://physionet.org/works/ICD9CodingofDischargeSummaries 4 | """ 5 | 6 | import numpy as np 7 | import data_utils as data_utils 8 | 9 | ancestors = data_utils.load_obj("tf_data/ancestors.pk") 10 | root = 5367 11 | 12 | def only_leaves(codes): 13 | leaves = [] 14 | # codes = np.array(codes) 15 | all_ancestors = set() 16 | for code in codes: 17 | all_ancestors.update(list(ancestors[code][:-1])) 18 | for code in codes: 19 | if code not in all_ancestors: 20 | leaves.append(code) 21 | return leaves 22 | 23 | def get_p_r_f_jamia(logits, counts, labels): 24 | ranks = np.argsort(-logits, axis=1) 25 | FNs = [] 26 | FPs = [] 27 | TPs = [] 28 | 29 | predict_all_list = list() 30 | gold_all_list = list() 31 | 32 | for rank, L, k in zip(ranks, labels, counts): 33 | 34 | predict_id_list = list(rank[:k + 1]) 35 | predict_id_list.append(root) 36 | predict_id_set = set(predict_id_list) 37 | 38 | # Prune the predictions to respect the conditional classification 39 | # constraint (all ancestors must be predicted true for a child to be 40 | # predicted true) 41 | 42 | filtered_predictions = [] 43 | for prediction in predict_id_set: 44 | if prediction == root or np.all([anc in predict_id_set for anc in ancestors[prediction]]): 45 | filtered_predictions.append(prediction) 46 | predict_id_set = set(filtered_predictions) 47 | 48 | full_predict_id_set = set() 49 | 50 | for predict in predict_id_set: 51 | full_predict_id_set.update(list(ancestors[predict])) 52 | 53 | pred_only_leave = set(only_leaves(full_predict_id_set)) 54 | 55 | full_gold_set = set() 56 | for gs in L: 57 | full_gold_set.update(list(ancestors[gs])) 58 | 59 | gold_set = set(only_leaves(full_gold_set)) 60 | 61 | TP = 0 62 | FP = 0 63 | FN = 0 64 | for code in gold_set: 65 | if len(set(list(ancestors[code])) - set(full_predict_id_set)) > 0: 66 | FN += 1 67 | ## else: 68 | ## TP += 1 69 | for code in pred_only_leave: 70 | anc_set = set(list(ancestors[code])) 71 | if len(anc_set - set(full_gold_set)) > 0 and not np.any([x in anc_set for x in gold_set]): 72 | FP += 1 73 | else: 74 | TP += 1 75 | FNs.append(FN) 76 | FPs.append(FP) 77 | TPs.append(TP) 78 | 79 | predict_all_list.append(set(predict_id_set)) 80 | gold_all_list.append(set(L)) 81 | 82 | 83 | FNs = np.array(FNs, np.float) 84 | FPs = np.array(FPs, np.float) 85 | TPs = np.array(TPs, np.float) 86 | 87 | mean_prc = np.nanmean(np.where(TPs + FPs > 0, TPs / (TPs + FPs), 0)) 88 | mean_rec = np.nanmean(np.where(TPs + FNs > 0, TPs / (TPs + FNs), 0)) 89 | f_score = 2 * (mean_prc * mean_rec) / (mean_prc + mean_rec) 90 | 91 | return mean_prc, mean_rec, f_score, predict_all_list, gold_all_list -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ML-NET 2 | Reference codes repository for paper "ML-NET: multi-label classification of biomedical texts with deep neural networks". This repository demonstrates how to train and test 3 | ML-NET on task3: diagnosis codes assignment in the paper. 4 | 5 | ML-Net is a novel end-to-end deep learning framework for multi-label classification of biomedical tasks. 6 | ML-Net combines the label prediction network with a label count prediction network, 7 | which can determine the output labels based on both label confidence scores 8 | and document context in an end-to-end manner. 9 | 10 | ## Key features 11 | * Good performance on small to large scale multi-label biomedical text classification tasks 12 | * Determine the output labels based on both label confidence scores and document context in an end-to-end manner 13 | * Use [Threading and Queues](https://www.tensorflow.org/api_guides/python/threading_and_queues) in Tensorflow to faciliate faster training 14 | 15 | 16 | ## Requirements 17 | ML-NET relies on Python 3.6, TensorFlow 1.8+. 18 | 19 | ## Scripts overview 20 | 21 | ``` 22 | aux_data_generation.py #generate auxiliary data for model training and testing 23 | write_tf.py #generate training and test data (in both TFRecords and pickle format) 24 | ML_Net.py #the model of ML_Net 25 | ML_Net_components.py #the components of ML_Net 26 | metrics.py #the evaluation metrics 27 | data_utils.py #utilities functions of ML_Net 28 | ML_Net_label_prediction_train.py #the training of label prediction network 29 | ML_Net_label_count_prediction_train #the training of label count prediction network 30 | ``` 31 | 32 | ## Data preparation 33 | ### Download dataset 34 | The dataset is available here at: https://archive.physionet.org/works/ICD9CodingofDischargeSummaries/. 35 | Please note that you have to acquire the access for MIMIC II Clinical Database project first. Please cite the 36 | following paper when using the dataset: 37 | ``` 38 | @article{perotte2013diagnosis, 39 | title={Diagnosis code assignment: models and evaluation metrics}, 40 | author={Perotte, Adler and Pivovarov, Rimma and Natarajan, Karthik and Weiskopf, Nicole and Wood, Frank and Elhadad, No{\'e}mie}, 41 | journal={Journal of the American Medical Informatics Association}, 42 | volume={21}, 43 | number={2}, 44 | pages={231--237}, 45 | year={2013}, 46 | publisher={BMJ Publishing Group} 47 | } 48 | ``` 49 | 50 | ### Clean and construct dataset 51 | Please follow the readme file in the dataset folder and run construct_datasets.py. (Note: please run the script using Python 2). 52 | The following files will be used in the training and evaluation of ML-NET: 53 | ``` 54 | MIMIC_FILTERED_DSUMS #raw text of discharge summaries 55 | testing_codes.data #the labels (after augmentation) of testing set 56 | training_codes.data #the labels (after augmentation) of training set 57 | ``` 58 | 59 | ### Generate auxiliary data 60 | The training and testing of the ML-Net relies on the auxiliary data. Please 61 | run aux_data_generation.py to generate the auxiliary data 62 | 63 | ### Generate traininig and test data 64 | To gerenate the training and test dataset in tfrecords and pickle format, please run write_tf.py 65 | 66 | ## Training 67 | There are two training steps. We first train the label prediction network. 68 | During training, the label prediction as well as the hierarchical attention network are updated through back propagation. 69 | Then, we train the label count prediction network. However, different from the training label prediction network, 70 | only the MLP part is updated as gradient descent stops at the layer of the document vector. 71 | 72 | * first to run "ML_Net_label_prediction_training.py" 73 | * then to run "ML_Net_label_count_prediction_train.py" 74 | 75 | ## Contact 76 | 77 | Please contact Jingcheng Du: Jingcheng.du@uth.tmc.edu, if you have any questions 78 | 79 | ## Cite 80 | Please cite the following [article](https://academic.oup.com/jamia/advance-article/doi/10.1093/jamia/ocz085/5522430), if the codes are useful for your project. 81 | ``` 82 | @article{10.1093/jamia/ocz085, 83 | author = {Du, Jingcheng and Chen, Qingyu and Peng, Yifan and Xiang, Yang and Tao, Cui and Lu, Zhiyong}, 84 | title = "{ML-Net: multi-label classification of biomedical texts with deep neural networks}", 85 | journal = {Journal of the American Medical Informatics Association}, 86 | year = {2019}, 87 | month = {06}, 88 | issn = {1527-974X}, 89 | doi = {10.1093/jamia/ocz085}, 90 | url = {https://doi.org/10.1093/jamia/ocz085}, 91 | eprint = {http://oup.prod.sis.lan/jamia/advance-article-pdf/doi/10.1093/jamia/ocz085/28858839/ocz085.pdf}, 92 | } 93 | ``` 94 | 95 | -------------------------------------------------------------------------------- /src/write_tf.py: -------------------------------------------------------------------------------- 1 | """ 2 | write tfrecords and pickle fies for model training and test 3 | """ 4 | 5 | import data_utils as data_utils 6 | import tensorflow as tf 7 | import string 8 | import re 9 | 10 | term_pattern = re.compile('[A-Za-z]+') 11 | 12 | stopwords = set(list(string.punctuation)) 13 | 14 | def concatenateToken(tokens): 15 | sent = "" 16 | for token in tokens: 17 | sent = sent + " " + token 18 | return sent 19 | 20 | def clean_notes(text): 21 | lines = text.split("[NEWLINE]") 22 | clean_lines = list() 23 | for line in lines[1:]: 24 | raw_dsum = re.sub(r'\[[^\]]+\]', ' ', line) 25 | raw_dsum = re.sub(r'admission date:', ' ', raw_dsum, flags=re.I) 26 | raw_dsum = re.sub(r'discharge date:', ' ', raw_dsum, flags=re.I) 27 | raw_dsum = re.sub(r'date of birth:', ' ', raw_dsum, flags=re.I) 28 | raw_dsum = re.sub(r'sex:', ' ', raw_dsum, flags=re.I) 29 | raw_dsum = re.sub(r'service:', ' ', raw_dsum, flags=re.I) 30 | raw_dsum = re.sub(r'dictated by:.*$', ' ', raw_dsum, flags=re.I) 31 | raw_dsum = re.sub(r'completed by:.*$', ' ', raw_dsum, flags=re.I) 32 | raw_dsum = re.sub(r'signed electronically by:.*$', ' ', raw_dsum, flags=re.I) 33 | tokens = [token.lower() for token in re.findall(term_pattern, raw_dsum)] 34 | tokens = [token for token in tokens if token not in stopwords and len(token) > 1] 35 | if len(tokens) == 0: 36 | continue 37 | clean_lines.append(tokens) 38 | 39 | return clean_lines 40 | 41 | 42 | def _bytes_feature(value): 43 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 44 | 45 | 46 | def _int64_feature(value): 47 | if not isinstance(value, list): 48 | value = [value] 49 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 50 | 51 | 52 | MAX_ABS_LENGTH = 1500 53 | 54 | icd_data_dir = "/" 55 | output_data_dir = "/" 56 | 57 | test_tf = output_data_dir + 'icd_test.tfrecords' # address to save the TFRecords file 58 | train_tf = output_data_dir + '0_icd_train.tfrecords' # address to save the TFRecords file 59 | 60 | test_writer = tf.python_io.TFRecordWriter(test_tf) 61 | train_writer = tf.python_io.TFRecordWriter(train_tf) 62 | 63 | raw_text_file = open(icd_data_dir + "/MIMIC_FILTERED_DSUMS", "r").readlines() 64 | train_label_file = open(icd_data_dir + "/training_codes.data", "r").readlines() 65 | test_label_file = open(icd_data_dir + "/testing_codes.data", "r").readlines() 66 | 67 | train_len = len(train_label_file) 68 | test_len = len(test_label_file) 69 | 70 | label_list = list() 71 | abs_dict_list = list() 72 | 73 | for line_num, line in enumerate(raw_text_file[:train_len]): 74 | if line_num != 0 and line_num % 1280 == 0: 75 | train_writer.close() 76 | train_tf = output_data_dir + str(line_num) + "_icd_train.tfrecords" 77 | train_writer = tf.python_io.TFRecordWriter(train_tf) 78 | print('\r', line_num, end='', flush=True) 79 | 80 | abs_dict = dict() 81 | sents = clean_notes(line) 82 | 83 | raw_text = [] 84 | 85 | labels = train_label_file[line_num].split("|")[1:] 86 | label_set = set(labels) 87 | for sent in sents: 88 | raw_text.extend(sent) 89 | 90 | raw_text = concatenateToken(raw_text[:MAX_ABS_LENGTH]).strip() 91 | 92 | labels_list = [int(x) for x in label_set.copy()] 93 | feature = {'labels': _int64_feature(labels_list.copy()), 94 | 'raw_text': _bytes_feature(tf.compat.as_bytes(raw_text)) 95 | } 96 | 97 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 98 | abs_dict["raw_text"] = raw_text 99 | abs_dict_list.append(abs_dict) 100 | label_list.append(labels_list) 101 | train_writer.write(example.SerializeToString()) 102 | label_set.clear() 103 | data_utils.save_obj(zip(abs_dict_list, label_list), output_data_dir + "train_abs_label_zip.pickle") 104 | train_writer.close() 105 | 106 | label_list.clear() 107 | abs_dict_list.clear() 108 | 109 | for line_num, line in enumerate(raw_text_file[train_len:]): 110 | if line_num % 10 == 0: 111 | print('\r', "test: ", line_num, end='', flush=True) 112 | abs_dict = dict() 113 | sents = clean_notes(line) 114 | raw_text = [] 115 | 116 | labels = test_label_file[line_num].split("|")[1:] 117 | label_set = set(labels) 118 | for sent in sents: 119 | raw_text.extend(sent) 120 | 121 | raw_text = concatenateToken(raw_text[:MAX_ABS_LENGTH]).strip() 122 | 123 | labels_list = [int(x) for x in label_set.copy()] 124 | feature = {'labels': _int64_feature(labels_list.copy()), 125 | 'raw_text': _bytes_feature(tf.compat.as_bytes(raw_text)) 126 | } 127 | 128 | example = tf.train.Example(features=tf.train.Features(feature=feature)) 129 | abs_dict["raw_text"] = raw_text 130 | abs_dict_list.append(abs_dict) 131 | label_list.append(labels_list) 132 | test_writer.write(example.SerializeToString()) 133 | label_set.clear() 134 | data_utils.save_obj(zip(abs_dict_list, label_list), output_data_dir + "test_abs_label_zip.pickle") 135 | test_writer.close() 136 | 137 | label_list.clear() 138 | abs_dict_list.clear() 139 | -------------------------------------------------------------------------------- /src/ML_Net.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import ML_Net_components as ML_component 3 | from tensorflow.contrib import rnn 4 | import tensorflow_hub as hub 5 | 6 | class ML_Net(object): 7 | 8 | def __init__(self, NUM_CLASSES, hidden_size, MAX_LABELS_PERMITTED, batch_size, MAX_PAIRS): 9 | """init all hyperparameter here""" 10 | # set hyperparamter 11 | 12 | self.NUM_CLASSES = NUM_CLASSES 13 | self.MAX_PAIRS = MAX_PAIRS 14 | self.MAX_LABELS_PERMITTED = MAX_LABELS_PERMITTED 15 | self.batch_size = batch_size 16 | self.input_x = tf.placeholder(tf.string, shape=(None), name='inputs') 17 | self.input_y_label_pairs = tf.placeholder(tf.int32, [None, self.MAX_PAIRS, 2], 18 | name="input_y_label_pairs") # convert to 19 | self.input_y_label_map = tf.placeholder(tf.float32, [None, self.NUM_CLASSES], 20 | name="input_y_label_map") # convert to 21 | 22 | self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob") 23 | 24 | with tf.variable_scope("encoder"): 25 | module_url = "https://tfhub.dev/google/elmo/2" 26 | elmo_embed = hub.Module(module_url, trainable=True) 27 | embeddings = elmo_embed(self.input_x, signature="default", as_dict=True)["elmo"] 28 | 29 | with tf.name_scope("bidirectional_LSTM"): 30 | lstm_fw_cell = rnn.BasicLSTMCell(hidden_size) # forward direction cell 31 | lstm_bw_cell = rnn.BasicLSTMCell(hidden_size) # backward direction cell 32 | if self.dropout_keep_prob is not None: 33 | lstm_fw_cell = rnn.DropoutWrapper(lstm_fw_cell, output_keep_prob=self.dropout_keep_prob) 34 | lstm_bw_cell = rnn.DropoutWrapper(lstm_bw_cell, output_keep_prob=self.dropout_keep_prob) 35 | rnn_outputs, rnn_states = tf.nn.bidirectional_dynamic_rnn(lstm_fw_cell, lstm_bw_cell, embeddings, 36 | dtype=tf.float32) 37 | 38 | self.output_rnn_bi = tf.concat(rnn_outputs, axis=2) # [batch_size,sequence_length,hidden_size*2] 39 | 40 | with tf.name_scope("attention_layer"): 41 | 42 | hidden_size = self.output_rnn_bi.shape[2].value # D value - hidden size of the RNN layer 43 | 44 | # Trainable parameters 45 | w_omega = tf.Variable(tf.random_normal([hidden_size, hidden_size], stddev=0.1)) 46 | b_omega = tf.Variable(tf.random_normal([hidden_size], stddev=0.1)) 47 | u_omega = tf.Variable(tf.random_normal([hidden_size], stddev=0.1)) 48 | 49 | with tf.name_scope('v'): 50 | # Applying fully connected layer with non-linear activation to each of the B*T timestamps; 51 | # the shape of `v` is (B,T,D)*(D,A)=(B,T,A), where A=attention_size 52 | v = tf.tanh(tf.tensordot(self.output_rnn_bi, w_omega, axes=1) + b_omega) 53 | 54 | # For each of the timestamps its vector of size A from `v` is reduced with `u` vector 55 | vu = tf.tensordot(v, u_omega, axes=1, name='vu') # (B,T) shape 56 | alphas = tf.nn.softmax(vu, name='alphas') # (B,T) shape 57 | 58 | # Output of (Bi-)RNN is reduced with attention vector; the result has (B,D) shape 59 | self.sentence_level_output = tf.reduce_sum(self.output_rnn_bi * tf.expand_dims(alphas, -1), 1) 60 | 61 | with tf.name_scope("output"): 62 | sentence_level_output_norm = tf.contrib.layers.batch_norm(self.sentence_level_output, center=True, scale=True, is_training=True, epsilon=1e-12) 63 | self.logits = tf.contrib.layers.fully_connected(sentence_level_output_norm, self.NUM_CLASSES,activation_fn=tf.nn.relu) # [batch_size,num_classes] 64 | # self.final_output = tf.nn.sigmoid(self.logits, name = "attention_final_output") 65 | 66 | with tf.name_scope("loss"): 67 | self.loss = ML_component.mll_exp(self.logits, self.input_y_label_pairs, self.batch_size,NUM_CLASSES) 68 | ##L2 normalization 69 | 70 | with tf.name_scope("count_loss"): 71 | num_bins = self.MAX_LABELS_PERMITTED 72 | sentence_level_output = tf.stop_gradient(sentence_level_output_norm) 73 | cnt_h1 = tf.contrib.layers.fully_connected(sentence_level_output, num_outputs = NUM_CLASSES, activation_fn=tf.nn.relu) 74 | # cnt_h1 = tf.contrib.layers.batch_norm(cnt_h1,center=True, scale=True, is_training=True, epsilon=1e-12) 75 | cnt_h2 = tf.contrib.layers.fully_connected(cnt_h1, num_outputs = NUM_CLASSES, activation_fn=tf.nn.relu) 76 | # cnt_h2 = tf.contrib.layers.batch_norm(cnt_h2, center=True, scale=True, is_training=True, epsilon=1e-12) 77 | cnt_h3 = tf.contrib.layers.fully_connected(cnt_h2, num_outputs=128, activation_fn=tf.nn.relu) 78 | # cnt_h3 = tf.contrib.layers.batch_norm(cnt_h3, center=True, scale=True, is_training=True, epsilon=1e-12) 79 | self.lcnt = tf.contrib.layers.fully_connected(cnt_h3,num_bins, activation_fn=None) 80 | self.predictions_count = tf.argmax(self.lcnt, axis=1, name="predictions_count") 81 | label_count = tf.reduce_sum(self.input_y_label_map, 1) 82 | tails = num_bins * tf.ones_like(label_count) 83 | bins = tf.where(label_count > num_bins, tails, label_count) 84 | labels = bins - 1 85 | labels = tf.cast(labels, tf.int64) 86 | 87 | xent = tf.nn.sparse_softmax_cross_entropy_with_logits( 88 | logits=self.lcnt, labels=labels) 89 | self.lcnt_loss = tf.reduce_mean(xent, name= 'count_loss') -------------------------------------------------------------------------------- /src/ML_Net_components.py: -------------------------------------------------------------------------------- 1 | """ 2 | contains the components of ML-Net, including the loss functions 3 | This script contains the codes from: 4 | https://bitbucket.org/raingo-ur/mll-tf/src/master/ 5 | https://github.com/ematvey/hierarchical-attention-networks/ 6 | """ 7 | 8 | import tensorflow as tf 9 | import tensorflow.contrib.layers as layers 10 | import numpy as np 11 | 12 | try: 13 | from tensorflow.contrib.rnn import LSTMStateTuple 14 | except ImportError: 15 | LSTMStateTuple = tf.nn.rnn_cell.LSTMStateTuple 16 | 17 | 18 | def bidirectional_rnn(cell_fw, cell_bw, inputs_embedded, input_lengths, 19 | scope=None): 20 | """Bidirecional RNN with concatenated outputs and states""" 21 | with tf.variable_scope(scope or "birnn") as scope: 22 | ((fw_outputs, 23 | bw_outputs), 24 | (fw_state, 25 | bw_state)) = ( 26 | tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fw, 27 | cell_bw=cell_bw, 28 | inputs=inputs_embedded, 29 | sequence_length=input_lengths, 30 | dtype=tf.float32, 31 | swap_memory=True, 32 | scope=scope)) 33 | outputs = tf.concat((fw_outputs, bw_outputs), 2) 34 | 35 | def concatenate_state(fw_state, bw_state): 36 | if isinstance(fw_state, LSTMStateTuple): 37 | state_c = tf.concat( 38 | (fw_state.c, bw_state.c), 1, name='bidirectional_concat_c') 39 | state_h = tf.concat( 40 | (fw_state.h, bw_state.h), 1, name='bidirectional_concat_h') 41 | state = LSTMStateTuple(c=state_c, h=state_h) 42 | return state 43 | elif isinstance(fw_state, tf.Tensor): 44 | state = tf.concat((fw_state, bw_state), 1, 45 | name='bidirectional_concat') 46 | return state 47 | elif (isinstance(fw_state, tuple) and 48 | isinstance(bw_state, tuple) and 49 | len(fw_state) == len(bw_state)): 50 | # multilayer 51 | state = tuple(concatenate_state(fw, bw) 52 | for fw, bw in zip(fw_state, bw_state)) 53 | return state 54 | 55 | else: 56 | raise ValueError( 57 | 'unknown state type: {}'.format((fw_state, bw_state))) 58 | 59 | state = concatenate_state(fw_state, bw_state) 60 | return outputs, state 61 | 62 | 63 | def task_specific_attention(inputs, output_size, 64 | initializer=layers.xavier_initializer(), 65 | activation_fn=tf.tanh, scope=None): 66 | """ 67 | Performs task-specific attention reduction, using learned 68 | attention context vector (constant within task of interest). 69 | Args: 70 | inputs: Tensor of shape [batch_size, units, input_size] 71 | `input_size` must be static (known) 72 | `units` axis will be attended over (reduced from output) 73 | `batch_size` will be preserved 74 | output_size: Size of output's inner (feature) dimension 75 | Returns: 76 | outputs: Tensor of shape [batch_size, output_dim]. 77 | """ 78 | assert len(inputs.get_shape()) == 3 and inputs.get_shape()[-1].value is not None 79 | 80 | with tf.variable_scope(scope or 'attention') as scope: 81 | attention_context_vector = tf.get_variable(name='attention_context_vector', 82 | shape=[output_size], 83 | initializer=initializer, 84 | dtype=tf.float32) 85 | input_projection = layers.fully_connected(inputs, output_size, 86 | activation_fn=activation_fn, 87 | scope=scope) 88 | 89 | vector_attn = tf.reduce_sum(tf.multiply(input_projection, attention_context_vector), axis=2, keepdims=True) 90 | attention_weights = tf.nn.softmax(vector_attn, axis=1) 91 | weighted_projection = tf.multiply(input_projection, attention_weights) 92 | 93 | outputs = tf.reduce_sum(weighted_projection, axis=1) 94 | 95 | return outputs 96 | 97 | 98 | """ 99 | definition of the loss functions 100 | """ 101 | 102 | 103 | def _batch_gather(input, indices, batch_size): 104 | """ 105 | output[i, ..., j] = input[i, indices[i, ..., j]] 106 | """ 107 | shape_output = indices.get_shape().as_list() 108 | 109 | shape_input = input.get_shape().as_list() 110 | shape_input[0] = batch_size 111 | shape_output[0] = batch_size 112 | 113 | assert len(shape_input) == 2 114 | batch_base = shape_input[1] * np.arange(shape_input[0]) 115 | batch_base_shape = [1] * len(shape_output) 116 | batch_base_shape[0] = shape_input[0] 117 | 118 | batch_base = batch_base.reshape(batch_base_shape) 119 | indices = batch_base + indices 120 | 121 | input = tf.reshape(input, [-1]) 122 | return tf.gather(input, indices) 123 | 124 | 125 | def _pairwise(label_pairs, logits, batch_size, NUM_CLASSES): 126 | mapped = _batch_gather(logits, label_pairs, batch_size) 127 | neg, pos = tf.split(mapped, 2, 2) 128 | delta = neg - pos 129 | 130 | neg_idx, pos_idx = tf.split(label_pairs, 2, 2) 131 | _, indices = tf.nn.top_k(tf.stop_gradient(logits), NUM_CLASSES) 132 | _, ranks = tf.nn.top_k(-indices, NUM_CLASSES) 133 | 134 | delta_nnz = tf.cast(tf.not_equal(neg_idx, pos_idx), tf.float32) 135 | return delta, delta_nnz 136 | 137 | 138 | def mll_exp(logits, label_pairs, batch_size, NUM_CLASSES): 139 | # compute label pairs 140 | # # batch_size x num_pairs x 2 141 | # print(logits.get_shape(), "logit shape") 142 | # print(label_pairs.get_shape(), "label_pairs shape") 143 | delta, delta_nnz = _pairwise(label_pairs, logits, batch_size, NUM_CLASSES) 144 | 145 | delta_max = tf.reduce_max(delta, 1, keepdims=True) 146 | delta_max_nnz = tf.nn.relu(delta_max) 147 | 148 | inner_exp_diff = tf.exp(delta - delta_max_nnz) 149 | inner_exp_diff *= delta_nnz 150 | 151 | inner_sum = tf.reduce_sum(inner_exp_diff, 1, keepdims=True) 152 | 153 | loss = delta_max_nnz + tf.log(tf.exp(-delta_max_nnz) + inner_sum) 154 | return tf.reduce_mean(loss) 155 | -------------------------------------------------------------------------------- /src/ML_Net_label_prediction_train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import data_utils as data_utils 3 | import time 4 | import datetime 5 | import os 6 | import os.path as osp 7 | import glob 8 | from ML_Net import ML_elmo 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 12 | 13 | logs_path = "tensorflow_logs/label_predict/" 14 | data_dir = "tf_data/" 15 | 16 | # Parameters 17 | # ================================================== # 18 | 19 | # Data loading params 20 | tf.app.flags.DEFINE_integer("NUM_CLASSES", 7042, "NUM_CLASSES") 21 | tf.app.flags.DEFINE_integer("MAX_LABELS_PERMITTED", 70, "maximum of labels permitted by label decision network ") 22 | tf.app.flags.DEFINE_integer("MAX_PAIRS", 2000, "sample at most MAX_PAIRS from the Cartesian product (negative sampling)") 23 | 24 | # Model Hyperparameters 25 | tf.app.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability") 26 | tf.app.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda") 27 | tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning_rate") 28 | tf.app.flags.DEFINE_integer("hidden_size", 50, "the hidden size of rnn unit") 29 | 30 | # Training parameters 31 | tf.app.flags.DEFINE_integer("batch_size", 16, "Batch Size") 32 | tf.app.flags.DEFINE_integer("num_epochs", 50, "Number of training epochs ") 33 | 34 | # Misc Parameters 35 | tf.app.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") 36 | tf.app.flags.DEFINE_integer("save_model_step", 12000, "save every xx step") 37 | 38 | FLAGS = tf.app.flags.FLAGS 39 | print("\nParameters:") 40 | for attr, value in sorted(FLAGS.__flags.items()): 41 | print("{}={}".format(attr.upper(), value.value)) 42 | print("") 43 | with tf.device('/device:CPU:0'): 44 | num_preprocess_threads = 2 45 | min_after_dequeue = 50 # 1000 per file 46 | examples_queue = tf.RandomShuffleQueue( 47 | capacity=min_after_dequeue + 16 * FLAGS.batch_size, 48 | min_after_dequeue=min_after_dequeue, 49 | dtypes=[tf.string]) 50 | files = glob.glob(osp.join(data_dir, '*_train.tfrecords')) 51 | filename_queue = tf.train.string_input_producer(files,num_epochs=FLAGS.num_epochs,shuffle=True, capacity=2) 52 | reader = tf.TFRecordReader() 53 | _, serialized_example = reader.read(filename_queue) 54 | enqueue_ops = [] 55 | enqueue_ops.append(examples_queue.enqueue([serialized_example])) 56 | 57 | tf.train.queue_runner.add_queue_runner( 58 | tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops)) 59 | example_serialized = examples_queue.dequeue() 60 | 61 | outputs = [] 62 | keys = {}.keys() 63 | for _ in range(num_preprocess_threads): 64 | data = data_utils.__parse_example_proto_with_elmo(example_serialized,FLAGS.NUM_CLASSES, FLAGS.MAX_PAIRS) 65 | keys = data.keys() 66 | outputs.append(list(data.values())) 67 | 68 | res = tf.train.batch_join(outputs, batch_size=FLAGS.batch_size, capacity=2 * num_preprocess_threads * FLAGS.batch_size) 69 | res_d = {} 70 | for key, value in zip(keys, res): 71 | res_d[key] = value 72 | 73 | session_conf = tf.ConfigProto(allow_soft_placement=True) 74 | session_conf.gpu_options.allow_growth = True 75 | sess = tf.Session(config=session_conf) 76 | with sess.as_default(): 77 | 78 | ML_net = ML_elmo( 79 | NUM_CLASSES=FLAGS.NUM_CLASSES, 80 | MAX_LABELS_PERMITTED=FLAGS.MAX_LABELS_PERMITTED, 81 | batch_size=FLAGS.batch_size, 82 | hidden_size=FLAGS.hidden_size, 83 | MAX_PAIRS=FLAGS.MAX_PAIRS 84 | ) 85 | saver = tf.train.Saver(max_to_keep=15) 86 | current_loss = ML_net.loss 87 | # Define Training procedure 88 | global_step = tf.Variable(0, name="global_step", trainable=False) 89 | optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) 90 | grads_and_vars = optimizer.compute_gradients(current_loss) 91 | train_op = tf.contrib.layers.optimize_loss(current_loss, global_step=global_step, learning_rate=FLAGS.learning_rate, 92 | optimizer="Adam") 93 | 94 | timestamp = str(int(time.time())) 95 | out_dir = os.path.abspath(os.path.join(logs_path, "elmo_" + timestamp)) 96 | 97 | print("Writing to {}\n".format(out_dir)) 98 | if not os.path.exists(out_dir): 99 | os.makedirs(out_dir) 100 | 101 | para_writer = open(out_dir + "/parameters.txt", "w") 102 | for attr, value in sorted(FLAGS.__flags.items()): 103 | para_writer.write(data_dir + "\n") 104 | para_writer.write("{}={}".format(attr.upper(), value.value)) 105 | para_writer.write("\n") 106 | para_writer.close() 107 | 108 | def train_step(feed): 109 | """ 110 | A single training step 111 | """ 112 | x_feed = feed[0] 113 | y_feed = feed[1] 114 | feed_dict = {ML_net.input_x: x_feed[0], 115 | ML_net.input_y_label_pairs: y_feed[0], 116 | ML_net.input_y_label_map: y_feed[1], 117 | ML_net.dropout_keep_prob: FLAGS.dropout_keep_prob 118 | } 119 | 120 | _, step, loss = sess.run( 121 | [train_op, global_step, current_loss], 122 | feed_dict) 123 | time_str = datetime.datetime.now().isoformat() 124 | print("{}: step {}, loss {:g}".format(time_str, step, loss)) 125 | 126 | tf.global_variables_initializer().run() 127 | tf.local_variables_initializer().run() 128 | tf.tables_initializer().run() 129 | 130 | coord = tf.train.Coordinator() 131 | threads = tf.train.start_queue_runners(coord=coord) 132 | current_step = 0 133 | try: 134 | while not coord.should_stop(): 135 | train_data = sess.run(res_d) 136 | text = tuple(train_data['raw_text'].flatten().tolist()) 137 | text = [s.decode("utf-8").strip() for s in text] 138 | label_pairs = train_data['label_pair'] 139 | label_map = train_data['label_map'] 140 | 141 | x_feed = [text] 142 | y_feed = [label_pairs, label_map] 143 | feed = [x_feed, y_feed] 144 | train_step(feed) 145 | 146 | current_step = tf.train.global_step(sess, global_step) 147 | if current_step % FLAGS.save_model_step == 0: 148 | saver.save(sess, out_dir + "/" + str(current_step) + "-model.ckpt") 149 | 150 | except tf.errors.OutOfRangeError: 151 | saver.save(sess, out_dir + "/final-model.ckpt") 152 | print("Done training") 153 | finally: 154 | coord.request_stop() 155 | coord.join(threads) 156 | 157 | print("\nend!") -------------------------------------------------------------------------------- /src/ML_Net_label_count_prediction_train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import data_utils as data_utils 3 | import time 4 | import datetime 5 | import os 6 | import os.path as osp 7 | import glob 8 | from ML_Net import ML_elmo 9 | import metrics as metric 10 | import numpy as np 11 | 12 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 13 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 14 | 15 | logs_path = "tensorflow_logs/label_count/" 16 | data_dir = "tf_data/" 17 | trained_model_path = "" 18 | # Parameters 19 | # ================================================== # 20 | 21 | # Data loading params 22 | tf.app.flags.DEFINE_integer("NUM_CLASSES", 7042, "NUM_CLASSES") 23 | tf.app.flags.DEFINE_integer("MAX_LABELS_PERMITTED", 70, "maximum of labels permitted by label decision network ") 24 | tf.app.flags.DEFINE_integer("MAX_PAIRS", 2000, "sample at most MAX_PAIRS from the Cartesian product (negative sampling)") 25 | 26 | # Model Hyperparameters 27 | tf.app.flags.DEFINE_float("dropout_keep_prob", 0.5, "Dropout keep probability") 28 | tf.app.flags.DEFINE_float("l2_reg_lambda", 0.0, "L2 regularization lambda") 29 | tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning_rate") 30 | tf.app.flags.DEFINE_integer("hidden_size", 50, "the hidden size of rnn unit") 31 | 32 | # Training parameters 33 | tf.app.flags.DEFINE_integer("batch_size", 16, "Batch Size ") 34 | tf.app.flags.DEFINE_integer("num_epochs", 20, "Number of training epochs") 35 | 36 | # Misc Parameters 37 | tf.app.flags.DEFINE_integer("eval_model_step", 200, "evaluate every xx step") 38 | 39 | FLAGS = tf.app.flags.FLAGS 40 | print("\nParameters:") 41 | for attr, value in sorted(FLAGS.__flags.items()): 42 | print("{}={}".format(attr.upper(), value.value)) 43 | print("") 44 | 45 | num_preprocess_threads = 4 46 | min_after_dequeue = 100 # 1000 per file 47 | examples_queue = tf.RandomShuffleQueue( 48 | capacity=min_after_dequeue + 16 * FLAGS.batch_size, 49 | min_after_dequeue=min_after_dequeue, 50 | dtypes=[tf.string]) 51 | files = glob.glob(osp.join(data_dir, '*_train.tfrecords')) 52 | filename_queue = tf.train.string_input_producer(files,num_epochs=FLAGS.num_epochs,shuffle=True, capacity=10) 53 | reader = tf.TFRecordReader() 54 | _, serialized_example = reader.read(filename_queue) 55 | enqueue_ops = [] 56 | enqueue_ops.append(examples_queue.enqueue([serialized_example])) 57 | 58 | tf.train.queue_runner.add_queue_runner( 59 | tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops)) 60 | example_serialized = examples_queue.dequeue() 61 | 62 | outputs = [] 63 | keys = {}.keys() 64 | for _ in range(num_preprocess_threads): 65 | data = data_utils.__parse_example_proto_with_elmo(example_serialized, FLAGS.NUM_CLASSES, FLAGS.MAX_PAIRS) 66 | keys = data.keys() 67 | outputs.append(list(data.values())) 68 | 69 | res = tf.train.batch_join(outputs, batch_size=FLAGS.batch_size, capacity=2 * num_preprocess_threads * FLAGS.batch_size) 70 | res_d = {} 71 | for key, value in zip(keys, res): 72 | res_d[key] = value 73 | 74 | session_conf = tf.ConfigProto(allow_soft_placement=True) 75 | session_conf.gpu_options.allow_growth = True 76 | sess = tf.Session(config=session_conf) 77 | 78 | with sess.as_default(): 79 | 80 | ML_net = ML_elmo( 81 | NUM_CLASSES=FLAGS.NUM_CLASSES, 82 | MAX_LABELS_PERMITTED=FLAGS.MAX_LABELS_PERMITTED, 83 | batch_size=FLAGS.batch_size, 84 | hidden_size=FLAGS.hidden_size, 85 | MAX_PAIRS=FLAGS.MAX_PAIRS 86 | ) 87 | saver = tf.train.Saver(max_to_keep=15) 88 | current_loss = ML_net.lcnt_loss 89 | # Define Training procedure 90 | global_step = tf.Variable(0, name="global_step", trainable=False) 91 | optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) 92 | grads_and_vars = optimizer.compute_gradients(current_loss) 93 | train_op = tf.contrib.layers.optimize_loss(current_loss, global_step=global_step, learning_rate=FLAGS.learning_rate, 94 | optimizer="Adam") 95 | 96 | timestamp = str(int(time.time())) 97 | out_dir = os.path.abspath(os.path.join(logs_path, "runs", timestamp)) 98 | 99 | print("Writing to {}\n".format(out_dir)) 100 | if not os.path.exists(out_dir): 101 | os.makedirs(out_dir) 102 | 103 | para_writer = open(out_dir + "/parameters.txt", "w") 104 | for attr, value in sorted(FLAGS.__flags.items()): 105 | para_writer.write(data_dir + "\n") 106 | para_writer.write("{}={}".format(attr.upper(), value.value)) 107 | para_writer.write("\n") 108 | para_writer.close() 109 | 110 | def eval_step(): 111 | print("----\nprediction begin!\n----") 112 | # Tensors we want to evaluate 113 | prediction_logits_list = list() 114 | prediction_counts_list = list() 115 | eval_data = data_utils.load_obj("tf_data/test_abs_label_zip.pickle") 116 | 117 | train_x, train_y = zip(*eval_data) 118 | total_test_size = len(train_x) 119 | # print(total_test_size, "total_test_size") 120 | remain_test_size = total_test_size % FLAGS.batch_size 121 | total_epoch = int(total_test_size / FLAGS.batch_size) 122 | batches = data_utils.batch_iter_eval(zip(train_x, train_y), batch_size=FLAGS.batch_size) 123 | 124 | gold_labels_list = list() 125 | for x_batch, y_batch in batches: 126 | gold_labels_list.extend(y_batch) 127 | # print(y_batch, "gold_label") 128 | # print(np.where(y_batch == 1)[0],"gold label index") 129 | text_list = list() 130 | for abs_dict in x_batch: 131 | text_list.append(abs_dict["raw_text"]) 132 | 133 | feed_dict = { 134 | ML_net.input_x: text_list, 135 | ML_net.dropout_keep_prob: 1 136 | } 137 | batch_logits, batch_count = sess.run([ML_net.logits, ML_net.predictions_count], feed_dict) 138 | prediction_logits_list.extend(batch_logits) 139 | prediction_counts_list.extend(batch_count) 140 | 141 | prediction_logits_list_update = prediction_logits_list[:total_epoch * FLAGS.batch_size].copy() 142 | prediction_logits_list_update.extend(prediction_logits_list[-remain_test_size:]) 143 | prediction_counts_list_update = prediction_counts_list[:total_epoch * FLAGS.batch_size].copy() 144 | prediction_counts_list_update.extend(prediction_counts_list[-remain_test_size:]) 145 | gold_labels_list_update = gold_labels_list[:total_epoch * FLAGS.batch_size].copy() 146 | gold_labels_list_update.extend(gold_labels_list[-remain_test_size:]) 147 | 148 | prediction_logits_arr = np.asarray(prediction_logits_list_update) 149 | prediction_counts_list = np.asarray(prediction_counts_list_update) 150 | 151 | p, r, f, p_list, g_list = metric.get_p_r_f_jamia(logits=prediction_logits_arr, counts=prediction_counts_list, 152 | labels=gold_labels_list_update) 153 | print(global_step.eval(), "step") 154 | print(p, "\nprecision") 155 | print(r, "recall") 156 | print(f, "f-measure") 157 | 158 | 159 | def train_step(feed): 160 | """ 161 | A single training step 162 | """ 163 | x_feed = feed[0] 164 | y_feed = feed[1] 165 | feed_dict = {ML_net.input_x: x_feed[0], 166 | ML_net.input_y_label_pairs: y_feed[0], 167 | ML_net.input_y_label_map: y_feed[1], 168 | ML_net.dropout_keep_prob: FLAGS.dropout_keep_prob 169 | } 170 | 171 | _, step, loss = sess.run( 172 | [train_op, global_step, current_loss], 173 | feed_dict) 174 | time_str = datetime.datetime.now().isoformat() 175 | print("{}: step {}, loss {:g}".format(time_str, step, loss)) 176 | 177 | tf.global_variables_initializer().run() 178 | tf.local_variables_initializer().run() 179 | tf.tables_initializer().run() 180 | 181 | saver.restore(sess, trained_model_path) 182 | coord = tf.train.Coordinator() 183 | threads = tf.train.start_queue_runners(coord=coord) 184 | current_step = 0 185 | try: 186 | while not coord.should_stop(): 187 | train_data = sess.run(res_d) 188 | text = tuple(train_data['raw_text'].flatten().tolist()) 189 | text = [s.decode("utf-8").strip() for s in text] 190 | label_pairs = train_data['label_pair'] 191 | label_map = train_data['label_map'] 192 | 193 | x_feed = [text] 194 | y_feed = [label_pairs, label_map] 195 | feed = [x_feed, y_feed] 196 | train_step(feed) 197 | 198 | current_step = tf.train.global_step(sess, global_step) 199 | if current_step % FLAGS.eval_model_step == 0: 200 | eval_step() 201 | 202 | except tf.errors.OutOfRangeError: 203 | print("Done training") 204 | finally: 205 | coord.request_stop() 206 | coord.join(threads) 207 | 208 | # print("Model saved in path: %s" % save_path) 209 | print("\nend!") 210 | --------------------------------------------------------------------------------