├── BAG-pytorch.py ├── BAG.png ├── BAG.py ├── README.md ├── data ├── ner_dict.pickle └── pos_dict.pickle ├── prepro.py └── utils ├── ConfigLogger.py ├── Dataset.py ├── __init__.py ├── pytorch_dataset.py └── str2bool.py /BAG-pytorch.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import os 6 | import argparse 7 | import sys 8 | import json 9 | 10 | from utils.ConfigLogger import config_logger 11 | from utils.str2bool import str2bool 12 | from utils.pytorch_dataset import get_pytorch_dataloader 13 | from ignite.engine import Engine, Events 14 | from ignite.handlers import EarlyStopping 15 | from ignite.metrics import Accuracy, RunningAverage 16 | from ignite.contrib.handlers import ProgressBar, LRScheduler 17 | from torch.optim.lr_scheduler import StepLR 18 | from torch.optim import Adam 19 | 20 | '''Feature layer who takes raw feature as input''' 21 | class FeatureLayer(nn.Module): 22 | def __init__(self, args): 23 | super(FeatureLayer, self).__init__() 24 | self.use_elmo = args.use_elmo 25 | self.use_glove = args.use_glove 26 | self.use_extra_feature = args.use_extra_feature 27 | self.max_query_size = args.max_query_size 28 | self.max_nodes = args.max_nodes 29 | self.query_encoding_type = args.query_encoding_type 30 | self.encoding_size = args.encoding_size 31 | self.lstm_hidden_size = args.lstm_hidden_size 32 | self.encoder_input_size = 0 33 | if self.use_elmo: 34 | self.encoder_input_size += 3 * 1024 35 | if self.use_glove: 36 | self.encoder_input_size += 300 37 | if self.query_encoding_type == 'lstm': 38 | self.query_encoder = nn.LSTM(self.encoder_input_size, int(self.encoding_size / 2), 2, bidirectional=True, 39 | batch_first=True) 40 | else: 41 | self.query_encoder = nn.Linear(self.encoder_input_size, self.encoding_size) 42 | self.node_encoder = nn.Linear(self.encoder_input_size, self.encoding_size) 43 | args.hidden_size = self.encoding_size 44 | if self.use_extra_feature: 45 | self.ner_embedding = nn.Embedding(args.ner_dict_size, args.ner_emb_size) 46 | self.pos_embedding = nn.Embedding(args.pos_dict_size, args.pos_emb_size) 47 | args.hidden_size += (args.ner_emb_size + args.pos_emb_size) 48 | 49 | def forward(self, nodes_elmo, query_elmo, nodes_glove, query_glove, nodes_ner, nodes_pos, query_ner, query_pos, 50 | query_lengths): 51 | query_flat, nodes_flat = None, None 52 | if self.use_elmo: 53 | query_flat = query_elmo.view(-1, self.max_query_size, 3 * 1024) 54 | nodes_flat = nodes_elmo.view(-1, self.max_nodes, 3 * 1024) 55 | if self.use_glove: 56 | if query_flat is None: 57 | query_flat, nodes_flat = query_glove, nodes_glove 58 | else: 59 | query_flat = torch.cat((query_flat, query_glove), dim=-1) 60 | nodes_flat = torch.cat((nodes_flat, nodes_glove), dim=-1) 61 | if self.query_encoding_type == 'lstm': 62 | query_flat = nn.utils.rnn.pack_padded_sequence(query_flat, query_lengths, batch_first=True, 63 | enforce_sorted=False) 64 | query_compress = F.tanh( 65 | torch.nn.utils.rnn.pad_packed_sequence(self.query_encoder(query_flat)[0], batch_first=True)[0]) 66 | else: 67 | query_compress = F.tanh(self.query_encoder(query_flat)) 68 | nodes_compress = F.tanh(self.node_encoder(nodes_flat)) 69 | if self.use_extra_feature: 70 | query_ner_emb = self.ner_embedding(query_ner) 71 | query_pos_emb = self.pos_embedding(query_pos) 72 | nodes_ner_emb = self.ner_embedding(nodes_ner) 73 | nodes_pos_emb = self.pos_embedding(nodes_pos) 74 | nodes_compress = torch.cat([nodes_compress, nodes_ner_emb, nodes_pos_emb], dim=-1) 75 | new_query_length = query_compress.shape[1] 76 | query_compress = torch.cat([query_compress, query_ner_emb[:, :new_query_length, :], query_pos_emb[:, :new_query_length, :]], dim=-1) 77 | return nodes_compress, query_compress 78 | 79 | '''A single gated GCN layer''' 80 | class GCNLayer(nn.Module): 81 | def __init__(self, args): 82 | super(GCNLayer, self).__init__() 83 | self.use_edge = args.use_edge 84 | if args.use_edge: 85 | self.gcns = nn.ModuleList([nn.Linear(args.hidden_size, args.hidden_size) for _ in range(3)]) 86 | else: 87 | self.gcns = nn.ModuleList([nn.Linear(args.hidden_size, args.hidden_size)]) 88 | self.update_gate = nn.Linear(args.hidden_size, args.hidden_size) 89 | self.att_gate = nn.Linear(args.hidden_size * 2, args.hidden_size) 90 | 91 | def forward(self, adj, nodes_hidden, nodes_mask): 92 | accumulated_nodes_hidden = torch.stack([gcn(nodes_hidden) for gcn in self.gcns], dim=1) * \ 93 | nodes_mask.unsqueeze(1).unsqueeze(-1) 94 | update = torch.sum(torch.matmul(adj, accumulated_nodes_hidden), dim=1) + \ 95 | self.update_gate(nodes_hidden) * nodes_mask.unsqueeze(-1) 96 | att = F.sigmoid(self.att_gate(torch.cat([update, nodes_hidden], -1))) 97 | output = att * F.tanh(update) + (1 - att) * nodes_hidden 98 | return output 99 | 100 | '''Bidirectional attention layer''' 101 | class BiAttention(nn.Module): 102 | def __init__(self, args): 103 | super(BiAttention, self).__init__() 104 | self.max_nodes = args.max_nodes 105 | self.max_query_size = args.max_query_size 106 | self.attention_linear = nn.Linear(args.hidden_size * 3, 1, bias=False) 107 | 108 | def forward(self, nodes_compress, query_compress, nodes_hidden): 109 | query_size = query_compress.shape[1] 110 | expanded_query = query_compress.unsqueeze(1).repeat((1, self.max_nodes, 1, 1)) 111 | expanded_nodes = nodes_compress.unsqueeze(2).repeat((1, 1, query_size, 1)) 112 | nodes_query_similarity = expanded_nodes * expanded_query 113 | concatenated_data = torch.cat((expanded_nodes, expanded_query, nodes_query_similarity), -1) 114 | similarity = self.attention_linear(concatenated_data).squeeze(-1) 115 | nodes2query = torch.matmul(F.softmax(similarity, dim=-1), query_compress) 116 | b = F.softmax(torch.max(similarity, dim=-1)[0], dim=-1) 117 | query2nodes = torch.matmul(b.unsqueeze(1), nodes_compress).repeat(1, self.max_nodes, 1) 118 | attention_output = torch.cat( 119 | (nodes_compress, nodes2query, nodes_compress * nodes2query, nodes_compress * query2nodes), dim=-1) 120 | return attention_output 121 | 122 | '''Output layer who takes output from attention layer and node output mask to generate the predictions''' 123 | class OutputLayer(nn.Module): 124 | def __init__(self, ags): 125 | super(OutputLayer, self).__init__() 126 | self.linear1 = nn.Linear(args.hidden_size * 4, 128) 127 | self.linear2 = nn.Linear(128, 1) 128 | 129 | def forward(self, attention_output, output_mask): 130 | raw_preds = self.linear2(F.tanh(self.linear1(attention_output))).squeeze(-1) 131 | preds = output_mask.float() * raw_preds.unsqueeze(1) 132 | preds = preds.masked_fill(preds == 0.0, -float("inf")) 133 | preds = torch.max(preds, dim=-1)[0] 134 | return preds 135 | 136 | '''BAG Model class''' 137 | class Model(nn.Module): 138 | def __init__(self, args): 139 | super(Model, self).__init__() 140 | self.feature_layer = FeatureLayer(args) 141 | self.gcn_layers = nn.ModuleList([GCNLayer(args) for _ in range(args.hop_num)]) 142 | self.bi_attention = BiAttention(args) 143 | self.output_layer = OutputLayer(args) 144 | 145 | def forward(self, adj, nodes_elmo, query_elmo, nodes_glove, query_glove, nodes_ner, nodes_pos, query_ner, query_pos, 146 | query_lengths, nodes_mask, output_mask): 147 | nodes_compress, query_compress = self.feature_layer(nodes_elmo, query_elmo, nodes_glove, query_glove, nodes_ner, 148 | nodes_pos, query_ner, query_pos, query_lengths) 149 | nodes_hidden = nodes_compress 150 | for gcn_layer in self.gcn_layers: 151 | nodes_hidden = gcn_layer(adj, nodes_hidden, nodes_mask) 152 | attention_output = self.bi_attention(nodes_compress, query_compress, nodes_hidden) 153 | preds = self.output_layer(attention_output, output_mask) 154 | return preds 155 | 156 | """ Check whether the preprocessed file existed in current directory 157 | """ 158 | def checkPreprocessFile(file_name, add_query_node): 159 | preprocess_file_name = file_name 160 | if add_query_node: 161 | preprocess_file_name = preprocess_file_name + '.add_query_node' 162 | if not os.path.isfile('{}.preprocessed.pickle'.format(preprocess_file_name)): 163 | return preprocess_file_name, False 164 | return preprocess_file_name, True 165 | 166 | def write_best_accuray_txt(model_path, accuracy): 167 | file_name = model_path + '/best.txt' 168 | if os.path.isfile(file_name): 169 | os.remove(file_name) 170 | with open(file_name, 'w') as f: 171 | f.write(str(accuracy) + '\n') 172 | f.close() 173 | 174 | def generate_answer_json(answer_dict, in_file): 175 | with open(in_file, 'r') as f: 176 | data = json.load(f) 177 | final_answer_dict = {} 178 | for d in data: 179 | final_answer_dict[d['id']] = d['candidates'][answer_dict[d['id']]] 180 | with open('predictions.json', 'w') as f: 181 | json.dump(final_answer_dict, f) 182 | 183 | if __name__ == '__main__': 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument('in_file', type=str) 186 | parser.add_argument('dev_file', type=str) 187 | parser.add_argument('--model_checkpoint', type=str, default=None, help='The initial model checkpoint') 188 | parser.add_argument('--is_masked', type=bool, default=False, help='using masked data or not') 189 | parser.add_argument('--continue_train', type=bool, default=False, help='continue last train using saved model or' + 190 | ' start a new training') 191 | parser.add_argument('--cpu_number', type=int, default=4, help='The number of CPUs used in current script') 192 | parser.add_argument('--add_query_node', type=bool, default=False, help='Whether the entity in query is involved '+ 193 | 'in the construction of graph') 194 | parser.add_argument('--evaluation_mode', type=bool, default=False, help='Whether using evaluation mode') 195 | parser.add_argument('--use_elmo', type=str2bool, default=True, help='Whether use the ELMo as a feature for each node') 196 | parser.add_argument('--use_glove', type=str2bool, default=True, help='Whether use the GloVe as a feature for each node') 197 | parser.add_argument('--use_extra_feature', type=str2bool, default=True, help='Whether use extra feature for ' + 198 | 'each node, e.g. NER and POS') 199 | parser.add_argument('--use_edge', type=str2bool, default=True, help='Whether use edges in graph') 200 | parser.add_argument('--use_full_query_token', type=str2bool, default=False, help='Tokens in query will be splitted by underlines') 201 | parser.add_argument('--dynamic_change_learning_rate', type=str2bool, default=True, help='Whether the learning rate will change along training') 202 | parser.add_argument('--lr', type=float, default=2e-4, help='Initial learning rate') 203 | parser.add_argument('--hop_num', type=int, default=5, help='Hop num in GCN layer, in other words the depth of GCN') 204 | parser.add_argument('--epochs', type=int, default=50, help='Epoch number for the training') 205 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size') 206 | parser.add_argument('--info_interval', type=int, default=1000, help='The interval to display training loss info') 207 | parser.add_argument('--dropout', type=float, default=0.8, help="Keep rate for dropout in model") 208 | parser.add_argument('--encoding_size', type=int, default=512, help='The encoding output size for both query and nodes') 209 | parser.add_argument('--lstm_hidden_size', type=int, default=256, help='The hidden size for lstm intermediate layer') 210 | parser.add_argument('--pos_emb_size', type=int, default=8, help='The size of POS embedding') 211 | parser.add_argument('--ner_emb_size', type=int, default=8, help='The size of NER embedding') 212 | parser.add_argument('--query_encoding_type', type=str, default='lstm', help='The function to encode query') 213 | parser.add_argument('--max_nodes', type=int, default=500, help='Maximum node number in graph') 214 | parser.add_argument('--max_query_size', type=int, default=25, help='Maximum length for query') 215 | parser.add_argument('--max_candidates', type=int, default=80, help='Maximum number of answer candidates') 216 | parser.add_argument('--max_candidates_len', type=int, default=10, help='Maximum length for a candidates') 217 | parser.add_argument('--loss_log_interval', type=int, default=1000, help='Iteration interval to print loss') 218 | parser.add_argument('--patience', type=int, default=5, help='Epoch early stopping patience') 219 | 220 | args = parser.parse_args() 221 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 222 | 223 | encoding_type_map = {'lstm':'lstm','linear':'linear'} 224 | best_res = {'acc': 0.0} 225 | pred_res = [] 226 | 227 | model_name = 'BAG-pytorch' 228 | if args.evaluation_mode: 229 | logger = config_logger('evaluation-pytorch/' + model_name) 230 | else: 231 | logger = config_logger('BAG-pytorch') 232 | 233 | model_path = os.getcwd() + '/models-pytorch/' + model_name + '/' 234 | if not os.path.exists(model_path): 235 | os.makedirs(model_path) 236 | model_path = os.path.join(model_path, 'best_model.bin') 237 | 238 | for item in vars(args).items(): 239 | logger.info('%s : %s', item[0], str(item[1])) 240 | 241 | '''Check whether preprocessed files are existed''' 242 | train_file_name_prefix, fileExist = checkPreprocessFile(args.in_file, args.add_query_node) 243 | if not fileExist: 244 | logger.info('Cannot find preprocess data %s, program will shut down.', 245 | '{}.preprocessed.pickle'.format(train_file_name_prefix)) 246 | sys.exit() 247 | dev_file_name_prefix, fileExist = checkPreprocessFile(args.dev_file, args.add_query_node) 248 | if not fileExist: 249 | logger.info('Cannot find preprocess data %s, program will shut down.', 250 | '{}.preprocessed.pickle'.format(dev_file_name_prefix)) 251 | sys.exit() 252 | args.pos_dict_size, args.ner_dict_size = 0, 0 253 | dev_data_loader, args.ner_dict_size, args.pos_dict_size, id_candidate_list = \ 254 | get_pytorch_dataloader(args, dev_file_name_prefix, for_evaluation=True) 255 | 256 | '''Initialize the BAG model''' 257 | model = Model(args) 258 | if args.model_checkpoint is not None: 259 | model.load_state_dict(torch.load(args.model_checkpoint)) 260 | logger.info('Load previous model from %s', args.model_checkpoint) 261 | 262 | '''Handler function for training''' 263 | def train(engine, batch): 264 | model.train() 265 | batch = tuple(input_tensor.to(device) for input_tensor in batch) 266 | adj, nodes_elmo, query_elmo, nodes_glove, query_glove, nodes_ner, nodes_pos, query_ner, query_pos, nodes_mask, \ 267 | query_lengths, output_mask, labels = batch 268 | preds = model(adj, nodes_elmo, query_elmo, nodes_glove, query_glove, nodes_ner, nodes_pos, query_ner, query_pos, 269 | query_lengths, nodes_mask, output_mask) 270 | loss_fct = nn.CrossEntropyLoss() 271 | loss = loss_fct(preds, labels) 272 | if loss.item() == float("inf"): 273 | tmp_labels, tmp_preds = [], [] 274 | for i in range(labels.shape[0]): 275 | if preds[i][labels[i]] != -float("inf"): 276 | tmp_labels.append(labels[i].unsqueeze(0)) 277 | tmp_preds.append(preds[i].unsqueeze(0)) 278 | loss = loss_fct(torch.cat(tmp_preds, dim=0), torch.cat(tmp_labels, dim=0)) 279 | optimizer.zero_grad() 280 | loss.backward() 281 | optimizer.step() 282 | return loss.item() 283 | 284 | '''Handler function for evaluation''' 285 | def evaluation(engine, batch): 286 | model.eval() 287 | with torch.no_grad(): 288 | batch = tuple(input_tensor.to(device) for input_tensor in batch) 289 | adj, nodes_elmo, query_elmo, nodes_glove, query_glove, nodes_ner, nodes_pos, query_ner, query_pos, \ 290 | nodes_mask, query_lengths, output_mask, labels = batch 291 | preds = model(adj, nodes_elmo, query_elmo, nodes_glove, query_glove, nodes_ner, nodes_pos, query_ner, 292 | query_pos, query_lengths, nodes_mask, output_mask) 293 | if args.evaluation_mode: 294 | pred_res.extend(torch.argmax(preds, dim=1).tolist()) 295 | return preds, labels 296 | evaluator = Engine(evaluation) 297 | Accuracy(output_transform=lambda x: (x[0], x[1])).attach(evaluator, 'accuracy') 298 | dev_pbar = ProgressBar(persist=True, desc='Validation') 299 | dev_pbar.attach(evaluator) 300 | 301 | '''Handler function to save best models after each evaluation''' 302 | def after_evaluation(engine): 303 | acc = engine.state.metrics['accuracy'] 304 | logger.info('Evaluation accuracy on Epoch %d is %.3f', engine.state.epoch, acc * 100) 305 | if acc > best_res['acc']: 306 | logger.info('Current model BEATS the previous best model, previous best accuracy is %.5f', best_res['acc']) 307 | torch.save(model.state_dict(), model_path) 308 | logger.info('Best model has been saved') 309 | else: 310 | logger.info('Current model CANNOT BEAT the previous best model, previous best accuracy is %.5f', 311 | best_res['acc']) 312 | 313 | def score_function(engine): 314 | return engine.state.metrics['accuracy'] 315 | 316 | if not args.evaluation_mode: 317 | '''If current run is training''' 318 | train_data_loader, _, _, _ = get_pytorch_dataloader(args, train_file_name_prefix, shuffle=True) 319 | optimizer = Adam(model.parameters(), lr=args.lr) 320 | '''Learning rate decays every 5 epochs''' 321 | optimizer_scheduler = StepLR(optimizer, step_size=5, gamma=0.5) 322 | scheduler = LRScheduler(optimizer_scheduler) 323 | trainer = Engine(train) 324 | trainer.add_event_handler(Events.EPOCH_COMPLETED, scheduler) 325 | trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda _: evaluator.run(dev_data_loader)) 326 | 327 | pbar = ProgressBar(persist=True, desc='Training') 328 | pbar.attach(trainer, metric_names=["loss"]) 329 | RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") 330 | 331 | trainer.add_event_handler(Events.ITERATION_COMPLETED(every=args.loss_log_interval), lambda engine: \ 332 | logger.info('Loss at iteration %d is %.5f', engine.state.iteration, engine.state.metrics['loss'])) 333 | early_stop_handler = EarlyStopping(patience=args.patience, score_function=score_function, trainer=trainer) 334 | evaluator.add_event_handler(Events.COMPLETED, lambda engine: after_evaluation(engine)) 335 | evaluator.add_event_handler(Events.COMPLETED, early_stop_handler) 336 | 337 | trainer.run(train_data_loader, max_epochs=args.epochs) 338 | else: 339 | '''If current run is evaluation, it will generate prediction json file''' 340 | evaluator.run(dev_data_loader) 341 | evaluator.add_event_handler(Events.COMPLETED, 342 | lambda engine: logger.info('Current evaluation accuracy is %.3f', engine.state.metrics['accuracy'] * 100)) 343 | pred_dict = {} 344 | for i, pred_label in enumerate(pred_res): 345 | pred_dict[id_candidate_list[i][0]] = id_candidate_list[i][1][pred_label] 346 | with open('predictions.json', 'w') as f: 347 | json.dump(pred_dict, f) 348 | -------------------------------------------------------------------------------- /BAG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoyu-noob/BAG/180538a8e0de3a6a5465802a1e10feee4d564dd2/BAG.png -------------------------------------------------------------------------------- /BAG.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | import os 6 | import math 7 | import argparse 8 | import sys 9 | import json 10 | 11 | from nltk.tokenize import TweetTokenizer 12 | 13 | from utils.ConfigLogger import config_logger 14 | from utils.str2bool import str2bool 15 | from utils.Dataset import Dataset 16 | 17 | """The model class 18 | """ 19 | class Model: 20 | 21 | def __init__(self, use_elmo, use_glove, use_extra_feature, encoding_size, pos_emb_size, ner_emb_size, pos_dict_size, 22 | ner_dict_size, max_nodes=500, max_query_size=25, glove_dim=300, query_encoding_type='lstm'): 23 | # # set placeholder for glove feature when glove feature is used 24 | self.nodes_glove, self.query_glove, self.nodes_ner, self.nodes_pos = None, None, None, None 25 | self.query_ner, self.query_pos, self.ner_dict_size, self.pos_dict_size = None, None, None, None 26 | self.use_glove, self.query_encoding_type, self.use_extra_feature = use_glove, query_encoding_type, use_extra_feature 27 | self.use_elmo = use_elmo 28 | self.ner_emb_size, self.pos_emb_size = ner_emb_size, pos_emb_size 29 | self.max_nodes, self.max_query_size = max_nodes, max_query_size 30 | self.encoding_size = encoding_size 31 | self.pos_dict_size, self.ner_dict_size = pos_dict_size, ner_dict_size 32 | 33 | """ Main function to get the model prediction""" 34 | def modelProcessing(self, query_length, adj, nodes_mask, bmask, nodes_elmo, query_elmo, nodes_glove, query_glove, 35 | nodes_ner, nodes_pos, query_ner, query_pos, dropout): 36 | 37 | ## obtain the multi-level feature for both nodes and query 38 | nodes_compress, query_compress = self.featureLayer(query_length, nodes_elmo, query_elmo, nodes_glove, query_glove, 39 | nodes_ner, nodes_pos, query_ner, query_pos) 40 | # create nodes 41 | nodes = nodes_compress * tf.expand_dims(nodes_mask, -1) 42 | 43 | ## using GCN to handle the features of nodes and get the transformed nodes representation 44 | nodes = tf.nn.dropout(nodes, dropout) 45 | last_hop = nodes 46 | for _ in range(hops): 47 | last_hop = self.GCNLayer(adj, last_hop, nodes_mask) # last_hop=(batch_size, max_nodes, node_feature_dim) 48 | 49 | ## Bi-directional attention flow is applied to calculate the attention result between nodes and query 50 | attentionFlowOutput = self.biAttentionLayer(query_compress, nodes_compress, last_hop) 51 | 52 | ## obtain the predictions 53 | predictions = self.outputLayer(attentionFlowOutput, bmask) 54 | return predictions 55 | 56 | """ Multi-level feature layer 57 | """ 58 | def featureLayer(self, query_length, nodes_elmo, query_elmo, nodes_glove, query_glove, nodes_ner, nodes_pos, 59 | query_ner, query_pos): 60 | # compress and flatten query 61 | with tf.variable_scope('feature_layer', reuse=tf.AUTO_REUSE): 62 | query_flat, nodes_flat = None, None 63 | if self.use_elmo: 64 | query_flat = tf.reshape(query_elmo, (-1, self.max_query_size, 3 * 1024)) 65 | nodes_flat = tf.reshape(nodes_elmo, (-1, self.max_nodes, 3 * 1024)) 66 | if self.use_glove: 67 | if query_flat is None: 68 | query_flat, nodes_flat = query_glove, nodes_glove 69 | else: 70 | query_flat = tf.concat((query_flat, query_glove), -1) 71 | nodes_flat = tf.concat((nodes_flat, nodes_glove), -1) 72 | query_compress = None 73 | if self.query_encoding_type == 'lstm': 74 | lstm_size = self.encoding_size / 2 75 | query_compress, output_state_fw, output_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( 76 | cells_fw=[tf.nn.rnn_cell.LSTMCell(256), tf.nn.rnn_cell.LSTMCell(lstm_size)], 77 | cells_bw=[tf.nn.rnn_cell.LSTMCell(256), tf.nn.rnn_cell.LSTMCell(lstm_size)], 78 | inputs=query_flat, 79 | dtype=tf.float32, 80 | sequence_length=query_length 81 | ) 82 | elif self.query_encoding_type == 'linear': 83 | query_compress = tf.layers.dense(query_flat, units=self.encoding_size, activation=tf.nn.tanh) 84 | # query_compress = (batch_size, query_feature_dim) 85 | # query_compress = tf.concat((output_state_fw[-1].h, output_state_bw[-1].h), -1) 86 | nodes_compress = tf.layers.dense(nodes_flat, units=self.encoding_size, activation=tf.nn.tanh) 87 | 88 | ## concatenate POS and NER feature with encoded feature 89 | if self.use_extra_feature: 90 | ner_embeddings = tf.get_variable('ner_embeddings', [self.ner_dict_size, self.ner_emb_size]) 91 | pos_embeddings = tf.get_variable('pos_embeddings', [self.pos_dict_size, self.pos_emb_size]) 92 | query_ner_emb = tf.nn.embedding_lookup(ner_embeddings, query_ner) 93 | query_pos_emb = tf.nn.embedding_lookup(pos_embeddings, query_pos) 94 | nodes_ner_emb = tf.nn.embedding_lookup(ner_embeddings, nodes_ner) 95 | nodes_pos_emb = tf.nn.embedding_lookup(pos_embeddings, nodes_pos) 96 | # (batch_size, max_query_length, hidden_size + ner_emb_size + pos_emb_size) 97 | query_compress = tf.concat((query_compress, query_ner_emb, query_pos_emb), -1) 98 | # (batch_size, max_nodes, hidden_size + ner_emb_size + pos_emb_size) 99 | nodes_compress = tf.concat((nodes_compress, nodes_ner_emb, nodes_pos_emb), -1) 100 | return nodes_compress, query_compress 101 | 102 | """ Output layer in BAG 103 | """ 104 | def outputLayer(self, attentionFlowOutput, bmask): 105 | with tf.variable_scope('output_layer', reuse=tf.AUTO_REUSE): 106 | ## two layer FFN 107 | ## The dimension of intermediate layer in following FFN is 128 for pre-trained model 108 | ## You can try to use 256 because I found it has a better performance on dev ser. 109 | rawPredictions = tf.squeeze(tf.layers.dense(tf.layers.dense( 110 | attentionFlowOutput, units=128, activation=tf.nn.tanh), units=1), -1) 111 | 112 | predictions2 = bmask * tf.expand_dims(rawPredictions, 1) 113 | predictions2 = tf.where(tf.equal(predictions2, 0), 114 | tf.fill(tf.shape(predictions2), -np.inf), predictions2) 115 | predictions2 = tf.reduce_max(predictions2, -1) 116 | return predictions2 117 | 118 | """ Bi-directional attention layer in BAG 119 | """ 120 | def biAttentionLayer(self, query_compress, nodes_compress, last_hop): 121 | with tf.variable_scope('attention_flow', reuse=tf.AUTO_REUSE): 122 | # context_query_similarity = (batch_size, max_nodes, node_feature_dim) 123 | expanded_query = tf.tile(tf.expand_dims(query_compress, -3), (1, self.max_nodes, 1, 1)) 124 | expanded_nodes = tf.tile(tf.expand_dims(last_hop, -2), (1, 1, self.max_query_size, 1)) 125 | context_query_similarity = expanded_nodes * expanded_query 126 | # concated_attention_data = (batch_size, max_nodes, max_query, feature_dim) 127 | concated_attention_data = tf.concat((expanded_nodes, expanded_query, context_query_similarity), -1) 128 | similarity = tf.reduce_mean(tf.layers.dense(concated_attention_data, units=1, use_bias=False), 129 | -1) # (batch_size, max_nodes, max_query) 130 | 131 | ## nodes to query = (batch_size, max_nodes, feature_dim) 132 | nodes2query = tf.matmul(tf.nn.softmax(similarity, -1), query_compress) 133 | ## query to nodes = (batch_size, max_query, feature_dim) 134 | b = tf.nn.softmax(tf.reduce_max(similarity, -1), -1) # b = (batch_size, max_nodes) 135 | query2nodes = tf.matmul(tf.expand_dims(b, 1), nodes_compress) 136 | query2nodes = tf.tile(query2nodes, (1, self.max_nodes, 1)) 137 | G = tf.concat((nodes_compress, nodes2query, nodes_compress * nodes2query, nodes_compress * query2nodes), -1) 138 | # G = tf.concat((nodes_compress, nodes_compress * nodes2query, nodes_compress * query2nodes), -1) 139 | return G 140 | 141 | """ The GCN layer in BAG 142 | """ 143 | def GCNLayer(self, adj, hidden_tensor, hidden_mask): 144 | with tf.variable_scope('hop_layer', reuse=tf.AUTO_REUSE): 145 | 146 | adjacency_tensor = adj 147 | hidden_tensors = tf.stack([tf.layers.dense(inputs=hidden_tensor, units=hidden_tensor.shape[-1]) 148 | for _ in range(adj.shape[1])], 1) * \ 149 | tf.expand_dims(tf.expand_dims(hidden_mask, -1), 1) 150 | 151 | update = tf.reduce_sum(tf.matmul(adjacency_tensor, hidden_tensors), 1) + tf.layers.dense( 152 | hidden_tensor, units=hidden_tensor.shape[-1]) * tf.expand_dims(hidden_mask, -1) 153 | 154 | att = tf.layers.dense(tf.concat((update, hidden_tensor), -1), units=hidden_tensor.shape[-1], 155 | activation=tf.nn.sigmoid) * tf.expand_dims(hidden_mask, -1) 156 | 157 | output = att * tf.nn.tanh(update) + (1 - att) * hidden_tensor 158 | return output 159 | 160 | """The optimizer class 161 | """ 162 | class Optimizer: 163 | 164 | def __init__(self, model, use_elmo, use_glove, use_extra_feature, pos_dict_size, ner_dict_size, max_nodes=500, 165 | max_query_size=25, max_candidates=80, glove_dim=300, 166 | query_encoding_type='lstm', dynamic_change_lr=True, use_multi_gpu=False): 167 | self.original_learning_rate = tf.placeholder(dtype=tf.float64) 168 | self.epoch = tf.placeholder(dtype=tf.int32) 169 | 170 | self.nodes_length = tf.placeholder(shape=(None,), dtype=tf.int32) 171 | self.query_length = tf.placeholder(shape=(None,), dtype=tf.int32) 172 | 173 | # self.answer_node_mask = tf.placeholder(shape=(None, max_nodes), dtype=tf.float32) 174 | self.answer_candidates_id = tf.placeholder(shape=(None,), dtype=tf.int64) 175 | 176 | self.adj = tf.placeholder(shape=(None, 3, max_nodes, max_nodes), dtype=tf.float32) 177 | self.bmask = tf.placeholder(shape=(None, max_candidates, max_nodes), dtype=tf.float32) 178 | self.dropout = tf.placeholder(dtype=tf.float32) 179 | 180 | self.use_elmo, self.use_glove, self.use_extra_feature = use_elmo, use_glove, use_extra_feature 181 | 182 | # masks 183 | nodes_mask = tf.tile(tf.expand_dims(tf.range(max_nodes, dtype=tf.int32), 0), 184 | (tf.shape(self.nodes_length)[0], 1)) < tf.expand_dims(self.nodes_length, -1) 185 | self.nodes_mask = tf.cast(nodes_mask, tf.float32) 186 | 187 | if use_elmo: 188 | self.nodes_elmo = tf.placeholder(shape=(None, max_nodes, 3, 1024), dtype=tf.float32) 189 | self.query_elmo = tf.placeholder(shape=(None, max_query_size, 3, 1024), dtype=tf.float32) 190 | if use_glove: 191 | self.nodes_glove = tf.placeholder(shape=(None, max_nodes, glove_dim), dtype=tf.float32) 192 | self.query_glove = tf.placeholder(shape=(None, max_query_size, glove_dim), dtype=tf.float32) 193 | if use_extra_feature: 194 | self.nodes_ner = tf.placeholder(shape=(None, max_nodes,), dtype=tf.int32) 195 | self.nodes_pos = tf.placeholder(shape=(None, max_nodes,), dtype=tf.int32) 196 | self.query_ner = tf.placeholder(shape=(None, max_query_size,), dtype=tf.int32) 197 | self.query_pos = tf.placeholder(shape=(None, max_query_size,), dtype=tf.int32) 198 | self.ner_dict_size = ner_dict_size 199 | self.pos_dict_size = pos_dict_size 200 | self.predictions = model.modelProcessing(self.query_length, self.adj, self.nodes_mask, self.bmask, 201 | self.nodes_elmo, self.query_elmo, self.nodes_glove, self.query_glove, self.nodes_ner, self.nodes_pos, 202 | self.query_ner, self.query_pos, self.dropout) 203 | 204 | current_lr = self.original_learning_rate 205 | if dynamic_change_lr: 206 | current_lr = self.original_learning_rate / (1 + tf.floor(self.epoch / 5)) 207 | if not use_multi_gpu: 208 | cross_entropy = tf.losses.sparse_softmax_cross_entropy( 209 | self.answer_candidates_id, 210 | self.predictions, reduction=tf.losses.Reduction.NONE) 211 | self.loss = tf.reduce_mean(cross_entropy) 212 | self.train_step = tf.train.AdamOptimizer(learning_rate=current_lr).minimize(self.loss) 213 | else: 214 | gpu_indices = [0,1] 215 | gpu_num = 2 216 | split_feature = self.splitForMultiGpu(gpu_num) 217 | count = 0 218 | opt = tf.train.AdamOptimizer(learning_rate=current_lr) 219 | tower_grads, losses = [], [] 220 | for i in gpu_indices: 221 | with tf.device('/gpu:' + str(i)): 222 | with tf.name_scope('gpu_' + str(i)): 223 | nodes_elmo, query_elmo = None, None 224 | if self.use_elmo: 225 | nodes_elmo, query_elmo = split_feature['nodes_elmo'][count], split_feature['query_elmo'][count] 226 | nodes_glove, query_glove = None, None 227 | if self.use_glove: 228 | nodes_glove, query_glove = split_feature['nodes_glove'][count], \ 229 | split_feature['query_glove'][count] 230 | nodes_ner, nodes_pos, query_ner, query_pos = None, None, None, None 231 | if self.use_extra_feature: 232 | nodes_ner, nodes_pos = split_feature['nodes_ner'][count], split_feature['nodes_pos'][count] 233 | query_ner, query_pos = split_feature['query_ner'][count], split_feature['query_pos'][count] 234 | predictions = model.modelProcessing(split_feature['query_length'][count], 235 | split_feature['adj'][count], split_feature['nodes_mask'][count], split_feature['bmask'][count], 236 | nodes_elmo, query_elmo, nodes_glove, query_glove, nodes_ner, nodes_pos, query_ner, 237 | query_pos, self.dropout) 238 | cross_entropy = tf.losses.sparse_softmax_cross_entropy( 239 | split_feature['answer_candidates_id'][count], 240 | predictions, reduction=tf.losses.Reduction.NONE) 241 | loss = tf.reduce_mean(cross_entropy) 242 | losses.append(loss) 243 | tower_grads.append(opt.compute_gradients(loss)) 244 | count += 1 245 | self.loss = tf.reduce_mean(losses) 246 | grad = self.averageGradients(tower_grads) 247 | self.train_step = opt.apply_gradients(grad) 248 | 249 | """ Average the gradient calculated by different GPUs 250 | """ 251 | def averageGradients(self, tower_grads): 252 | average_grad = [] 253 | for grad_and_vars in zip(*tower_grads): 254 | grads = [] 255 | for g, _ in grad_and_vars: 256 | expend_g = tf.expand_dims(g, 0) 257 | grads.append(expend_g) 258 | grad = tf.concat(grads, 0) 259 | grad = tf.reduce_mean(grad, 0) 260 | v = grad_and_vars[0][1] 261 | grad_and_var = (grad, v) 262 | average_grad.append(grad_and_var) 263 | return average_grad 264 | 265 | """ Split training data for multiple GPUs (here we only support 2 GPU now) 266 | """ 267 | def splitForMultiGpu(self, gpu_num): 268 | 269 | query_length = tf.split(self.query_length, gpu_num) 270 | answer_candidates_id = tf.split(self.answer_candidates_id, gpu_num) 271 | adj = tf.split(self.adj, gpu_num) 272 | bmask = tf.split(self.bmask, gpu_num) 273 | nodes_mask = tf.split(self.nodes_mask, gpu_num) 274 | nodes_elmo, query_elmo = None, None 275 | nodes_glove, query_glove = None, None 276 | nodes_ner, nodes_pos, query_ner, query_pos = None, None, None, None 277 | if self.use_elmo: 278 | nodes_elmo = tf.split(self.nodes_elmo, gpu_num) 279 | query_elmo = tf.split(self.query_elmo, gpu_num) 280 | if self.use_glove: 281 | nodes_glove = tf.split(self.nodes_glove, gpu_num) 282 | query_glove = tf.split(self.query_glove, gpu_num) 283 | if self.use_extra_feature: 284 | nodes_ner = tf.split(self.nodes_ner, gpu_num) 285 | nodes_pos = tf.split(self.nodes_pos, gpu_num) 286 | query_ner = tf.split(self.query_ner, gpu_num) 287 | query_pos = tf.split(self.query_pos, gpu_num) 288 | return {'query_length': query_length, 'answer_candidates_id': answer_candidates_id, 'adj': adj, 'bmask': bmask, 289 | 'nodes_elmo': nodes_elmo, 'query_elmo': query_elmo, 'nodes_mask': nodes_mask, 'nodes_glove': nodes_glove, 290 | 'query_glove': query_glove, 'nodes_ner': nodes_ner, 'nodes_pos': nodes_pos, 'query_ner': query_ner, 291 | 'query_pos': query_pos} 292 | 293 | """ Check whether the preprocessed file existed in current directory 294 | """ 295 | def checkPreprocessFile(file_name, add_query_node): 296 | preprocess_file_name = file_name 297 | if add_query_node: 298 | preprocess_file_name = preprocess_file_name + '.add_query_node' 299 | if not os.path.isfile('{}.preprocessed.pickle'.format(preprocess_file_name)): 300 | return preprocess_file_name, False 301 | return preprocess_file_name, True 302 | 303 | 304 | def runEvaluationStage(dev_dataset, session, use_elmo, use_glove, use_extra_feature, model=None, 305 | batch_size=16, save_json=False): 306 | finished = False 307 | eval_correct_count = 0 308 | eval_sample_count = 0 309 | eval_interval_count = 0 310 | answer_dict = {} 311 | while not finished: 312 | finished, batch = dev_dataset.next_batch(batch_size) 313 | feed_dict = {optimizer.nodes_length: batch['nodes_length_mb'], 314 | optimizer.query_length: batch['query_length_mb'], 315 | optimizer.adj: batch['adj_mb'], 316 | optimizer.bmask: batch['bmask_mb'], 317 | optimizer.dropout: 1} 318 | if use_elmo: 319 | feed_dict[optimizer.nodes_elmo] = batch['nodes_elmo_mb'] 320 | feed_dict[optimizer.query_elmo] = batch['query_elmo_mb'] 321 | if use_glove: 322 | feed_dict[optimizer.nodes_glove] = batch['nodes_glove_mb'] 323 | feed_dict[optimizer.query_glove] = batch['query_glove_mb'] 324 | if use_extra_feature: 325 | feed_dict[optimizer.nodes_pos] = batch['nodes_pos_mb'] 326 | feed_dict[optimizer.nodes_ner] = batch['nodes_ner_mb'] 327 | feed_dict[optimizer.query_ner] = batch['query_ner_mb'] 328 | feed_dict[optimizer.query_pos] = batch['query_pos_mb'] 329 | preds = np.argmax(session.run(optimizer.predictions, feed_dict), -1) 330 | eval_correct_count += (preds == batch['answer_candidates_id_mb']).sum() 331 | eval_sample_count += len(batch['query_length_mb']) 332 | eval_interval_count += len(batch['query_length_mb']) 333 | if eval_interval_count >= training_info_interval: 334 | logger.info('%s Dev samples has been done, accuracy = %.3f', eval_sample_count, 335 | eval_correct_count / eval_sample_count) 336 | eval_interval_count -= training_info_interval 337 | if save_json: 338 | for index, id in enumerate(batch['id_mb']): 339 | answer_dict[id] = preds[index] 340 | return eval_correct_count / eval_sample_count, answer_dict 341 | 342 | def add_attribute_to_collection(model, optimizer): 343 | tf.add_to_collection('nodes_length', optimizer.nodes_length) 344 | tf.add_to_collection('query_length', optimizer.query_length) 345 | tf.add_to_collection('answer_candidates_id', optimizer.answer_candidates_id) 346 | tf.add_to_collection('adj', optimizer.adj) 347 | tf.add_to_collection('bmask', optimizer.bmask) 348 | tf.add_to_collection('train_step', optimizer.train_step) 349 | tf.add_to_collection('loss', optimizer.loss) 350 | tf.add_to_collection('predictions', optimizer.predictions) 351 | # tf.add_to_collection('predictions2', model.predictions2) 352 | tf.add_to_collection('original_learning_rate', optimizer.original_learning_rate) 353 | tf.add_to_collection('epoch', optimizer.epoch) 354 | tf.add_to_collection('use_elmo', optimizer.use_elmo) 355 | tf.add_to_collection('use_glove', optimizer.use_glove) 356 | tf.add_to_collection('use_extra_feature', optimizer.use_extra_feature) 357 | if model.use_elmo: 358 | tf.add_to_collection('nodes_elmo', optimizer.nodes_elmo) 359 | tf.add_to_collection('query_elmo', optimizer.query_elmo) 360 | if model.use_glove: 361 | tf.add_to_collection('nodes_glove', optimizer.nodes_glove) 362 | tf.add_to_collection('query_glove', optimizer.query_glove) 363 | if model.use_extra_feature: 364 | tf.add_to_collection('nodes_pos', optimizer.nodes_pos) 365 | tf.add_to_collection('nodes_ner', optimizer.nodes_ner) 366 | tf.add_to_collection('query_pos', optimizer.query_pos) 367 | tf.add_to_collection('query_ner', optimizer.query_ner) 368 | tf.add_to_collection('ner_dict_size', optimizer.ner_dict_size) 369 | tf.add_to_collection('pos_dict_size', optimizer.pos_dict_size) 370 | 371 | def write_best_accuray_txt(model_path, accuracy): 372 | file_name = model_path + '/best.txt' 373 | if os.path.isfile(file_name): 374 | os.remove(file_name) 375 | with open(file_name, 'w') as f: 376 | f.write(str(accuracy) + '\n') 377 | f.close() 378 | 379 | def generate_answer_json(answer_dict, in_file): 380 | with open(in_file, 'r') as f: 381 | data = json.load(f) 382 | final_answer_dict = {} 383 | for d in data: 384 | final_answer_dict[d['id']] = d['candidates'][answer_dict[d['id']]] 385 | with open('predictions.json', 'w') as f: 386 | json.dump(final_answer_dict, f) 387 | 388 | class LoadedOptimizer: 389 | 390 | def __init__(self): 391 | self.train_step = tf.get_collection('train_step')[0] 392 | self.loss = tf.get_collection('loss')[0] 393 | self.original_learning_rate = tf.get_collection('original_learning_rate')[0] 394 | self.epoch = tf.get_collection('epoch')[0] 395 | self.nodes_length = tf.get_collection('nodes_length')[0] 396 | self.query_length = tf.get_collection('query_length')[0] 397 | self.answer_candidates_id = tf.get_collection('answer_candidates_id')[0] 398 | self.adj = tf.get_collection('adj')[0] 399 | self.bmask = tf.get_collection('bmask')[0] 400 | self.use_elmo = tf.get_collection('use_elmo')[0] 401 | self.use_glove = tf.get_collection('use_glove')[0] 402 | self.use_extra_feature = tf.get_collection('use_extra_feature')[0] 403 | self.predictions = tf.get_collection('predictions')[0] 404 | if self.use_elmo: 405 | self.nodes_elmo = tf.get_collection('nodes_elmo')[0] 406 | self.query_elmo = tf.get_collection('query_elmo')[0] 407 | if self.use_glove: 408 | self.nodes_glove = tf.get_collection('nodes_glove')[0] 409 | self.query_glove = tf.get_collection('query_glove')[0] 410 | if self.use_extra_feature: 411 | self.nodes_pos = tf.get_collection('nodes_pos')[0] 412 | self.nodes_ner = tf.get_collection('nodes_ner')[0] 413 | self.query_pos = tf.get_collection('query_pos')[0] 414 | self.query_ner = tf.get_collection('query_ner')[0] 415 | self.pos_dict_size = tf.get_collection('pos_dict_size')[0] 416 | self.ner_dict_size = tf.get_collection('ner_dict_size')[0] 417 | 418 | if __name__ == '__main__': 419 | parser = argparse.ArgumentParser() 420 | parser.add_argument('in_file', type=str) 421 | parser.add_argument('dev_file', type=str) 422 | parser.add_argument('--is_masked', type=bool, default=False, help='using masked data or not') 423 | parser.add_argument('--continue_train', type=bool, default=False, help='continue last train using saved model or' + 424 | ' start a new training') 425 | parser.add_argument('--cpu_number', type=int, default=4, help='The number of CPUs used in current script') 426 | parser.add_argument('--add_query_node', type=bool, default=False, help='Whether the entity in query is involved '+ 427 | 'in the construction of graph') 428 | parser.add_argument('--evaluation_mode', type=bool, default=False, help='Whether using evaluation mode') 429 | parser.add_argument('--use_elmo', type=str2bool, default=True, help='Whether use the ELMo as a feature for each node') 430 | parser.add_argument('--use_glove', type=str2bool, default=True, help='Whether use the GloVe as a feature for each node') 431 | parser.add_argument('--use_extra_feature', type=str2bool, default=True, help='Whether use extra feature for ' + 432 | 'each node, e.g. NER and POS') 433 | parser.add_argument('--use_multi_gpu', type=str2bool, default=False, help='Whether use multiple GPUs') 434 | parser.add_argument('--lr', type=float, default=2e-4, help='Initial learning rate') 435 | parser.add_argument('--hop_num', type=int, default=5, help='Hop num in GCN layer, in other words the depth of GCN') 436 | parser.add_argument('--epochs', type=int, default=50, help='Epoch number for the training') 437 | parser.add_argument('--batch_size', type=int, default=32, help='Batch size') 438 | parser.add_argument('--info_interval', type=int, default=1000, help='The interval to display training loss info') 439 | parser.add_argument('--dropout', type=float, default=0.8, help="Keep rate for dropout in model") 440 | parser.add_argument('--encoding_size', type=int, default=512, help='The encoding output size for both query and nodes') 441 | parser.add_argument('--pos_emb_size', type=int, default=8, help='The size of POS embedding') 442 | parser.add_argument('--ner_emb_size', type=int, default=8, help='The size of NER embedding') 443 | 444 | args = parser.parse_args() 445 | in_file = args.in_file 446 | dev_file = args.dev_file 447 | is_masked = args.is_masked 448 | continue_train = args.continue_train 449 | evaluation_mode = args.evaluation_mode 450 | cpu_number = args.cpu_number 451 | add_query_node = args.add_query_node 452 | use_elmo = args.use_elmo 453 | use_glove = args.use_glove 454 | use_extra_feature = args.use_extra_feature 455 | use_multi_gpu = args.use_multi_gpu 456 | learning_rate = args.lr 457 | hops = args.hop_num 458 | epochs = args.epochs 459 | batch_size = args.batch_size 460 | training_info_interval = args.info_interval 461 | dropout = args.dropout 462 | encoding_size = args.encoding_size 463 | pos_emb_size = args.pos_emb_size 464 | ner_emb_size = args.ner_emb_size 465 | 466 | options_file = 'data/elmo_2x4096_512_2048cnn_2xhighway_options.json' 467 | weight_file = 'data/elmo_2x4096_512_2048cnn_2xhighway_weights' 468 | encoding_type_map = {'lstm':'lstm','linear':'linear'} 469 | 470 | model_name = 'BAG' 471 | if evaluation_mode: 472 | logger = config_logger('evaluation/' + model_name) 473 | else: 474 | logger = config_logger('BAG') 475 | 476 | model_path = os.getcwd() + '/models/' + model_name + '/' 477 | if not os.path.exists(model_path): 478 | os.makedirs(model_path) 479 | 480 | tokenize = TweetTokenizer().tokenize 481 | logger.info('Hop number is %s', hops) 482 | logger.info('Learning rate is %s', learning_rate) 483 | logger.info('Training epoch is %s', epochs) 484 | logger.info('Batch size is %s', batch_size) 485 | logger.info('Dropout rate is %f', dropout) 486 | logger.info('Encoding size for nodes and query feature is %s', encoding_size) 487 | query_feature_type = encoding_type_map['lstm'] 488 | logger.info('Encoding type for query feature is %s', query_feature_type) 489 | dynamic_change_learning_rate = True 490 | logger.info('Is learning rate changing along with epoch count: %s', dynamic_change_learning_rate) 491 | 492 | tf.reset_default_graph() 493 | 494 | train_file_name_prefix, fileExist = checkPreprocessFile(in_file, add_query_node) 495 | if not fileExist: 496 | logger.info('Cannot find preprocess data %s, program will shut down.', 497 | '{}.preprocessed.pickle'.format(train_file_name_prefix)) 498 | sys.exit() 499 | dev_file_name_prefix, fileExist = checkPreprocessFile(dev_file, add_query_node) 500 | if not fileExist: 501 | logger.info('Cannot find preprocess data %s, program will shut down.', 502 | '{}.preprocessed.pickle'.format(dev_file_name_prefix)) 503 | sys.exit() 504 | if not evaluation_mode: 505 | logger.info('Loading preprocessed training data file %s', '{}.preprocessed.pickle'.format(train_file_name_prefix)) 506 | dataset = Dataset(train_file_name_prefix, use_elmo, use_glove, use_extra_feature, max_nodes=500, 507 | max_query_size=25, max_candidates=80, max_candidates_len=10) 508 | logger.info('Loading preprocessed development data file %s', '{}.preprocessed.pickle'.format(dev_file_name_prefix)) 509 | dev_dataset = Dataset(dev_file_name_prefix, use_elmo, use_glove, use_extra_feature, max_nodes=500, 510 | max_query_size=25, max_candidates=80, max_candidates_len=10) 511 | else: 512 | logger.info('Loading preprocessed evaluation data file %s', 513 | '{}.preprocessed.pickle'.format(dev_file_name_prefix)) 514 | dataset = Dataset(dev_file_name_prefix, use_elmo, use_glove, use_extra_feature, max_nodes=500, 515 | max_query_size=25, max_candidates=80, max_candidates_len=10) 516 | 517 | pos_dict_size, ner_dict_size = 0, 0 518 | if use_extra_feature: 519 | pos_dict_size, ner_dict_size = dataset.getPosAndNerDictSize() 520 | 521 | config = tf.ConfigProto(device_count={'GPU': 2, 'CPU': cpu_number}, allow_soft_placement=True) 522 | # config.gpu_options.per_process_gpu_memory_fraction = memory_fraction 523 | session = tf.Session(config=config) 524 | 525 | if (not continue_train) and (not evaluation_mode): 526 | model = Model(use_elmo, use_glove, use_extra_feature, encoding_size, pos_emb_size, ner_emb_size, pos_dict_size, 527 | ner_dict_size) 528 | optimizer = Optimizer(model, use_elmo, use_glove, use_extra_feature, pos_dict_size, ner_dict_size, 529 | dynamic_change_lr=dynamic_change_learning_rate, use_multi_gpu=use_multi_gpu) 530 | logger.info('Start a new training') 531 | saver = tf.train.Saver() 532 | session.run(tf.global_variables_initializer()) 533 | session_op = [optimizer.train_step, optimizer.loss] 534 | add_attribute_to_collection(model, optimizer) 535 | else: 536 | logger.info('Using previously trained model in /models') 537 | model = Model(use_elmo, use_glove, use_extra_feature, encoding_size, pos_emb_size, ner_emb_size, pos_dict_size, 538 | ner_dict_size) 539 | optimizer = Optimizer(model, use_elmo, use_glove, use_extra_feature, pos_dict_size, ner_dict_size, 540 | dynamic_change_lr=dynamic_change_learning_rate, use_multi_gpu=use_multi_gpu) 541 | saver = tf.train.Saver() 542 | saver.restore(session, './models/' + model_name + '/model.ckpt') 543 | session_op = [optimizer.train_step, optimizer.loss] 544 | 545 | if not evaluation_mode: 546 | logger.info('=============================') 547 | logger.info('Starting Training.....') 548 | best_accuracy = 0 549 | best_test_accuracy = 0 550 | for i in range(epochs): 551 | logger.info('=============================') 552 | logger.info('Starting Training Epoch %s', i + 1) 553 | sample_count = 0 554 | epoch_finished = False 555 | interval_count = 0 556 | loss_count = 0 557 | loss_sum = 0 558 | problem_index = [] 559 | while not epoch_finished: 560 | epoch_finished, batch = dataset.next_batch(batch_dim=batch_size, use_multi_gpu=use_multi_gpu) 561 | feed_dict = {optimizer.nodes_length: batch['nodes_length_mb'], 562 | optimizer.query_length: batch['query_length_mb'], 563 | # model.answer_node_mask: batch['answer_node_mask_mb'], 564 | optimizer.answer_candidates_id: batch['answer_candidates_id_mb'], 565 | optimizer.adj: batch['adj_mb'], 566 | optimizer.bmask: batch['bmask_mb'], 567 | optimizer.original_learning_rate: learning_rate, 568 | optimizer.epoch: i, optimizer.dropout: dropout} 569 | if use_elmo: 570 | feed_dict[optimizer.nodes_elmo] = batch['nodes_elmo_mb'] 571 | feed_dict[optimizer.query_elmo] = batch['query_elmo_mb'] 572 | if use_glove: 573 | feed_dict[optimizer.nodes_glove] = batch['nodes_glove_mb'] 574 | feed_dict[optimizer.query_glove] = batch['query_glove_mb'] 575 | if use_extra_feature: 576 | feed_dict[optimizer.nodes_pos] = batch['nodes_pos_mb'] 577 | feed_dict[optimizer.nodes_ner] = batch['nodes_ner_mb'] 578 | feed_dict[optimizer.query_ner] = batch['query_ner_mb'] 579 | feed_dict[optimizer.query_pos] = batch['query_pos_mb'] 580 | 581 | _, loss = session.run(session_op, feed_dict=feed_dict) 582 | sample_count += len(batch['query_length_mb']) 583 | interval_count += len(batch['query_length_mb']) 584 | if not math.isinf(loss): 585 | loss_sum += loss 586 | loss_count += 1 587 | if interval_count >= training_info_interval: 588 | avg_loss = 100 589 | if loss_count != 0: 590 | avg_loss = loss_sum / loss_count 591 | logger.info('%s training samples has been done, loss = %.5f', sample_count, avg_loss) 592 | interval_count -= training_info_interval 593 | loss_sum = 0 594 | loss_count = 0 595 | 596 | logger.info('Epoch %s has been done', i + 1) 597 | logger.info('-----------------------------') 598 | logger.info('Running the evaluation stage') 599 | accuracy, _ = runEvaluationStage(dev_dataset, session, use_elmo, use_glove, use_extra_feature, 600 | model=model) 601 | if accuracy > best_accuracy: 602 | logger.info('Evaluation stage finished') 603 | logger.info('Current model beats the previous best accuracy on dev, previous=%.4f, current=%.4f', 604 | best_accuracy, accuracy) 605 | best_accuracy = accuracy 606 | saver.save(session, model_path + 'model.ckpt') 607 | write_best_accuray_txt(model_path, best_accuracy) 608 | logger.info('Current best model has been saved.') 609 | else: 610 | logger.info('Evaluation stage finished') 611 | logger.info( 612 | 'Current model did not beat the previous best accuracy on dev, previous=%.3f, current=%.3f', 613 | best_accuracy, accuracy) 614 | logger.info('=============================') 615 | else: 616 | logger.info('=============================') 617 | logger.info('Starting Evaluation.....') 618 | accuracy, answer_dict = runEvaluationStage(dataset, session, use_elmo, use_glove, use_extra_feature, 619 | model=model, save_json=True) 620 | generate_answer_json(answer_dict, in_file) 621 | logger.info('Current accuracy=%.5f', accuracy) 622 | logger.info('Evaluation stage finished') 623 | logger.info('=============================') 624 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BAG 2 | Implementation for NAACL-2019 paper 3 | 4 | **BAG: Bi-directional Attention Entity Graph Convolutional Network forMulti-hop Reasoning Question Answering** 5 | 6 | [paper link](http://arxiv.org/abs/1904.04969) 7 | 8 | ![BAG Framework](https://github.com/caoyu1991/BAG/blob/master/BAG.png) 9 | 10 | ## Requirement 11 | We provided main entrance in both **_TensorFlow version_** (BAG.py) and **_PyTorch version_** (BAG-pytorch.py) 12 | 1. Python 3.6 13 | 2. TensorFlow == 1.11.0 (if you want to run TF version script, We are not sure if it works at higher version) 14 | 3. Pytorch >= 1.1.0 15 | 4. SpaCy >= 2.0.12 (You need to install "en" module via "python -m spacy download en") 16 | 5. allennlp >= 0.7.1 17 | 6. nltk >= 3.3 18 | 7. pytorch-ignite (if your need to run the script in PyTorch version) 19 | 20 | And some other packages 21 | 22 | I run it using two NVIDIA GTX1080Ti GPUs each one has 11GB memory. To run it with 23 | default batch size 32, at least 16GB GPU memory is needed. To run the preprocessing 24 | procedure on the whole dataset, at least 50GB system memory is needed. 25 | 26 | ## How to run 27 | - Before run 28 | 29 | You need to download pretrained 840B 300d [GLoVe embeddings](http://nlp.stanford.edu/data/glove.840B.300d.zip), 30 | and pretrained original size ELMo embedding [weights](https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5) 31 | and [options](https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json) and put them under directory _/data_. 32 | 33 | - Preprocessing dataset 34 | 35 | You need to download [QAngaroo WIKIHOP dataset](https://drive.google.com/file/d/1ytVZ4AhubFDOEL7o7XrIRIyhU8g9wvKA/view) 36 | , unzip it and put the json files under the root directory. Then run prerpocessing script 37 | 38 | `python prepro.py {json_file_name}` 39 | 40 | It will generate four preprocessed pickle file in the root directory which will be used 41 | in the training or prediction. 42 | 43 | - Train the model 44 | 45 | Train the model using following command which will follow the configure 46 | in original paper 47 | 48 | `python BAG.py {train_json_file_name} {dev_json_file_name} --use_multi_gpu=true` 49 | 50 | or in pytorch version (we do not provide multi-gpu support yet for pytorch, the most simple way is wrapping model with 51 | nn.DataParallel) 52 | 53 | `python BAG-pytorch.py {train_json_file_name} {dev_json_file_name}` 54 | 55 | Please make sure you have run preprocessing for both train file and dev 56 | file before training. And please make sure you have CUDA0 and CUDA1 available. 57 | If you have single GPU with more than 16GB memory, you can remove parameter 58 | _--use_multi_gpu_. 59 | 60 | - Predict 61 | 62 | After training it will put trained model onto directory _/models_. 63 | You can predict the answer of a json file using following command 64 | 65 | `python BAG.py {predict_json_file_name} {predict_json_file_name} --use_multi_gpu=true --evaluation_mode=true` 66 | 67 | - Trained model 68 | 69 | Anyone who needs the trained model in our submission can find it on the [Codalab](https://worksheets.codalab.org/bundles/0x26949d12bc5845c2a341b2ede40986f1) (Only TF version is available) 70 | 71 | ## Acknowledgement 72 | 73 | We would like to appreciate Nicola De Cao [link](https://nicola-decao.github.io/) for his assistance in implementing this project. 74 | 75 | ## Reference 76 | ``` 77 | @inproceedings{cao2019bag, 78 | title={BAG: Bi-directional Attention Entity Graph Convolutional Network for Multi-hop Reasoning Question Answering}, 79 | author={Cao, Yu and Fang, Meng and Tao, Dacheng}, 80 | booktitle={Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers)}, 81 | pages={357--362}, 82 | year={2019} 83 | } 84 | ``` 85 | -------------------------------------------------------------------------------- /data/ner_dict.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoyu-noob/BAG/180538a8e0de3a6a5465802a1e10feee4d564dd2/data/ner_dict.pickle -------------------------------------------------------------------------------- /data/pos_dict.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoyu-noob/BAG/180538a8e0de3a6a5465802a1e10feee4d564dd2/data/pos_dict.pickle -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | import pickle 5 | import numpy as np 6 | import nltk 7 | import spacy 8 | import unicodedata 9 | 10 | from allennlp.commands.elmo import ElmoEmbedder 11 | from progressbar import ProgressBar, Percentage, Bar, Timer, ETA 12 | from utils.str2bool import str2bool 13 | from utils.ConfigLogger import config_logger 14 | from nltk.tokenize import TweetTokenizer 15 | 16 | nltk.internals.config_java(options='-Xmx4g') 17 | 18 | class Preprocesser: 19 | def __init__(self, file_name, logger, is_masked=False, use_elmo=True, use_glove=True, use_extra_feature=True, 20 | use_bert=True, options_file=None, weight_file=None, glove_file=None): 21 | self.file_name = file_name 22 | self.is_masked = is_masked 23 | self.use_elmo = use_elmo 24 | self.use_glove = use_glove 25 | self.use_extra_feature = use_extra_feature 26 | self.use_bert = use_bert 27 | self.options_file = options_file 28 | self.weight_file = weight_file 29 | self.glove_file = glove_file 30 | self.tokenizer = TweetTokenizer() 31 | self.nlp = spacy.load('en', disable=['vectors', 'textcat', 'parser']) 32 | self.tag_dict = {'': 0, '': 1, '': 2, '': 3} 33 | self.ner_dict = {'': 0, '': 1, '': 2, '': 3} 34 | self.logger = logger 35 | self.elmo_split_interval = 16 36 | self.bert_split_interval = 4 37 | self.max_support_length = 512 38 | 39 | """ check whether string of current span matches the candidate 40 | """ 41 | def check(self, support, word_index, candidate, for_unmarked=False): 42 | if for_unmarked: 43 | return sum( 44 | [self.is_contain_special_symbol(c_, support[word_index + j].lower()) for j, c_ in enumerate(candidate) if 45 | word_index + j < len(support)]) == len(candidate) 46 | else: 47 | return sum([support[word_index + j].lower() == c_ for j, c_ in enumerate(candidate) if 48 | word_index + j < len(support)]) == len(candidate) 49 | 50 | 51 | def is_contain_special_symbol(self, candidate_tok, support_tok): 52 | if candidate_tok.isdigit(): 53 | return support_tok.find(candidate_tok) >= 0 54 | else: 55 | return support_tok == candidate_tok or candidate_tok + 's' == support_tok or \ 56 | (support_tok.find(candidate_tok) >= 0 and ( 57 | support_tok.find('-') > 0 or support_tok.find('\'s') > 0 or 58 | support_tok.find(',') > 0)) 59 | 60 | """ Check whether the mask is valid via its length 61 | """ 62 | def check_masked(self, support, word_index, candidate): 63 | return sum([support[word_index + j] == c_ for j, c_ in enumerate(candidate) if 64 | word_index + j < len(support)]) == len(candidate) 65 | 66 | """ generating index for candidates in the original document 67 | """ 68 | def ind(self, support_index, word_index, candidate_index, candidate, marked_candidate): 69 | marked_candidate[candidate_index] = True 70 | return [[support_index, word_index + i, candidate_index] for i in range(len(candidate))] 71 | 72 | """ some candidates may not be found in the original document so we have to merge it with the node masks who were 73 | found in original document 74 | """ 75 | def merge_two_masks(self, mask, unmarked_mask): 76 | for i in range(len(mask)): 77 | if len(unmarked_mask[i]) != 0: 78 | if len(mask[i]) == 0: 79 | mask[i] = unmarked_mask[i] 80 | else: 81 | for unmarked_index in range(len(unmarked_mask[i])): 82 | mask[i].append(unmarked_mask[i][unmarked_index]) 83 | mask[i].sort(key=lambda x: x[0][1]) 84 | return mask 85 | 86 | """ if some new POS or NER tags are found in data, we need to merge it with previous POS or NER dict 87 | """ 88 | def mergeTwoDictFile(self, file_name, dict): 89 | with open(file_name, 'rb') as f: 90 | prev_dict = pickle.load(f) 91 | for name in dict: 92 | if not prev_dict.__contains__(name): 93 | prev_dict[name] = len(prev_dict) 94 | with open(file_name, 'wb') as f: 95 | pickle.dump(prev_dict, f) 96 | 97 | """ The main function to pre-processing WIKIHOP dataset and save it as several pickle files 98 | """ 99 | def preprocess(self): 100 | preprocess_graph_file_name = self.file_name 101 | preprocess_elmo_file_name = self.file_name + '.elmo' 102 | preprocess_glove_file_name = self.file_name + '.glove' 103 | preprocess_extra_file_name = self.file_name + '.extra' 104 | supports = self.doPreprocessForGraph(preprocess_graph_file_name) 105 | with open('{}.preprocessed.pickle'.format(preprocess_graph_file_name), 'rb') as f: 106 | data_graph = [d for d in pickle.load(f)] 107 | # text data including supporting documents, queries and node mask 108 | text_data = [] 109 | for index, graph_d in enumerate(data_graph): 110 | tmp = {} 111 | tmp['query'] = graph_d['query'] 112 | tmp['query_full_token'] = graph_d['query_full_token'] 113 | tmp['nodes_mask'] = graph_d['nodes_mask'] 114 | tmp['candidates'] = graph_d['candidates'] 115 | tmp['nodes_candidates_id'] = graph_d['nodes_candidates_id'] 116 | tmp['supports'] = supports[index]['supports'] 117 | text_data.append(tmp) 118 | if self.use_elmo: 119 | self.doPreprocessForElmo(text_data, preprocess_elmo_file_name) 120 | if self.use_glove: 121 | self.doPreprocessForGlove(text_data, preprocess_glove_file_name) 122 | if self.use_extra_feature: 123 | self.doPreprocessForExtraFeature(text_data, preprocess_extra_file_name) 124 | 125 | """ Build entity graph base on input json data and save graph as a pickle 126 | """ 127 | def doPreprocessForGraph(self, preprocess_graph_file_name): 128 | with open(self.file_name, 'r') as f: 129 | data = json.load(f) 130 | self.logger.info('Load json file:' + self.file_name) 131 | supports = self.doPreprocess(data, mode='supports') 132 | if not os.path.isfile('{}.preprocessed.pickle'.format(preprocess_graph_file_name)): 133 | self.logger.info('Preprocsssing Json data for Graph....') 134 | data = self.doPreprocess(data, mode='graph', supports=supports) 135 | self.logger.info('Preprocessing Graph data finished') 136 | with open('{}.preprocessed.pickle'.format(preprocess_graph_file_name), 'wb') as f: 137 | pickle.dump(data, f) 138 | self.logger.info('Successfully save preprocessed Graph data file %s', 139 | '{}.preprocessed.pickle'.format(preprocess_graph_file_name)) 140 | else: 141 | self.logger.info('Preprocessed Graph data is already existed, no preprocessing will be executed.') 142 | return supports 143 | 144 | """ Generating pickle file for ELMo embeddings of queries and nodes in graph 145 | """ 146 | def doPreprocessForElmo(self, text_data, preprocess_elmo_file_name): 147 | if not os.path.isfile('{}.preprocessed.pickle'.format(preprocess_elmo_file_name)): 148 | elmoEmbedder = ElmoEmbedder(cuda_device=0, options_file=self.options_file, weight_file=self.weight_file) 149 | self.logger.info('Preprocsssing Json data for Elmo....') 150 | data = self.doPreprocess(text_data, mode='elmo', ee=elmoEmbedder) 151 | self.logger.info('Preprocessing Elmo data finished') 152 | with open('{}.preprocessed.pickle'.format(preprocess_elmo_file_name), 'wb') as f: 153 | pickle.dump(data, f) 154 | self.logger.info('Successfully save preprocessed Elmo data file %s', 155 | '{}.preprocessed.pickle'.format(preprocess_elmo_file_name)) 156 | else: 157 | self.logger.info('Preprocessed Elmo data is already existed, no preprocessing will be executed.') 158 | 159 | """ Generating pickle file for GLoVe embeddings of queries and nodes in graph 160 | """ 161 | def doPreprocessForGlove(self, text_data, preprocess_glove_file_name): 162 | if self.use_glove: 163 | if not os.path.isfile('{}.preprocessed.pickle'.format(preprocess_glove_file_name)): 164 | self.logger.info('Building vocabulary dictionary....') 165 | vocab2index, index2vocab = self.buildVocabMap(text_data) 166 | self.logger.info('Building GloVe Embedding Map....') 167 | gloveEmbMap = self.buildGloveEmbMap(vocab2index) 168 | self.logger.info('Prerpocessing Json data for Glove....') 169 | data_glove = self.doPreprocess(text_data, mode='glove', gloveEmbMap=gloveEmbMap, vocab2index=vocab2index) 170 | with open('{}.preprocessed.pickle'.format(preprocess_glove_file_name), 'wb') as f: 171 | pickle.dump(data_glove, f) 172 | self.logger.info('Successfully save preprocessed Glove data file %s', 173 | '{}.preprocessed.pickle'.format(preprocess_glove_file_name)) 174 | else: 175 | self.logger.info('Preprocessed Glove data is already existed, no preprocessing will be executed.') 176 | 177 | """ Generating pickle file for extra feature (NER, POS) of nodes and queries 178 | """ 179 | def doPreprocessForExtraFeature(self, text_data, preprocess_extra_file_name): 180 | if not os.path.isfile('{}.preprocessed.pickle'.format(preprocess_extra_file_name)): 181 | data_extra = self.doPreprocess(text_data, mode='extra') 182 | with open('{}.preprocessed.pickle'.format(preprocess_extra_file_name), 'wb') as f: 183 | pickle.dump(data_extra[0], f) 184 | self.logger.info('Successfully save preprocessed Extra feature data file %s', 185 | '{}.preprocessed.pickle'.format(preprocess_extra_file_name)) 186 | if not os.path.isfile('data/pos_dict.pickle'): 187 | with open('data/pos_dict.pickle', 'wb') as f: 188 | pickle.dump(data_extra[2], f) 189 | self.logger.info('Successfully save pos dict data file pos_dict.pickle') 190 | else: 191 | self.mergeTwoDictFile('data/pos_dict.pickle', data_extra[2]) 192 | self.logger.info('Successfully merge current pos dict with pos_dict.pickle') 193 | if not os.path.isfile('data/ner_dict.pickle'): 194 | with open('data/ner_dict.pickle', 'wb') as f: 195 | pickle.dump(data_extra[1], f) 196 | self.logger.info('Successfully save ner dict data file ner_dict.pickle') 197 | else: 198 | self.mergeTwoDictFile('data/ner_dict.pickle', data_extra[1]) 199 | self.logger.info('Successfully merge current ner dict with ner_dict.pickle') 200 | else: 201 | self.logger.info('Preprocessed Extra data is already existed, no preprocessing will be executed.') 202 | 203 | """ Core preprocessing function 204 | """ 205 | def doPreprocess(self, data_mb, mode, supports=None, ee=None, gloveEmbMap=None, vocab2index=None, bert_model=None): 206 | data_gen = [] 207 | widgets = ['Progress: ', Percentage(), ' ', Bar('#'), ' ', Timer(), 208 | ' ', ETA(), ' '] 209 | pbar = ProgressBar(widgets=widgets, maxval=len(data_mb)).start() 210 | 211 | data_count = 0 212 | for index, data in enumerate(data_mb): 213 | # if index < 19999: 214 | # data_count += 1 215 | # continue 216 | # try: 217 | if mode == 'supports': 218 | tmp = {} 219 | tmp['supports'] = [self.tokenizer.tokenize(support) for support in data['supports']] 220 | for index in range(len(tmp['supports'])): 221 | if len(tmp['supports'][index]) > self.max_support_length: 222 | tmp['supports'][index] = tmp['supports'][index][:self.max_support_length] 223 | data_gen.append(tmp) 224 | elif mode == 'graph': 225 | preprocessGraphData = self.preprocessForGraph(data, supports[index]['supports']) 226 | data_gen.append(preprocessGraphData) 227 | elif mode == 'elmo': 228 | preprocessElmoData = self.preprocessForElmo(data, ee) 229 | data_gen.append(preprocessElmoData) 230 | elif mode == 'glove': 231 | preprocessGloveData = self.preprocessForGlove(data, gloveEmbMap, vocab2index) 232 | data_gen.append(preprocessGloveData) 233 | elif mode == 'extra': 234 | preprocessExtraData = self.preprocessForExtra(data) 235 | data_gen.append(preprocessExtraData) 236 | elif mode == 'bert': 237 | preprocessBertData = self.preprocessForBert(data, bert_model) 238 | data_gen.append(preprocessBertData) 239 | # except: 240 | # print(index) 241 | # pass 242 | # data_count += 1 243 | # pbar.update(data_count) 244 | # if data_count >= 96: 245 | # break 246 | pbar.finish() 247 | if mode == 'extra': 248 | return [data_gen, self.ner_dict, self.tag_dict] 249 | return data_gen 250 | 251 | """ build vocabulary map for all tokens included in dataset 252 | """ 253 | def buildVocabMap(self, data_elmo): 254 | vocab2index, index2vocab = {}, {} 255 | count = 1 256 | vocab2index['unk'] = 0 257 | index2vocab[0] = 'unk' 258 | for data_mb in data_elmo: 259 | for candidate in data_mb['candidates']: 260 | for token in candidate: 261 | if not vocab2index.__contains__(token): 262 | vocab2index[token] = count 263 | index2vocab[count] = token 264 | count += 1 265 | for token in data_mb['query']: 266 | if not vocab2index.__contains__(token): 267 | vocab2index[token] = count 268 | index2vocab[count] = token 269 | count += 1 270 | for token in data_mb['query_full_token']: 271 | if not vocab2index.__contains__(token): 272 | vocab2index[token] = count 273 | index2vocab[count] = token 274 | count += 1 275 | return vocab2index, index2vocab 276 | 277 | """ The core function to build graph 278 | """ 279 | def preprocessForGraph(self, data, supports): 280 | if data.__contains__('annotations'): 281 | data.pop('annotations') 282 | 283 | ## The first token in the query is combined with underline so we have to divided it into several words by 284 | ## removing underlines 285 | first_blank_pos = data['query'].find(' ') 286 | if first_blank_pos > 0: 287 | first_token_in_query = data['query'][:first_blank_pos] 288 | else: 289 | first_token_in_query = data['query'] 290 | query = data['query'].replace('_', ' ') 291 | data['query'] = self.tokenizer.tokenize(query) 292 | ## query_full_token means split the relation word in query based on "_" 293 | data['query_full_token'] = query 294 | 295 | candidates_orig = list(data['candidates']) 296 | 297 | data['candidates'] = [self.tokenizer.tokenize(candidate) for candidate in data['candidates']] 298 | 299 | marked_candidate = {} 300 | 301 | ## find all matched candidates in documents and mark their positions 302 | if self.is_masked: 303 | mask = [[self.ind(sindex, windex, cindex, candidate, marked_candidate) 304 | for windex, word_support in enumerate(support) for cindex, candidate in 305 | enumerate(data['candidates']) 306 | if self.check_masked(support, windex, candidate)] for sindex, support in 307 | enumerate(supports)] 308 | else: 309 | mask = [[self.ind(sindex, windex, cindex, candidate, marked_candidate) 310 | for windex, word_support in enumerate(support) for cindex, candidate in 311 | enumerate(data['candidates']) 312 | if self.check(support, windex, candidate)] for sindex, support in enumerate(supports)] 313 | tok_unmarked_candidates = [] 314 | unmarked_candidates_index_map = {} 315 | for candidate_index in range(len(data['candidates'])): 316 | if not marked_candidate.__contains__(candidate_index): 317 | tok_unmarked_candidates.append(data['candidates'][candidate_index]) 318 | unmarked_candidates_index_map[len(tok_unmarked_candidates) - 1] = candidate_index 319 | if len(tok_unmarked_candidates) != 0: 320 | unmarked_mask = [ 321 | [self.ind(sindex, windex, unmarked_candidates_index_map[cindex], candidate, marked_candidate) 322 | for windex, word_support in enumerate(support) for cindex, candidate in 323 | enumerate(tok_unmarked_candidates) 324 | if self.check(support, windex, candidate, for_unmarked=True)] for sindex, support in 325 | enumerate(supports)] 326 | mask = self.merge_two_masks(mask, unmarked_mask) 327 | 328 | nodes_id_name = [] 329 | count = 0 330 | for e in [[[x[-1] for x in c][0] for c in s] for s in mask]: 331 | u = [] 332 | for f in e: 333 | u.append((count, f)) 334 | count += 1 335 | 336 | nodes_id_name.append(u) 337 | 338 | data['nodes_candidates_id'] = [[node_triple[-1] for node_triple in node][0] 339 | for nodes_in_a_support in mask for node in nodes_in_a_support] 340 | 341 | ## find two kinds of edges between nodes 342 | ## edges_in means nodes within a document, edges_out means nodes with same string across different document 343 | edges_in, edges_out = [], [] 344 | for e0 in nodes_id_name: 345 | for f0, w0 in e0: 346 | for f1, w1 in e0: 347 | if f0 != f1: 348 | edges_in.append((f0, f1)) 349 | 350 | for e1 in nodes_id_name: 351 | for f1, w1 in e1: 352 | if e0 != e1 and w0 == w1: 353 | edges_out.append((f0, f1)) 354 | 355 | data['edges_in'] = edges_in 356 | data['edges_out'] = edges_out 357 | 358 | data['nodes_mask'] = mask 359 | 360 | data['relation_index'] = len(first_token_in_query) 361 | for index, answer in enumerate(candidates_orig): 362 | if answer == data['answer']: 363 | data['answer_candidate_id'] = index 364 | break 365 | return data 366 | 367 | """ gerating ELMo embeddings for nodes and query 368 | """ 369 | def preprocessForElmo(self, text_data, ee): 370 | data_elmo = {} 371 | 372 | mask_ = [[x[:-1] for x in f] for e in text_data['nodes_mask'] for f in e] 373 | supports, query, query_full_tokens = text_data['supports'], text_data['query'], text_data['query_full_token'] 374 | first_tokens_in_query = query[0].split('_') 375 | 376 | split_interval = self.elmo_split_interval 377 | if len(supports) <= split_interval: 378 | candidates, _ = ee.batch_to_embeddings(supports) 379 | candidates = candidates.data.cpu().numpy() 380 | else: 381 | ## split long support data into several parts to avoid possible OOM 382 | count = 0 383 | candidates = None 384 | while count < len(supports): 385 | current_candidates, _ = \ 386 | ee.batch_to_embeddings(supports[count:min(count + split_interval, len(supports))]) 387 | current_candidates = current_candidates.data.cpu().numpy() 388 | if candidates is None: 389 | candidates = current_candidates 390 | else: 391 | if candidates.shape[2] > current_candidates.shape[2]: 392 | current_candidates = np.pad(current_candidates, 393 | ((0, 0), (0, 0), (0, candidates.shape[2] - current_candidates.shape[2]), (0, 0)), 'constant') 394 | elif current_candidates.shape[2] > candidates.shape[2]: 395 | candidates = np.pad(candidates, 396 | ((0, 0), (0, 0), (0, current_candidates.shape[2] - candidates.shape[2]), (0, 0)), 'constant') 397 | candidates = np.concatenate((candidates, current_candidates)) 398 | count += split_interval 399 | 400 | data_elmo['nodes_elmo'] = [(candidates.transpose((0, 2, 1, 3))[np.array(m).T.tolist()]).astype(np.float16) 401 | for m in mask_] 402 | 403 | query, _ = ee.batch_to_embeddings([query]) 404 | query = query.data.cpu().numpy() 405 | data_elmo['query_elmo'] = (query.transpose((0, 2, 1, 3))).astype(np.float16)[0] 406 | if len(first_tokens_in_query) == 1: 407 | data_elmo['query_full_token_elmo'] = data_elmo['query_elmo'] 408 | else: 409 | query_full_tokens, _ = ee.batch_to_embeddings([first_tokens_in_query]) 410 | query_full_tokens = query_full_tokens.cpu().numpy() 411 | data_elmo['query_full_token_elmo'] = np.concatenate( 412 | (query_full_tokens.transpose((0, 2, 1, 3)).astype(np.float16)[0], data_elmo['query_elmo'][1:,:,:]), 0) 413 | return data_elmo 414 | 415 | """ generating GLoVe for nodes and query 416 | """ 417 | def preprocessForGlove(self, data_elmo, gloveEmbMap, vocab2index): 418 | data = {} 419 | nodes_glove = [] 420 | for candidate_id in data_elmo['nodes_candidates_id']: 421 | candidate_token = data_elmo['candidates'][candidate_id] 422 | node_glove = [] 423 | for token in candidate_token: 424 | node_glove.append(gloveEmbMap[vocab2index[token]]) 425 | nodes_glove.append(np.array(node_glove).astype(np.float32)) 426 | data['nodes_glove'] = nodes_glove 427 | query_glove = [] 428 | for token in data_elmo['query']: 429 | query_glove.append(gloveEmbMap[vocab2index[token]]) 430 | data['query_glove'] = np.array(query_glove).astype(np.float32) 431 | query_full_token_glove = [] 432 | for token in data_elmo['query_full_token']: 433 | query_full_token_glove.append(gloveEmbMap[vocab2index[token]]) 434 | data['query_full_token_glove'] = np.array(query_full_token_glove).astype(np.float32) 435 | return data 436 | 437 | """ generating POS and NER tags for every token in nodes or query 438 | """ 439 | def preprocessForExtra(self, data): 440 | nodes_mask = data['nodes_mask'] 441 | supports, query = data['supports'], data['query'] 442 | recovered_support = [] 443 | for support in supports: 444 | recovered_support.append(self.recoverTokens(support)) 445 | tokened_supports = [doc for doc in self.nlp.pipe(recovered_support, batch_size=1000)] 446 | recovered_query = self.recoverTokens(query) 447 | tokened_query = [doc for doc in self.nlp.pipe([recovered_query], batch_size=1000)] 448 | tag_dict = self.tag_dict 449 | ner_dict = self.ner_dict 450 | pos = [self.postagFunc(tokened_support, tag_dict) for tokened_support in tokened_supports] 451 | ner = [self.nertagFunc(tokened_support, ner_dict) for tokened_support in tokened_supports] 452 | query_pos = self.postagFunc(tokened_query[0], tag_dict) 453 | query_ner = self.nertagFunc(tokened_query[0], ner_dict) 454 | nodes_ner, nodes_pos = [], [] 455 | ## give each node a ner and pos tag and the order in each sample is the same as it in Elmo data 456 | for support_index, support_nodes in enumerate(nodes_mask): 457 | for node_mask in support_nodes: 458 | node_pos, node_ner = [], [] 459 | for mask in node_mask: 460 | if mask[1] > len(pos[support_index]) - 1: 461 | node_pos.append(tag_dict['']) 462 | node_ner.append(tag_dict['']) 463 | else: 464 | node_pos.append(pos[support_index][mask[1]]) 465 | node_ner.append(ner[support_index][mask[1]]) 466 | nodes_ner.append(np.array(node_ner)) 467 | nodes_pos.append(np.array(node_pos)) 468 | 469 | data = {'nodes_ner': nodes_ner, 'nodes_pos': nodes_pos, 'query_ner': query_ner, 'query_pos': query_pos} 470 | return data 471 | 472 | """ Read GLoVe file and generating a dict mapping it from token to embedding 473 | """ 474 | def buildGloveEmbMap(self, vocab2index, dim=300): 475 | vocab_size = len(vocab2index) 476 | emb = np.zeros((vocab_size, dim)) 477 | emb[0] = 0 478 | unknown_mask = np.zeros(vocab_size, dtype=bool) 479 | with open(self.glove_file, encoding='utf-8') as f: 480 | line_count = 0 481 | for line in f: 482 | line_count += 1 483 | elems = line.split() 484 | token = self.normalize_text(' '.join(elems[0:-dim])) 485 | if token in vocab2index: 486 | emb[vocab2index[token]] = [float(v) for v in elems[-dim:]] 487 | unknown_mask[vocab2index[token]] = True 488 | for index, mask in enumerate(unknown_mask): 489 | if not mask: 490 | emb[index] = emb[0] 491 | return emb 492 | 493 | def normalize_text(self, text): 494 | return unicodedata.normalize('NFD', text) 495 | 496 | def recoverTokens(self, tokens): 497 | res = '' 498 | for token in tokens: 499 | res = res + token + ' ' 500 | return res[:-1] 501 | 502 | def tagFunc(self, toks, pos_dict): 503 | postag = [] 504 | for w in toks: 505 | if not pos_dict.__contains__(w[1]): 506 | current_index = len(pos_dict) 507 | pos_dict[w[1]] = current_index 508 | postag.append(pos_dict[w[1]]) 509 | return postag, pos_dict 510 | 511 | def postagFunc(self, toks, tag_dict): 512 | postag = [] 513 | for w in toks: 514 | if len(w.text) > 0: 515 | if not tag_dict.__contains__(w.tag_): 516 | current_index = len(tag_dict) 517 | tag_dict[w.tag_] = current_index 518 | postag.append(tag_dict[w.tag_]) 519 | return postag 520 | 521 | def nertagFunc(self, toks, ner_dict): 522 | nertag = [] 523 | for w in toks: 524 | if len(w.text) > 0: 525 | ner_type = '{}_{}'.format(w.ent_type_, w.ent_iob_) 526 | if not ner_dict.__contains__(ner_type): 527 | current_index = len(ner_dict) 528 | ner_dict[ner_type] = current_index 529 | nertag.append(ner_dict[ner_type]) 530 | return nertag 531 | 532 | if __name__ == '__main__': 533 | parser = argparse.ArgumentParser() 534 | parser.add_argument('file_name', type=str) 535 | parser.add_argument('--is_masked', type=str2bool, default=False, help='using masked data or not') 536 | parser.add_argument('--use_glove', type=str2bool, default=True, help='Using Glove embedding or not') 537 | parser.add_argument('--use_extra_feature', type=str2bool, default=True, help='Using extra feature, e.g. '+ 538 | 'NER, POS, ') 539 | 540 | args = parser.parse_args() 541 | file_name = args.file_name 542 | is_masked = args.is_masked 543 | use_glove = args.use_glove 544 | use_extra_feature = args.use_extra_feature 545 | options_file = 'data/elmo_2x4096_512_2048cnn_2xhighway_options.json' 546 | weight_file = 'data/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5' 547 | glove_file = 'data/glove.840B.300d.txt' 548 | logger = config_logger('Preprocess') 549 | 550 | preprocesser = Preprocesser(file_name, logger, is_masked=is_masked, 551 | use_glove=use_glove, use_extra_feature=use_extra_feature, options_file=options_file, 552 | weight_file=weight_file, glove_file=glove_file) 553 | preprocesser.preprocess() 554 | -------------------------------------------------------------------------------- /utils/ConfigLogger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | 5 | 6 | def config_logger(log_prefix): 7 | logger = logging.getLogger() 8 | logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s', 9 | level=logging.INFO) 10 | rq = time.strftime('%Y%m%d%H%M', time.localtime(time.time())) 11 | log_path = os.getcwd() + '/logs/' + log_prefix + '/' 12 | if not os.path.exists(log_path): 13 | os.makedirs(log_path) 14 | log_name = log_path + rq + '.log' 15 | file_handler = logging.FileHandler(log_name, mode='w') 16 | file_handler.setLevel(logging.INFO) 17 | file_handler.setFormatter( 18 | logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s')) 19 | logger.addHandler(file_handler) 20 | return logger 21 | -------------------------------------------------------------------------------- /utils/Dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import scipy.sparse 4 | 5 | """The class related to dataset 6 | """ 7 | class Dataset(object): 8 | def __init__(self, filename_prefix, use_elmo, use_glove, use_extra_feature, use_full_query_token=False, 9 | add_query_node=False, max_nodes=500, max_query_size=25, max_candidates=80, max_candidates_len=10, 10 | use_edge=True): 11 | if not use_elmo and not use_glove: 12 | raise Exception('At least one of ELMo, GloVe should be used') 13 | 14 | self.max_nodes, self.max_query_size, self.max_candidates, self.max_candidates_len = max_nodes,\ 15 | max_query_size, max_candidates, max_candidates_len 16 | 17 | graph_file_name = '{}.preprocessed.pickle'.format(filename_prefix) 18 | self.data = [d for d in pickle.load(open(graph_file_name, 'rb')) if len(d['nodes_candidates_id']) > 0] 19 | self.use_elmo = use_elmo 20 | self.data_elmo = None 21 | if use_elmo: 22 | elmo_file_name = '{}.elmo.preprocessed.pickle'.format(filename_prefix) 23 | self.data_elmo = [d for d in pickle.load(open(elmo_file_name, 'rb')) if len(d['nodes_elmo']) > 0] 24 | self.data_glove = None 25 | self.use_glove = use_glove 26 | if use_glove: 27 | glove_file_name = '{}.glove.preprocessed.pickle'.format(filename_prefix) 28 | self.data_glove = [d for d in pickle.load(open(glove_file_name, 'rb')) if len(d['nodes_glove']) > 0] 29 | self.data_extra = None 30 | self.use_extra_feature = use_extra_feature 31 | if use_extra_feature: 32 | extra_file_name = '{}.extra.preprocessed.pickle'.format(filename_prefix) 33 | self.data_extra = [d for d in pickle.load(open(extra_file_name, 'rb')) if len(d['nodes_pos']) > 0] 34 | self.ner_dict = pickle.load(open('data/ner_dict.pickle', 'rb')) 35 | self.pos_dict = pickle.load(open('data/pos_dict.pickle', 'rb')) 36 | self.idx = list(range(len(self))) 37 | self.counter = len(self) 38 | self.use_full_query_token = use_full_query_token 39 | self.add_query_node = add_query_node 40 | self.use_edge = use_edge 41 | 42 | def __len__(self): 43 | return len(self.data) 44 | 45 | def getDataSize(self): 46 | return len(self.data) 47 | 48 | def getPosAndNerDictSize(self): 49 | return len(self.pos_dict), len(self.ner_dict) 50 | 51 | def buildElmoData(self, data_elmo_mb): 52 | filt = lambda c: np.array([c[:, 0].mean(0), c[0, 1], c[-1, 2]]) 53 | 54 | nodes_elmo_mb = [] 55 | for d in data_elmo_mb: 56 | nodes_elmo_mb.append(np.pad(np.array([filt(c) for c in d['nodes_elmo']]), 57 | ((0, self.max_nodes - len(d['nodes_elmo'])), (0, 0), (0, 0)), mode='constant')) 58 | nodes_elmo_mb = np.array(nodes_elmo_mb) 59 | 60 | query_elmo_mb = np.stack([np.pad(d['query_elmo'], 61 | ((0, self.max_query_size - d['query_elmo'].shape[0]), (0, 0), (0, 0)), mode='constant') for d in 62 | data_elmo_mb], 0) 63 | return nodes_elmo_mb, query_elmo_mb 64 | 65 | """ Generating glove data""" 66 | def buildGloveData(self, idx, data_glove_mb): 67 | filt = lambda c: np.array(c[:].mean(0)) 68 | # print(idx) 69 | # print([len(d['nodes_glove']) for d in data_glove_mb]) 70 | nodes_glove_mb = np.array([ 71 | np.pad(np.array([filt(c) for c in d['nodes_glove']]), ((0, self.max_nodes - len(d['nodes_glove'])), (0, 0)), 72 | mode='constant') for d in data_glove_mb]) 73 | if not self.use_full_query_token: 74 | query_glove_mb = np.stack([np.pad(d['query_glove'], 75 | ((0, self.max_query_size - d['query_glove'].shape[0]), (0, 0)), mode='constant') 76 | for d in data_glove_mb], 0) 77 | else: 78 | query_glove_mb = np.stack([np.pad(d['query_full_token_glove'], 79 | ((0, self.max_query_size - d['query_full_token_glove'].shape[0]), (0, 0)), mode='constant') 80 | for d in data_glove_mb], 0) 81 | return nodes_glove_mb, query_glove_mb 82 | 83 | """ Generate data for extra data like POS and NER""" 84 | def buildExtraData(self, data_extra_bm): 85 | filt = lambda c : np.argmax(np.bincount(c)) 86 | node_ner_mb = np.array( 87 | [np.pad(np.array([filt(c) for c in d['nodes_ner']]), ((0, self.max_nodes - len(d['nodes_ner']))), 88 | mode='constant') for d in data_extra_bm]) 89 | node_pos_mb = np.array( 90 | [np.pad(np.array([filt(c) for c in d['nodes_pos']]), ((0, self.max_nodes - len(d['nodes_pos']))), 91 | mode='constant') for d in data_extra_bm]) 92 | if not self.use_full_query_token: 93 | query_pos_mb = np.stack([np.pad(d['query_pos'], (0, self.max_query_size - len(d['query_pos'])), 94 | mode='constant') for d in data_extra_bm], 0) 95 | query_ner_mb = np.stack([np.pad(d['query_ner'], (0, self.max_query_size - len(d['query_ner'])), 96 | mode='constant') for d in data_extra_bm], 0) 97 | else: 98 | query_pos_mb = np.stack([np.pad(d['query_pos_full_token'], 99 | (0, self.max_query_size - len(d['query_pos_full_token'])), 100 | mode='constant') for d in data_extra_bm], 0) 101 | query_ner_mb = np.stack([np.pad(d['query_ner_full_token'], 102 | (0, self.max_query_size - len(d['query_ner_full_token'])), 103 | mode='constant') for d in data_extra_bm], 0) 104 | return node_ner_mb, node_pos_mb, query_ner_mb, query_pos_mb 105 | 106 | """ We build edges in graph using adjacent matrices 107 | """ 108 | def buildEdgeData(self, data_mb): 109 | adj_mb = [] 110 | for d in data_mb: 111 | if self.use_edge: 112 | adj_ = [] 113 | 114 | if len(d['edges_in']) == 0: 115 | adj_.append(np.zeros((self.max_nodes, self.max_nodes))) 116 | else: 117 | adj = scipy.sparse.coo_matrix((np.ones(len(d['edges_in'])), np.array(d['edges_in']).T), 118 | shape=(self.max_nodes, self.max_nodes)).toarray() 119 | 120 | adj_.append(adj) 121 | 122 | if len(d['edges_out']) == 0: 123 | adj_.append(np.zeros((self.max_nodes, self.max_nodes))) 124 | else: 125 | adj = scipy.sparse.coo_matrix((np.ones(len(d['edges_out'])), np.array(d['edges_out']).T), 126 | shape=(self.max_nodes, self.max_nodes)).toarray() 127 | 128 | adj_.append(adj) 129 | 130 | adj = np.pad(np.ones((len(d['nodes_candidates_id']), len(d['nodes_candidates_id']))), 131 | ((0, self.max_nodes - len(d['nodes_candidates_id'])), 132 | (0, self.max_nodes - len(d['nodes_candidates_id']))), mode='constant') \ 133 | - adj_[0] - adj_[1] - np.pad(np.eye(len(d['nodes_candidates_id'])), 134 | ((0, self.max_nodes - len(d['nodes_candidates_id'])), 135 | (0, self.max_nodes - len(d['nodes_candidates_id']))), mode='constant') 136 | 137 | adj_.append(np.clip(adj, 0, 1)) 138 | 139 | adj = np.stack(adj_, 0) 140 | 141 | d_ = adj.sum(-1) 142 | d_[np.nonzero(d_)] **= -1 143 | adj = adj * np.expand_dims(d_, -1) 144 | adj_mb.append(adj) 145 | else: 146 | adj = np.pad(np.ones((len(d['nodes_candidates_id']), len(d['nodes_candidates_id']))), 147 | ((0, self.max_nodes - len(d['nodes_candidates_id'])), 148 | (0, self.max_nodes - len(d['nodes_candidates_id']))), mode='constant') \ 149 | - np.pad(np.eye(len(d['nodes_candidates_id'])), 150 | ((0, self.max_nodes - len(d['nodes_candidates_id'])), 151 | (0, self.max_nodes - len(d['nodes_candidates_id']))), mode='constant') 152 | adj_mb.append(adj) 153 | return adj_mb 154 | 155 | """ Truncate some nodes and edges if the node number in a graph exceeds the upper bound""" 156 | def truncateNodesAndEdge(self, data, data_elmo, data_glove, data_extra): 157 | # get the length of nodes in each data sample and truncate data if it exceeds the maximum length 158 | # including the elmo data, glove data and extra feature data 159 | nodes_length_mb = np.stack([len(d['nodes_candidates_id']) for d in data], 0) 160 | exceed_nodes_th = nodes_length_mb > self.max_nodes 161 | for index, exceed in enumerate(exceed_nodes_th): 162 | if exceed: 163 | data[index]['edges_in'] = self.truncateEdges(data[index]['edges_in']) 164 | data[index]['edges_out'] = self.truncateEdges(data[index]['edges_out']) 165 | data[index]['nodes_candidates_id'] = data[index]['nodes_candidates_id'][:self.max_nodes] 166 | if self.use_elmo: 167 | data_elmo[index]['nodes_elmo'] = data_elmo[index]['nodes_elmo'][:self.max_nodes] 168 | if self.use_glove: 169 | data_glove[index]['nodes_glove'] = data_glove[index]['nodes_glove'][:self.max_nodes] 170 | if self.use_extra_feature: 171 | data_extra[index]['nodes_ner'] = data_extra[index]['nodes_ner'][:self.max_nodes] 172 | data_extra[index]['nodes_pos'] = data_extra[index]['nodes_pos'][:self.max_nodes] 173 | return nodes_length_mb 174 | 175 | """ Sometimes if we truncate some nodes, then related edges should also be truncated""" 176 | def truncateEdges(self, edges): 177 | truncated_edges = [] 178 | for edge_pair in edges: 179 | if edge_pair[0] >= self.max_nodes: 180 | break 181 | if edge_pair[1] < self.max_nodes: 182 | truncated_edges.append(edge_pair) 183 | return truncated_edges 184 | 185 | """ Truncate query if it exceeds the max length""" 186 | def truncateQuery(self, data, data_elmo, data_glove, data_extra): 187 | query_length_mb = np.stack([len(d['query']) for d in data], 0) 188 | exceed_query_th = query_length_mb > self.max_query_size 189 | for index, exceed in enumerate(exceed_query_th): 190 | if exceed: 191 | if self.use_elmo: 192 | data_elmo[index]['query_elmo'] = data_elmo[index]['query_elmo'][:self.max_query_size] 193 | if self.use_glove: 194 | data_glove[index]['query_glove'] = data_glove[index]['query_glove'][:self.max_query_size] 195 | if self.use_extra_feature: 196 | data_extra[index]['query_ner'] = data_extra[index]['query_ner'][:self.max_query_size] 197 | data_extra[index]['query_pos'] = data_extra[index]['query_pos'][:self.max_query_size] 198 | return query_length_mb 199 | 200 | """ The core function to generate next batch""" 201 | def next_batch_pro(self, idx, epoch_finished): 202 | data_mb = [self.data[i] for i in idx] 203 | data_elmo_mb = None 204 | data_glove_mb = None 205 | data_extra_mb = None 206 | if self.use_elmo: 207 | data_elmo_mb = [self.data_elmo[i] for i in idx] 208 | if self.use_glove: 209 | data_glove_mb = [self.data_glove[i] for i in idx] 210 | if self.use_extra_feature: 211 | data_extra_mb = [self.data_extra[i] for i in idx] 212 | 213 | id_mb = [d['id'] for d in data_mb] 214 | 215 | answer_candidate_id_mb = [d['answer_candidate_id'] for d in data_mb] 216 | 217 | nodes_length_mb = self.truncateNodesAndEdge(data_mb, data_elmo_mb, data_glove_mb, data_extra_mb) 218 | query_length_mb = self.truncateQuery(data_mb, data_elmo_mb, data_glove_mb, data_glove_mb) 219 | 220 | adj_mb = self.buildEdgeData(data_mb) 221 | 222 | nodes_elmo_mb, query_elmo_mb = None, None 223 | if self.use_elmo: 224 | nodes_elmo_mb, query_elmo_mb = self.buildElmoData(data_elmo_mb) 225 | 226 | nodes_glove_mb, query_glove_mb = None, None 227 | if self.use_glove: 228 | nodes_glove_mb, query_glove_mb = self.buildGloveData(idx, data_glove_mb) 229 | 230 | nodes_ner_mb, nodes_pos_mb, query_ner_mb, query_pos_mb = None, None, None, None 231 | if self.use_extra_feature: 232 | nodes_ner_mb, nodes_pos_mb, query_ner_mb, query_pos_mb = self.buildExtraData(data_extra_mb) 233 | 234 | if self.add_query_node: 235 | bmask_mb = np.array([np.pad(np.array([i == np.array(d['nodes_candidates_id']) 236 | for i in range(len(d['candidates']) - 1)]), 237 | ((0, self.max_candidates - len(d['candidates']) + 1), 238 | (0, self.max_nodes - len(d['nodes_candidates_id']))), mode='constant') 239 | for d in data_mb]) 240 | else: 241 | bmask_mb = np.array([np.pad(np.array([i == np.array(d['nodes_candidates_id']) 242 | for i in range(len(d['candidates']))]), 243 | ((0, self.max_candidates - len(d['candidates'])), (0, self.max_nodes - len(d['nodes_candidates_id']))), 244 | mode='constant') for d in data_mb]) 245 | 246 | return epoch_finished, {'id_mb': id_mb, 'nodes_length_mb': nodes_length_mb, 247 | 'query_length_mb': query_length_mb, 'bmask_mb': bmask_mb, 'adj_mb': adj_mb, 248 | 'answer_candidates_id_mb': answer_candidate_id_mb, 249 | 'nodes_elmo_mb': nodes_elmo_mb, 'query_elmo_mb': query_elmo_mb, 250 | 'nodes_glove_mb': nodes_glove_mb, 'query_glove_mb': query_glove_mb, 251 | 'nodes_ner_mb': nodes_ner_mb, 'nodes_pos_mb': nodes_pos_mb, 252 | 'query_ner_mb': query_ner_mb, 'query_pos_mb': query_pos_mb} 253 | 254 | """ Generate data for next batch""" 255 | def next_batch(self, batch_dim=None, use_multi_gpu=False): 256 | 257 | epoch_finished = False 258 | if batch_dim is not None: 259 | if self.counter >= len(self): 260 | np.random.shuffle(self.idx) 261 | self.counter = 0 262 | if self.counter + batch_dim >= len(self): 263 | idx = self.idx[self.counter:] 264 | epoch_finished = True 265 | else: 266 | idx = self.idx[self.counter:self.counter + batch_dim] 267 | self.counter += batch_dim 268 | else: 269 | idx = self.idx 270 | # try: 271 | if use_multi_gpu: 272 | if len(idx) % 2 != 0: 273 | idx.append(idx[-1]) 274 | epoch_finished, feed_dict = self.next_batch_pro(idx, epoch_finished) 275 | # except: 276 | # print(idx) 277 | return epoch_finished, feed_dict 278 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caoyu-noob/BAG/180538a8e0de3a6a5465802a1e10feee4d564dd2/utils/__init__.py -------------------------------------------------------------------------------- /utils/pytorch_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import scipy 4 | from scipy.sparse.coo import coo_matrix 5 | 6 | import numpy as np 7 | 8 | from tqdm import tqdm 9 | from torch.utils.data import DataLoader, TensorDataset 10 | 11 | def load_raw_data(args, file_name_prefix): 12 | graph_file_name = '{}.preprocessed.pickle'.format(file_name_prefix) 13 | base_data = [d for d in pickle.load(open(graph_file_name, 'rb')) if len(d['nodes_candidates_id']) > 0] 14 | data_elmo = None 15 | if args.use_elmo: 16 | elmo_file_name = '{}.elmo.preprocessed.pickle'.format(file_name_prefix) 17 | data_elmo = [d for d in pickle.load(open(elmo_file_name, 'rb')) if len(d['nodes_elmo']) > 0] 18 | args.data_glove = None 19 | if args.use_glove: 20 | glove_file_name = '{}.glove.preprocessed.pickle'.format(file_name_prefix) 21 | data_glove = [d for d in pickle.load(open(glove_file_name, 'rb')) if len(d['nodes_glove']) > 0] 22 | data_extra = None 23 | if args.use_extra_feature: 24 | extra_file_name = '{}.extra.preprocessed.pickle'.format(file_name_prefix) 25 | data_extra = [d for d in pickle.load(open(extra_file_name, 'rb')) if len(d['nodes_pos']) > 0] 26 | ner_dict = pickle.load(open('data/ner_dict.pickle', 'rb')) 27 | pos_dict = pickle.load(open('data/pos_dict.pickle', 'rb')) 28 | return base_data, data_elmo, data_glove, data_extra, ner_dict, pos_dict 29 | 30 | def truncate_edge(edges, max_nodes): 31 | truncated_edges = [] 32 | for edge_pair in edges: 33 | if edge_pair[0] >= max_nodes: 34 | break 35 | if edge_pair[1] < max_nodes: 36 | truncated_edges.append(edge_pair) 37 | return truncated_edges 38 | 39 | """ Truncate some nodes and edges if the node number in a graph exceeds the upper bound""" 40 | def truncate_node_and_edge(args, base_data, data_elmo, data_glove, data_extra): 41 | # get the length of nodes in each data sample and truncate data if it exceeds the maximum length 42 | # including the elmo data, glove data and extra feature data 43 | if len(base_data['nodes_candidates_id']) > args.max_nodes: 44 | base_data['edges_in'] = truncate_edge(base_data['edges_in'], args.max_nodes) 45 | base_data['edges_out'] = truncate_edge(base_data['edges_out'], args.max_nodes) 46 | base_data['nodes_candidates_id'] = base_data['nodes_candidates_id'][:args.max_nodes] 47 | if args.use_elmo: 48 | data_elmo['nodes_elmo'] = data_elmo['nodes_elmo'][:args.max_nodes] 49 | if args.use_glove: 50 | data_glove['nodes_glove'] = data_glove['nodes_glove'][:args.max_nodes] 51 | if args.use_extra_feature: 52 | data_extra['nodes_ner'] = data_extra['nodes_ner'][:args.max_nodes] 53 | data_extra['nodes_pos'] = data_extra['nodes_pos'][:args.max_nodes] 54 | return len(base_data['nodes_candidates_id']) 55 | 56 | """ Truncate query if it exceeds the max length""" 57 | def truncate_query(args, base_data, data_elmo, data_glove, data_extra): 58 | if len(base_data['query']) > args.max_query_size: 59 | if args.use_elmo: 60 | data_elmo['query_elmo'] = data_elmo['query_elmo'][:args.max_query_size] 61 | if args.use_glove: 62 | data_glove['query_glove'] = data_glove['query_glove'][:args.max_query_size] 63 | if args.use_extra_feature: 64 | data_extra['query_ner'] = data_extra['query_ner'][:args.max_query_size] 65 | data_extra['query_pos'] = data_extra['query_pos'][:args.max_query_size] 66 | return len(base_data['query']) 67 | 68 | """ We build edges in graph using adjacent matrices 69 | """ 70 | def build_edge_data(args, base_data): 71 | if args.use_edge: 72 | adj_ = [] 73 | 74 | if len(base_data['edges_in']) == 0: 75 | adj_.append(np.zeros((args.max_nodes, args.max_nodes), dtype=np.float32)) 76 | else: 77 | adj = coo_matrix((np.ones(len(base_data['edges_in'])), np.array(base_data['edges_in']).T), 78 | shape=(args.max_nodes, args.max_nodes), dtype=np.float32).toarray() 79 | 80 | adj_.append(adj) 81 | 82 | if len(base_data['edges_out']) == 0: 83 | adj_.append(np.zeros((args.max_nodes, args.max_nodes), dtype=np.float32)) 84 | else: 85 | adj = coo_matrix((np.ones(len(base_data['edges_out'])), np.array(base_data['edges_out']).T), 86 | shape=(args.max_nodes, args.max_nodes), dtype=np.float32).toarray() 87 | 88 | adj_.append(adj) 89 | 90 | adj = np.pad(np.ones((len(base_data['nodes_candidates_id']), len(base_data['nodes_candidates_id'])), dtype=np.float32), 91 | ((0, args.max_nodes - len(base_data['nodes_candidates_id'])), 92 | (0, args.max_nodes - len(base_data['nodes_candidates_id']))), mode='constant') \ 93 | - adj_[0] - adj_[1] - np.pad(np.eye(len(base_data['nodes_candidates_id'])), 94 | ((0, args.max_nodes - len(base_data['nodes_candidates_id'])), 95 | (0, args.max_nodes - len(base_data['nodes_candidates_id']))), mode='constant') 96 | 97 | adj_.append(np.clip(adj, 0, 1, dtype=np.float32)) 98 | 99 | adj = np.stack(adj_, 0) 100 | 101 | d_ = adj.sum(-1) 102 | d_[np.nonzero(d_)] **= -1 103 | adj = adj * np.expand_dims(d_, -1) 104 | return torch.from_numpy(adj) 105 | else: 106 | adj = np.pad(np.ones((len(base_data['nodes_candidates_id']), len(d['nodes_candidates_id']))), 107 | ((0, args.max_nodes - len(base_data['nodes_candidates_id'])), 108 | (0, args.max_nodes - len(base_data['nodes_candidates_id']))), mode='constant') \ 109 | - np.pad(np.eye(len(base_data['nodes_candidates_id'])), 110 | ((0, args.max_nodes - len(base_data['nodes_candidates_id'])), 111 | (0, args.max_nodes - len(base_data['nodes_candidates_id']))), mode='constant') 112 | return torch.from_numpy(adj) 113 | 114 | def build_elmo_data(args, elmo): 115 | filt = lambda c: np.array([c[:, 0].mean(0), c[0, 1], c[-1, 2]]) 116 | 117 | nodes_elmo = np.pad(np.array([filt(c) for c in elmo['nodes_elmo']], dtype=np.float32), 118 | ((0, args.max_nodes - len(elmo['nodes_elmo'])), (0, 0), (0, 0)), mode='constant') 119 | query_elmo = np.pad(np.array(elmo['query_elmo'], dtype=np.float32), 120 | ((0, args.max_query_size - elmo['query_elmo'].shape[0]), (0, 0), (0, 0)), mode='constant') 121 | return torch.from_numpy(nodes_elmo), torch.from_numpy(query_elmo) 122 | 123 | """ Generating glove data""" 124 | def build_glove_data(args, glove): 125 | filt = lambda c: np.array(c[:].mean(0)) 126 | nodes_glove = np.pad(np.array([filt(c) for c in glove['nodes_glove']]), 127 | ((0, args.max_nodes - len(glove['nodes_glove'])), (0, 0)), mode='constant') 128 | if not args.use_full_query_token: 129 | query_glove = np.pad(glove['query_glove'], 130 | ((0, args.max_query_size - glove['query_glove'].shape[0]), (0, 0)), mode='constant') 131 | else: 132 | query_glove = np.pad(glove['query_full_token_glove'], 133 | ((0, args.max_query_size - glove['query_full_token_glove'].shape[0]), (0, 0)), mode='constant') 134 | return torch.from_numpy(nodes_glove), torch.from_numpy(query_glove) 135 | 136 | def build_extra_feature(args, extra): 137 | filt = lambda c : np.argmax(np.bincount(c)) 138 | nodes_ner = np.pad(np.array([filt(c) for c in extra['nodes_ner']], dtype=np.int64), 139 | ((0, args.max_nodes - len(extra['nodes_ner']))), mode='constant') 140 | nodes_pos = np.pad(np.array([filt(c) for c in extra['nodes_pos']], dtype=np.int64), 141 | ((0, args.max_nodes - len(extra['nodes_pos']))), mode='constant') 142 | if not args.use_full_query_token: 143 | query_pos = np.pad(np.array(extra['query_pos'], dtype=np.int64), 144 | (0, args.max_query_size - len(extra['query_pos'])), mode='constant') 145 | query_ner = np.pad(np.array(extra['query_ner'], dtype=np.int64), 146 | (0, args.max_query_size - len(extra['query_ner'])), mode='constant') 147 | else: 148 | query_pos = np.pad(np.array(extra['query_pos_full_token'], dtype=np.int64), 149 | (0, args.max_query_size - len(extra['query_pos_full_token'])), mode='constant') 150 | query_ner = np.pad(np.array(extra['query_ner_full_token'], dtype=np.int64), 151 | (0, args.max_query_size - len(extra['query_ner_full_token'])), mode='constant') 152 | return torch.from_numpy(nodes_ner), torch.from_numpy(nodes_pos), torch.from_numpy(query_ner), \ 153 | torch.from_numpy(query_pos) 154 | 155 | '''Build output mask to show the relationship between each node and candidate''' 156 | def build_output_nodes_mask(args, base_data): 157 | if args.add_query_node: 158 | output_mask = np.pad(np.array([i == np.array(base_data['nodes_candidates_id']) 159 | for i in range(len(base_data['candidates']) - 1)]), 160 | ((0, args.max_candidates - len(base_data['candidates']) + 1), 161 | (0, args.max_nodes - len(base_data['nodes_candidates_id']))), mode='constant') 162 | else: 163 | output_mask = np.pad(np.array([i == np.array(base_data['nodes_candidates_id']) 164 | for i in range(len(base_data['candidates']))]), 165 | ((0, args.max_candidates - len(base_data['candidates'])), 166 | (0, args.max_nodes - len(base_data['nodes_candidates_id']))), mode='constant') 167 | return torch.from_numpy(output_mask) 168 | 169 | '''Build pytorch tensor data so as to used in dataset loader''' 170 | def build_tensor_data(args, base_data, data_elmo, data_glove, data_extra): 171 | nodes_mask, query_lengths, adjs, output_masks, labels = [], [], [], [], [] 172 | for i in tqdm(range(len(base_data)), desc='Building Tensor data'): 173 | base = base_data[i] 174 | elmo, glove, extra = None, None, None 175 | if args.use_elmo: 176 | elmo = data_elmo[i] 177 | if args.use_glove: 178 | glove = data_glove[i] 179 | if args.use_extra_feature: 180 | extra = data_extra[i] 181 | cur_node_len = truncate_node_and_edge(args, base, elmo, glove, extra) 182 | cur_query_len = truncate_query(args, base, elmo, glove, extra) 183 | adj_data = build_edge_data(args, base) 184 | if args.use_elmo: 185 | elmo_data = build_elmo_data(args, elmo) 186 | data_elmo[i] = elmo_data 187 | if args.use_glove: 188 | glove_data = build_glove_data(args, glove) 189 | data_glove[i] = glove_data 190 | if args.use_extra_feature: 191 | extra_data = build_extra_feature(args, extra) 192 | data_extra[i] = extra_data 193 | output_masks.append(build_output_nodes_mask(args, base)) 194 | cur_nodes_mask = torch.zeros(args.max_nodes) 195 | cur_nodes_mask[:cur_node_len] = 1 196 | nodes_mask.append(cur_nodes_mask) 197 | query_lengths.append(cur_query_len) 198 | adjs.append(adj_data) 199 | labels.append(base['answer_candidate_id']) 200 | data_size = len(base_data) 201 | tensor_dataset = [torch.cat([x.unsqueeze(0) for x in adjs], dim=0)] 202 | nodes_elmo, query_elmo = torch.zeros(data_size, 1, dtype=torch.bool), torch.zeros(data_size, 1, dtype=torch.bool) 203 | if args.use_elmo: 204 | nodes_elmo = torch.cat([x[0].unsqueeze(0) for x in data_elmo], dim=0) 205 | query_elmo = torch.cat([x[1].unsqueeze(0) for x in data_elmo], dim=0) 206 | tensor_dataset.extend([nodes_elmo, query_elmo]) 207 | nodes_glove, query_glove = torch.zeros(data_size, 1, dtype=torch.bool), torch.zeros(data_size, 1, dtype=torch.bool) 208 | if args.use_glove: 209 | nodes_glove = torch.cat([x[0].unsqueeze(0) for x in data_glove], dim=0) 210 | query_glove = torch.cat([x[1].unsqueeze(0) for x in data_glove], dim=0) 211 | tensor_dataset.extend([nodes_glove, query_glove]) 212 | nodes_ner, nodes_pos = torch.zeros(data_size, 1, dtype=torch.bool), torch.zeros(data_size, 1, dtype=torch.bool) 213 | query_ner, query_pos = torch.zeros(data_size, 1, dtype=torch.bool), torch.zeros(data_size, 1, dtype=torch.bool) 214 | if args.use_extra_feature: 215 | nodes_ner = torch.cat([x[0].unsqueeze(0) for x in data_extra], dim=0) 216 | nodes_pos = torch.cat([x[1].unsqueeze(0) for x in data_extra], dim=0) 217 | query_ner = torch.cat([x[2].unsqueeze(0) for x in data_extra], dim=0) 218 | query_pos = torch.cat([x[3].unsqueeze(0) for x in data_extra], dim=0) 219 | tensor_dataset.extend([nodes_ner, nodes_pos, query_ner, query_pos]) 220 | tensor_dataset.append(torch.cat([x.unsqueeze(0) for x in nodes_mask], dim=0)) 221 | tensor_dataset.append(torch.LongTensor(query_lengths)) 222 | tensor_dataset.append(torch.cat([x.unsqueeze(0) for x in output_masks], dim=0)) 223 | tensor_dataset.append(torch.LongTensor(labels)) 224 | return tensor_dataset 225 | 226 | '''Build the list of ids and corresponding answer candidate text so as to generate the prediction json file''' 227 | def build_id_candidate_list(base_data): 228 | res = [] 229 | for d in base_data: 230 | res.append((d['id'], [' '.join(token) for token in d['candidates']])) 231 | return res 232 | 233 | def get_pytorch_dataloader(args, file_name_prefix, shuffle=False, for_evaluation=True): 234 | base_data, data_elmo, data_glove, data_extra, ner_dict, pos_dict = load_raw_data(args, file_name_prefix) 235 | id_candidate_list = None 236 | if for_evaluation: 237 | id_candidate_list = build_id_candidate_list(base_data) 238 | tensor_dataset = build_tensor_data(args, base_data, data_elmo, data_glove, data_extra) 239 | tensor_dataset = TensorDataset(*tensor_dataset) 240 | data_loader = DataLoader(tensor_dataset, batch_size=args.batch_size, shuffle=shuffle) 241 | return data_loader, len(ner_dict), len(pos_dict), id_candidate_list -------------------------------------------------------------------------------- /utils/str2bool.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(text): 4 | if text.lower() in ('true', 'yes', 'y'): 5 | return True 6 | elif text.lower() in ('false', 'no', 'n'): 7 | return False 8 | else: 9 | raise argparse.ArgumentTypeError('Boolean value expected.') --------------------------------------------------------------------------------