├── README.md ├── data └── Math_23K.json ├── main.py └── src ├── expressions_transfer.py ├── logger.py ├── masked_cross_entropy.py ├── models.py ├── pre_data.py └── train_and_evaluate.py /README.md: -------------------------------------------------------------------------------- 1 | # MultiMath 2 | Solving Math Word Problems with Multi-Encoders and Multi-Decoders (Coling 2020) 3 | 4 | Our paper: https://www.aclweb.org/anthology/2020.coling-main.262.pdf 5 | 6 | 7 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import os 3 | import time 4 | import torch.optim 5 | import torch.nn as nn 6 | 7 | from src.logger import * 8 | from src.models import * 9 | from src.train_and_evaluate import * 10 | from src.expressions_transfer import * 11 | 12 | batch_size = 64 13 | embedding_size = 128 14 | hidden_size = 512 15 | n_epochs = 80 16 | learning_rate = 1e-3 17 | weight_decay = 1e-5 18 | beam_size = 5 19 | n_layers = 2 20 | hop_size = 2 21 | 22 | from pyltp import Postagger,Parser 23 | LTP_DATA_DIR="../ltp_data_v3.4.0" 24 | pos_model_path = os.path.join(LTP_DATA_DIR, "pos.model") 25 | par_model_path = os.path.join(LTP_DATA_DIR, 'parser.model') 26 | postagger = Postagger() 27 | postagger.load(pos_model_path) 28 | parser = Parser() 29 | parser.load(par_model_path) 30 | 31 | 32 | def read_data_json(filename): 33 | with open(filename, 'r') as f: 34 | return json.load(f) 35 | 36 | def write_data_json(data, filename): 37 | with open(filename, 'w') as f: 38 | json.dump(data, f, ensure_ascii=False, indent=4) 39 | 40 | def generate_train_test(): 41 | data = load_raw_data("data/Math_23K.json") 42 | pairs, generate_nums, copy_nums = transfer_num(data) 43 | temp_pairs = [] 44 | for p in pairs: 45 | if p[0] not in ["8883"]: 46 | temp_pairs.append((p[0], p[1], p[2], p[2], p[3], p[4])) 47 | else: 48 | temp_pairs.append((p[0], p[1], p[2], p[2], p[3], p[4])) 49 | 50 | pre_temp_pairs = [] 51 | for p in temp_pairs: 52 | postags = postagger.postag(p[1]) 53 | postags = ' '.join(postags).split(' ') 54 | arcs = parser.parse(p[1], postags) 55 | parse_tree = [arc.head-1 for arc in arcs] 56 | pre_temp_pairs.append((p[0], p[1], postags, parse_tree, 57 | from_infix_to_prefix(p[3]), from_infix_to_postfix(p[3]), p[4], p[5])) 58 | 59 | pairs = pre_temp_pairs 60 | 61 | fold_size = int(len(pairs) * 0.2) 62 | fold_pairs = [] 63 | for split_fold in range(4): 64 | fold_start = fold_size * split_fold 65 | fold_end = fold_size * (split_fold + 1) 66 | fold_pairs.append(pairs[fold_start:fold_end]) 67 | fold_pairs.append(pairs[(fold_size * 4):]) 68 | 69 | for fold in range(5): 70 | pairs_tested = [] 71 | pairs_trained = [] 72 | for fold_t in range(5): 73 | if fold_t == fold: 74 | pairs_tested += fold_pairs[fold_t] 75 | else: 76 | pairs_trained += fold_pairs[fold_t] 77 | write_data_json(pairs_trained, "data/train_"+str(fold)+".json") 78 | write_data_json(pairs_tested, "data/test_"+str(fold)+".json") 79 | 80 | 81 | def train(fold): 82 | data = load_raw_data("data/Math_23K.json") 83 | pairs, generate_nums, copy_nums = transfer_num(data) 84 | 85 | elogger = Logger("MultiMath_"+str(fold)) 86 | pairs_trained = read_data_json("data/train_"+str(fold)+".json") 87 | pairs_tested = read_data_json("data/test_"+str(fold)+".json") 88 | 89 | best_acc_fold = [] 90 | 91 | input1_lang, input2_lang, output1_lang, output2_lang, train_pairs, test_pairs = prepare_data(pairs_trained, pairs_tested, 5, generate_nums, copy_nums) 92 | 93 | emb_vectors = word2vec(train_pairs, embedding_size, input1_lang) 94 | np.save("data/emb_"+str(fold)+".npy", emb_vectors) 95 | emb_vectors = np.load("data/emb_"+str(fold)+".npy") 96 | embed_model = nn.Embedding(input1_lang.n_words, embedding_size, padding_idx=0) 97 | embed_model.weight.data.copy_(torch.from_numpy(emb_vectors)) 98 | 99 | # Initialize models 100 | encoder = EncoderSeq(input1_size=input1_lang.n_words, input2_size=input2_lang.n_words, 101 | embed_model=embed_model, embedding1_size=embedding_size, embedding2_size=embedding_size//4, 102 | hidden_size=hidden_size, n_layers=n_layers, hop_size=hop_size) 103 | numencoder = NumEncoder(node_dim=hidden_size, hop_size=hop_size) 104 | predict = Prediction(hidden_size=hidden_size, op_nums=output1_lang.n_words - copy_nums - 1 - len(generate_nums), 105 | input_size=len(generate_nums)) 106 | generate = GenerateNode(hidden_size=hidden_size, op_nums=output1_lang.n_words - copy_nums - 1 - len(generate_nums), 107 | embedding_size=embedding_size) 108 | merge = Merge(hidden_size=hidden_size, embedding_size=embedding_size) 109 | decoder = AttnDecoderRNN(hidden_size=hidden_size, embedding_size=embedding_size, 110 | input_size=output2_lang.n_words, output_size=output2_lang.n_words, n_layers=n_layers) 111 | # the embedding layer is only for generated number embeddings, operators, and paddings 112 | 113 | encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=learning_rate, weight_decay=weight_decay) 114 | numencoder_optimizer = torch.optim.Adam(numencoder.parameters(), lr=learning_rate, weight_decay=weight_decay) 115 | predict_optimizer = torch.optim.Adam(predict.parameters(), lr=learning_rate, weight_decay=weight_decay) 116 | generate_optimizer = torch.optim.Adam(generate.parameters(), lr=learning_rate, weight_decay=weight_decay) 117 | merge_optimizer = torch.optim.Adam(merge.parameters(), lr=learning_rate, weight_decay=weight_decay) 118 | decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate, weight_decay=weight_decay) 119 | 120 | encoder_scheduler = torch.optim.lr_scheduler.StepLR(encoder_optimizer, step_size=20, gamma=0.5) 121 | numencoder_scheduler = torch.optim.lr_scheduler.StepLR(numencoder_optimizer, step_size=20, gamma=0.5) 122 | predict_scheduler = torch.optim.lr_scheduler.StepLR(predict_optimizer, step_size=20, gamma=0.5) 123 | generate_scheduler = torch.optim.lr_scheduler.StepLR(generate_optimizer, step_size=20, gamma=0.5) 124 | merge_scheduler = torch.optim.lr_scheduler.StepLR(merge_optimizer, step_size=20, gamma=0.5) 125 | decoder_scheduler = torch.optim.lr_scheduler.StepLR(decoder_optimizer, step_size=20, gamma=0.5) 126 | 127 | # Move models to GPU 128 | if USE_CUDA: 129 | encoder.cuda() 130 | numencoder.cuda() 131 | predict.cuda() 132 | generate.cuda() 133 | merge.cuda() 134 | decoder.cuda() 135 | 136 | elogger.log(str(encoder)) 137 | elogger.log(str(numencoder)) 138 | elogger.log(str(predict)) 139 | elogger.log(str(generate)) 140 | elogger.log(str(merge)) 141 | elogger.log(str(decoder)) 142 | 143 | generate_num1_ids = [] 144 | generate_num2_ids = [] 145 | for num in generate_nums: 146 | generate_num1_ids.append(output1_lang.word2index[num]) 147 | generate_num2_ids.append(output2_lang.word2index[num]) 148 | 149 | for epoch in range(n_epochs): 150 | loss_total = 0 151 | id_batches, input1_batches, input2_batches, input_lengths, output1_batches, output1_lengths, output2_batches, output2_lengths, \ 152 | nums_batches, num_stack_batches, num_pos_batches, num_order_batches, num_size_batches, parse_graph_batches = prepare_train_batch(train_pairs, batch_size) 153 | print("fold:", fold + 1) 154 | print("epoch:", epoch + 1) 155 | start = time.time() 156 | for idx in range(len(input_lengths)): 157 | loss = train_double( 158 | input1_batches[idx], input2_batches[idx], input_lengths[idx], output1_batches[idx], output1_lengths[idx], output2_batches[idx], output2_lengths[idx], 159 | num_stack_batches[idx], num_size_batches[idx], generate_num1_ids, generate_num2_ids, copy_nums, 160 | encoder, numencoder, predict, generate, merge, decoder, 161 | encoder_optimizer, numencoder_optimizer, predict_optimizer, generate_optimizer, merge_optimizer, decoder_optimizer, 162 | input1_lang, output1_lang, output2_lang, num_pos_batches[idx], num_order_batches[idx], parse_graph_batches[idx], 163 | beam_size=5, use_teacher_forcing=0.83, english=False) 164 | loss_total += loss 165 | 166 | print("loss:", loss_total / len(input_lengths)) 167 | print("training time", time_since(time.time() - start)) 168 | print("--------------------------------") 169 | elogger.log("epoch: %d, loss: %.4f" % (epoch+1, loss_total/len(input_lengths))) 170 | 171 | if epoch % 10 == 0 or epoch > n_epochs - 5: 172 | value_ac = 0 173 | equation_ac = 0 174 | eval_total = 0 175 | result_list = [] 176 | start = time.time() 177 | for test_batch in test_pairs: 178 | parse_graph = get_parse_graph_batch([test_batch[5]], [test_batch[4]]) 179 | result_type, test_res, score = evaluate_double(test_batch[2], test_batch[3], test_batch[5], generate_num1_ids, generate_num2_ids, 180 | encoder, numencoder, predict, generate, merge, decoder, 181 | input1_lang, output1_lang, output2_lang, test_batch[11], test_batch[13], parse_graph, beam_size=beam_size) 182 | if result_type == "tree": 183 | val_ac, equ_ac, _, _ = compute_prefix_tree_result(test_res, test_batch[6], output1_lang, test_batch[10], test_batch[12]) 184 | result = out_expression_list(test_res, output1_lang, test_batch[10]) 185 | result_list.append([test_batch[0], "tree", result, score]) 186 | else: 187 | if test_res[-1] == output2_lang.word2index["EOS"]: 188 | test_res = test_res[:-1] 189 | val_ac, equ_ac, _, _ = compute_postfix_tree_result(test_res, test_batch[8][:-1], output2_lang, test_batch[10], test_batch[12]) 190 | result = out_expression_list(test_res, output2_lang, test_batch[10]) 191 | result_list.append([test_batch[0], "attn", result, score]) 192 | 193 | if val_ac: 194 | value_ac += 1 195 | if equ_ac: 196 | equation_ac += 1 197 | eval_total += 1 198 | print(equation_ac, value_ac, eval_total) 199 | print("test_answer_acc", float(equation_ac) / eval_total, float(value_ac) / eval_total) 200 | print("testing time", time_since(time.time() - start)) 201 | print("------------------------------------------------------") 202 | torch.save(encoder.state_dict(), "models_"+str(fold)+"/encoder") 203 | torch.save(numencoder.state_dict(), "models_"+str(fold)+"/numencoder") 204 | torch.save(predict.state_dict(), "models_"+str(fold)+"/predict") 205 | torch.save(generate.state_dict(), "models_"+str(fold)+"/generate") 206 | torch.save(merge.state_dict(), "models_"+str(fold)+"/merge") 207 | torch.save(decoder.state_dict(), "models_"+str(fold)+"/decoder") 208 | write_data_json(result_list, "results/result_"+str(fold)+".json") 209 | elogger.log("epoch: %d, test_equ_acc: %.4f, test_ans_acc: %.4f" \ 210 | % (epoch+1, float(equation_ac)/eval_total, float(value_ac)/eval_total)) 211 | 212 | if epoch == n_epochs - 1: 213 | best_acc_fold.append((equation_ac, value_ac, eval_total)) 214 | 215 | encoder_scheduler.step() 216 | numencoder_scheduler.step() 217 | predict_scheduler.step() 218 | generate_scheduler.step() 219 | merge_scheduler.step() 220 | decoder_scheduler.step() 221 | 222 | 223 | def test(): 224 | data = load_raw_data("data/Math_23K.json") 225 | pairs, generate_nums, copy_nums = transfer_num(data) 226 | 227 | fold = 0 228 | pairs_trained = read_data_json("data/train_"+str(fold)+".json") 229 | pairs_tested = read_data_json("data/test_"+str(fold)+".json") 230 | 231 | input1_lang, input2_lang, output1_lang, output2_lang, train_pairs, test_pairs = prepare_data(pairs_trained, pairs_tested, 5, generate_nums, copy_nums) 232 | 233 | emb_vectors = np.load("data/emb_"+str(fold)+".npy") 234 | embed_model = nn.Embedding(input1_lang.n_words, embedding_size) 235 | embed_model.weight.data.copy_(torch.from_numpy(emb_vectors)) 236 | 237 | # Initialize models 238 | encoder = EncoderSeq(input1_size=input1_lang.n_words, input2_size=input2_lang.n_words, 239 | embed_model=embed_model, embedding1_size=embedding_size, embedding2_size=embedding_size//4, 240 | hidden_size=hidden_size, n_layers=n_layers, hop_size=hop_size) 241 | numencoder = NumEncoder(node_dim=hidden_size, hop_size=hop_size) 242 | predict = Prediction(hidden_size=hidden_size, op_nums=output1_lang.n_words - copy_nums - 1 - len(generate_nums), 243 | input_size=len(generate_nums)) 244 | generate = GenerateNode(hidden_size=hidden_size, op_nums=output1_lang.n_words - copy_nums - 1 - len(generate_nums), 245 | embedding_size=embedding_size) 246 | merge = Merge(hidden_size=hidden_size, embedding_size=embedding_size) 247 | decoder = AttnDecoderRNN(hidden_size=hidden_size, embedding_size=embedding_size, 248 | input_size=output2_lang.n_words, output_size=output2_lang.n_words, n_layers=n_layers) 249 | 250 | encoder.load_state_dict(torch.load("models_"+str(fold)+"/encoder", map_location="cpu")) 251 | numencoder.load_state_dict(torch.load("models_"+str(fold)+"/numencoder", map_location="cpu")) 252 | predict.load_state_dict(torch.load("models_"+str(fold)+"/predict", map_location="cpu")) 253 | generate.load_state_dict(torch.load("models_"+str(fold)+"/generate", map_location="cpu")) 254 | merge.load_state_dict(torch.load("models_"+str(fold)+"/merge", map_location="cpu")) 255 | decoder.load_state_dict(torch.load("models_"+str(fold)+"/decoder", map_location="cpu")) 256 | 257 | if USE_CUDA: 258 | encoder.cuda() 259 | numencoder.cuda() 260 | predict.cuda() 261 | generate.cuda() 262 | merge.cuda() 263 | decoder.cuda() 264 | 265 | generate_num1_ids = [] 266 | generate_num2_ids = [] 267 | for num in generate_nums: 268 | generate_num1_ids.append(output1_lang.word2index[num]) 269 | generate_num2_ids.append(output2_lang.word2index[num]) 270 | 271 | pair = pairs_tested[211][:] 272 | pair[1] = ['快车', '每', '小时', '行驶', 'NUM', '千米', ',', '慢车', '每', '小时', '行驶', 'NUM', '千米', ',', 273 | '两车', '相向', '而', '行', ',', '经过', 'NUM', '小时', '相遇', ',', '相遇', '时', '快车', '比', '慢车', '多行', '多少', '千米', '?'] 274 | postags = postagger.postag(pair[1]) 275 | postags = ' '.join(postags).split(' ') 276 | arcs = parser.parse(pair[1], postags) 277 | parse_tree = [arc.head-1 for arc in arcs] 278 | pair[2] = postags 279 | pair[3] = parse_tree 280 | pair[4] = ['*', '-', 'N0', 'N1', 'N2'] 281 | pair[5] = ['N0', 'N1', '-', 'N2', '*'] 282 | pair[6] = ['85', '58', '5'] 283 | pair[7] = [4, 11, 20] 284 | # 285 | # pair = pairs_tested[211][:] 286 | # pair[1] = ['慢车', '每', '小时', '行驶', 'NUM', '千米', ',', '快车', '每', '小时', '行驶', 'NUM', '千米', ',', 287 | # '两车', '相向', '而', '行', ',', '经过', 'NUM', '小时', '相遇', ',', '相遇', '时', '慢车', '比', '快车', '少行', '多少', '千米', '?'] 288 | # postags = postagger.postag(pair[1]) 289 | # postags = ' '.join(postags).split(' ') 290 | # arcs = parser.parse(pair[1], postags) 291 | # parse_tree = [arc.head-1 for arc in arcs] 292 | # pair[2] = postags 293 | # pair[3] = parse_tree 294 | 295 | # pair = pairs_tested[45][:] 296 | # pair[1] = ['妈妈', '有', 'NUM', '米', '蓝', '带子', ',', 297 | # 'NUM', '米', '红带子', '.', '蓝', '带子', '是', '红带子', '的', '几分', '之' '几', '?'] 298 | # pair[2] = ['/', 'N0', 'N1'] 299 | # pair[3] = ['N0', 'N1', '/'] 300 | # pair[4] = ['3', '12'] 301 | # pair[5] = [2, 7] 302 | 303 | # pair = pairs_tested[45][:] 304 | # pair[1] = ['妈妈', '有', 'NUM', '米', '蓝', '带子', ',', 305 | # 'NUM', '米', '红带子', '.', '蓝', '带子', '的', '长', '是', '红带子', '的', '几倍', '?'] 306 | # pair[2] = ['/', 'N0', 'N1'] 307 | # pair[3] = ['N0', 'N1', '/'] 308 | # pair[4] = ['12', '3'] 309 | # pair[5] = [2, 7] 310 | 311 | test_pairs = [] 312 | num_stack = [] 313 | for word in pair[4]: 314 | temp_num = [] 315 | flag_not = True 316 | if word not in output1_lang.index2word: 317 | flag_not = False 318 | for i, j in enumerate(pair[6]): 319 | if j == word: 320 | temp_num.append(i) 321 | 322 | if not flag_not and len(temp_num) != 0: 323 | num_stack.append(temp_num) 324 | if not flag_not and len(temp_num) == 0: 325 | num_stack.append([_ for _ in range(len(pair[6]))]) 326 | 327 | num_stack.reverse() 328 | input1_cell = indexes_from_sentence(input1_lang, pair[1]) 329 | texts_cell = texts_from_sentence(input1_lang, pair[1]) 330 | input2_cell = indexes_from_sentence(input2_lang, pair[2]) 331 | output1_cell = indexes_from_sentence(output1_lang, pair[4], True) 332 | output2_cell = indexes_from_sentence(output2_lang, pair[5], False) 333 | num_list = num_list_processed(pair[6]) 334 | num_order = num_order_processed(num_list) 335 | test_pairs.append((pair[0], texts_cell, input1_cell, input2_cell, pair[3], len(input1_cell), 336 | output1_cell, len(output1_cell), output2_cell, len(output2_cell), 337 | pair[6], pair[7], num_stack, num_order)) 338 | 339 | for test_batch in test_pairs: 340 | parse_graph = get_parse_graph_batch([test_batch[5]], [test_batch[4]]) 341 | result_type, test_res, score = evaluate_double(test_batch[2], test_batch[3], test_batch[5], generate_num1_ids, generate_num2_ids, 342 | encoder, numencoder, predict, generate, merge, decoder, 343 | input1_lang, output1_lang, output2_lang, test_batch[11], test_batch[13], parse_graph, beam_size=beam_size) 344 | if result_type == "tree": 345 | val_ac, equ_ac, _, _ = compute_prefix_tree_result(test_res, test_batch[6], output1_lang, test_batch[10], test_batch[12]) 346 | result = out_expression_list(test_res, output1_lang, test_batch[10]) 347 | else: 348 | if test_res[-1] == output2_lang.word2index["EOS"]: 349 | test_res = test_res[:-1] 350 | val_ac, equ_ac, _, _ = compute_postfix_tree_result(test_res, test_batch[8][:-1], output2_lang, test_batch[10], test_batch[12]) 351 | result = out_expression_list(test_res, output2_lang, test_batch[10]) 352 | print(result) 353 | 354 | if __name__ == '__main__': 355 | # train(0) 356 | # train(1) 357 | # train(2) 358 | # train(3) 359 | # train(4) 360 | # test() 361 | print('test') 362 | -------------------------------------------------------------------------------- /src/expressions_transfer.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import re 3 | 4 | 5 | # An expression tree node 6 | class Et: 7 | # Constructor to create a node 8 | def __init__(self, value): 9 | self.value = value 10 | self.left = None 11 | self.right = None 12 | 13 | 14 | # Returns root of constructed tree for given postfix expression 15 | def construct_exp_tree(postfix): 16 | stack = [] 17 | 18 | # Traverse through every character of input expression 19 | for char in postfix: 20 | 21 | # if operand, simply push into stack 22 | if char not in ["+", "-", "*", "/", "^"]: 23 | t = Et(char) 24 | stack.append(t) 25 | # Operator 26 | else: 27 | # Pop two top nodes 28 | t = Et(char) 29 | t1 = stack.pop() 30 | t2 = stack.pop() 31 | 32 | # make them children 33 | t.right = t1 34 | t.left = t2 35 | 36 | # Add this subexpression to stack 37 | stack.append(t) 38 | # Only element will be the root of expression tree 39 | t = stack.pop() 40 | return t 41 | 42 | 43 | def from_infix_to_postfix(expression): 44 | st = list() 45 | res = list() 46 | priority = {"+": 0, "-": 0, "*": 1, "/": 1, "^": 2} 47 | for e in expression: 48 | if e in ["(", "["]: 49 | st.append(e) 50 | elif e == ")": 51 | c = st.pop() 52 | while c != "(": 53 | res.append(c) 54 | c = st.pop() 55 | elif e == "]": 56 | c = st.pop() 57 | while c != "[": 58 | res.append(c) 59 | c = st.pop() 60 | elif e in priority: 61 | while len(st) > 0 and st[-1] not in ["(", "["] and priority[e] <= priority[st[-1]]: 62 | res.append(st.pop()) 63 | st.append(e) 64 | else: 65 | res.append(e) 66 | while len(st) > 0: 67 | res.append(st.pop()) 68 | return res 69 | 70 | 71 | def from_infix_to_prefix(expression): 72 | st = list() 73 | res = list() 74 | priority = {"+": 0, "-": 0, "*": 1, "/": 1, "^": 2} 75 | expression = deepcopy(expression) 76 | expression.reverse() 77 | for e in expression: 78 | if e in [")", "]"]: 79 | st.append(e) 80 | elif e == "(": 81 | c = st.pop() 82 | while c != ")": 83 | res.append(c) 84 | c = st.pop() 85 | elif e == "[": 86 | c = st.pop() 87 | while c != "]": 88 | res.append(c) 89 | c = st.pop() 90 | elif e in priority: 91 | while len(st) > 0 and st[-1] not in [")", "]"] and priority[e] < priority[st[-1]]: 92 | res.append(st.pop()) 93 | st.append(e) 94 | else: 95 | res.append(e) 96 | while len(st) > 0: 97 | res.append(st.pop()) 98 | res.reverse() 99 | return res 100 | 101 | 102 | def out_expression_list(test, output_lang, num_list, num_stack=None): 103 | max_index = output_lang.n_words 104 | res = [] 105 | for i in test: 106 | # if i == 0: 107 | # return res 108 | if i < max_index - 1: 109 | idx = output_lang.index2word[i] 110 | if idx[0] == "N": 111 | if int(idx[1:]) >= len(num_list): 112 | return None 113 | res.append(num_list[int(idx[1:])]) 114 | else: 115 | res.append(idx) 116 | else: 117 | pos_list = num_stack.pop() 118 | c = num_list[pos_list[0]] 119 | res.append(c) 120 | return res 121 | 122 | 123 | def compute_postfix_expression(post_fix): 124 | st = list() 125 | operators = ["+", "-", "^", "*", "/"] 126 | for p in post_fix: 127 | if p not in operators: 128 | pos1 = re.search("\d+\(", p) 129 | pos2 = re.search("\)\d+", p) 130 | if pos1: 131 | st.append(eval(p[pos1.start(): pos1.end() - 1] + "+" + p[pos1.end() - 1:])) 132 | elif pos2: 133 | st.append(eval(p[:pos2.start() + 1] + "+" + p[pos2.start() + 1: pos2.end()])) 134 | # pos = re.search("\d+\(", p) 135 | # if pos: 136 | # st.append(eval(p[pos.start(): pos.end() - 1] + "+" + p[pos.end() - 1:])) 137 | elif p[-1] == "%": 138 | st.append(float(p[:-1]) / 100) 139 | else: 140 | st.append(eval(p)) 141 | elif p == "+" and len(st) > 1: 142 | a = st.pop() 143 | b = st.pop() 144 | st.append(a + b) 145 | elif p == "*" and len(st) > 1: 146 | a = st.pop() 147 | b = st.pop() 148 | st.append(a * b) 149 | elif p == "*" and len(st) > 1: 150 | a = st.pop() 151 | b = st.pop() 152 | st.append(a * b) 153 | elif p == "/" and len(st) > 1: 154 | a = st.pop() 155 | b = st.pop() 156 | if a == 0: 157 | return None 158 | st.append(b / a) 159 | elif p == "-" and len(st) > 1: 160 | a = st.pop() 161 | b = st.pop() 162 | st.append(b - a) 163 | elif p == "^" and len(st) > 1: 164 | a = st.pop() 165 | b = st.pop() 166 | st.append(a ** b) 167 | else: 168 | return None 169 | if len(st) == 1: 170 | return st.pop() 171 | return None 172 | 173 | 174 | def compute_prefix_expression(pre_fix): 175 | st = list() 176 | operators = ["+", "-", "^", "*", "/"] 177 | pre_fix = deepcopy(pre_fix) 178 | pre_fix.reverse() 179 | for p in pre_fix: 180 | if p not in operators: 181 | pos1 = re.search("\d+\(", p) 182 | pos2 = re.search("\)\d+", p) 183 | if pos1: 184 | st.append(eval(p[pos1.start(): pos1.end() - 1] + "+" + p[pos1.end() - 1:])) 185 | elif pos2: 186 | st.append(eval(p[:pos2.start() + 1] + "+" + p[pos2.start() + 1: pos2.end()])) 187 | # pos = re.search("\d+\(", p) 188 | # if pos: 189 | # st.append(eval(p[pos.start(): pos.end() - 1] + "+" + p[pos.end() - 1:])) 190 | elif p[-1] == "%": 191 | st.append(float(p[:-1]) / 100) 192 | else: 193 | st.append(eval(p)) 194 | elif p == "+" and len(st) > 1: 195 | a = st.pop() 196 | b = st.pop() 197 | st.append(a + b) 198 | elif p == "*" and len(st) > 1: 199 | a = st.pop() 200 | b = st.pop() 201 | st.append(a * b) 202 | elif p == "/" and len(st) > 1: 203 | a = st.pop() 204 | b = st.pop() 205 | if b == 0: 206 | return None 207 | st.append(a / b) 208 | elif p == "-" and len(st) > 1: 209 | a = st.pop() 210 | b = st.pop() 211 | st.append(a - b) 212 | elif p == "^" and len(st) > 1: 213 | a = st.pop() 214 | b = st.pop() 215 | # if float(eval(b)) != 2.0 or float(eval(b)) != 3.0: 216 | # return None 217 | st.append(a ** b) 218 | else: 219 | return None 220 | if len(st) == 1: 221 | return st.pop() 222 | return None 223 | 224 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | class Logger: 2 | def __init__(self, exp_name): 3 | self.file = open("./logs/{}.log".format(exp_name), 'w') 4 | 5 | def log(self, content): 6 | self.file.write(content + "\n") 7 | self.file.flush() -------------------------------------------------------------------------------- /src/masked_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional 3 | 4 | 5 | def sequence_mask(sequence_length, max_len=None): 6 | if max_len is None: 7 | max_len = sequence_length.data.max() 8 | batch_size = sequence_length.size(0) 9 | seq_range = torch.arange(0, max_len).long() 10 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 11 | if sequence_length.is_cuda: 12 | seq_range_expand = seq_range_expand.cuda() 13 | seq_length_expand = (sequence_length.unsqueeze(1).expand_as(seq_range_expand)) 14 | return seq_range_expand < seq_length_expand 15 | 16 | 17 | def masked_cross_entropy(logits, target, length): 18 | if torch.cuda.is_available(): 19 | length = torch.LongTensor(length).cuda() 20 | else: 21 | length = torch.LongTensor(length) 22 | """ 23 | Args: 24 | logits: A Variable containing a FloatTensor of size 25 | (batch, max_len, num_classes) which contains the 26 | unnormalized probability for each class. 27 | target: A Variable containing a LongTensor of size 28 | (batch, max_len) which contains the index of the true 29 | class for each corresponding step. 30 | length: A Variable containing a LongTensor of size (batch,) 31 | which contains the length of each data in a batch. 32 | Returns: 33 | loss: An average loss value masked by the length. 34 | """ 35 | 36 | # logits_flat: (batch * max_len, num_classes) 37 | logits_flat = logits.view(-1, logits.size(-1)) 38 | # log_probs_flat: (batch * max_len, num_classes) 39 | log_probs_flat = functional.log_softmax(logits_flat, dim=1) 40 | # target_flat: (batch * max_len, 1) 41 | target_flat = target.view(-1, 1) 42 | # losses_flat: (batch * max_len, 1) 43 | losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) 44 | 45 | # losses: (batch, max_len) 46 | losses = losses_flat.view(*target.size()) 47 | # mask: (batch, max_len) 48 | mask = sequence_mask(sequence_length=length, max_len=target.size(1)) 49 | losses = losses * mask.float() 50 | loss = losses.sum() / length.float().sum() 51 | # if loss.item() > 10: 52 | # print(losses, target) 53 | return loss 54 | 55 | 56 | def masked_cross_entropy_without_logit(logits, target, length): 57 | if torch.cuda.is_available(): 58 | length = torch.LongTensor(length).cuda() 59 | else: 60 | length = torch.LongTensor(length) 61 | """ 62 | Args: 63 | logits: A Variable containing a FloatTensor of size 64 | (batch, max_len, num_classes) which contains the 65 | unnormalized probability for each class. 66 | target: A Variable containing a LongTensor of size 67 | (batch, max_len) which contains the index of the true 68 | class for each corresponding step. 69 | length: A Variable containing a LongTensor of size (batch,) 70 | which contains the length of each data in a batch. 71 | Returns: 72 | loss: An average loss value masked by the length. 73 | """ 74 | 75 | # logits_flat: (batch * max_len, num_classes) 76 | logits_flat = logits.view(-1, logits.size(-1)) 77 | 78 | # log_probs_flat: (batch * max_len, num_classes) 79 | log_probs_flat = torch.log(logits_flat + 1e-12) 80 | 81 | # target_flat: (batch * max_len, 1) 82 | target_flat = target.view(-1, 1) 83 | # losses_flat: (batch * max_len, 1) 84 | losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) 85 | 86 | # losses: (batch, max_len) 87 | losses = losses_flat.view(*target.size()) 88 | 89 | # mask: (batch, max_len) 90 | mask = sequence_mask(sequence_length=length, max_len=target.size(1)) 91 | losses = losses * mask.float() 92 | loss = losses.sum() / length.float().sum() 93 | # if loss.item() > 10: 94 | # print(losses, target) 95 | return loss 96 | 97 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def replace_masked_values(tensor, mask, replace_with): 7 | return tensor.masked_fill((1 - mask).bool(), replace_with) 8 | 9 | def clones(module, N): 10 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 11 | 12 | 13 | class EncoderRNN(nn.Module): 14 | def __init__(self, input_size, embedding_size, hidden_size, n_layers=2, dropout=0.5): 15 | super(EncoderRNN, self).__init__() 16 | 17 | self.input_size = input_size 18 | self.embedding_size = embedding_size 19 | self.hidden_size = hidden_size 20 | self.n_layers = n_layers 21 | self.dropout = dropout 22 | 23 | self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=0) 24 | self.em_dropout = nn.Dropout(dropout) 25 | self.gru = nn.GRU(embedding_size, hidden_size, n_layers, dropout=dropout, bidirectional=True) 26 | 27 | def forward(self, input_seqs, input_lengths, hidden=None): 28 | # Note: we run this all at once (over multiple batches of multiple sequences) 29 | embedded = self.embedding(input_seqs) # S x B x E 30 | embedded = self.em_dropout(embedded) 31 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths) 32 | outputs, hidden = self.gru(packed, hidden) 33 | outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded) 34 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] # Sum bidirectional outputs 35 | # S x B x H 36 | return outputs, hidden 37 | 38 | 39 | class Attn(nn.Module): 40 | def __init__(self, hidden_size): 41 | super(Attn, self).__init__() 42 | self.hidden_size = hidden_size 43 | self.attn = nn.Linear(hidden_size * 2, hidden_size) 44 | self.score = nn.Linear(hidden_size, 1, bias=False) 45 | self.softmax = nn.Softmax(dim=1) 46 | 47 | def forward(self, hidden, encoder_outputs, seq_mask=None): 48 | max_len = encoder_outputs.size(0) 49 | repeat_dims = [1] * hidden.dim() 50 | repeat_dims[0] = max_len 51 | hidden = hidden.repeat(*repeat_dims) # S x B x H 52 | # For each position of encoder outputs 53 | this_batch_size = encoder_outputs.size(1) 54 | energy_in = torch.cat((hidden, encoder_outputs), 2).view(-1, 2 * self.hidden_size) 55 | attn_energies = self.score(torch.tanh(self.attn(energy_in))) # (S x B) x 1 56 | attn_energies = attn_energies.squeeze(1) 57 | attn_energies = attn_energies.view(max_len, this_batch_size).transpose(0, 1) # B x S 58 | if seq_mask is not None: 59 | attn_energies = attn_energies.masked_fill_(seq_mask.bool(), -1e12) 60 | attn_energies = self.softmax(attn_energies) 61 | # Normalize energies to weights in range 0 to 1, resize to B x 1 x S 62 | return attn_energies.unsqueeze(1) 63 | 64 | 65 | class AttnDecoderRNN(nn.Module): 66 | def __init__( 67 | self, hidden_size, embedding_size, input_size, output_size, n_layers=2, dropout=0.5): 68 | super(AttnDecoderRNN, self).__init__() 69 | 70 | # Keep for reference 71 | self.embedding_size = embedding_size 72 | self.hidden_size = hidden_size 73 | self.input_size = input_size 74 | self.output_size = output_size 75 | self.n_layers = n_layers 76 | self.dropout = dropout 77 | 78 | # Define layers 79 | self.em_dropout = nn.Dropout(dropout) 80 | self.embedding = nn.Embedding(input_size, embedding_size, padding_idx=0) 81 | self.gru = nn.GRU(hidden_size + embedding_size, hidden_size, n_layers, dropout=dropout) 82 | self.concat = nn.Linear(hidden_size * 2, hidden_size) 83 | self.out = nn.Linear(hidden_size, output_size) 84 | # Choose attention model 85 | self.attn = Attn(hidden_size) 86 | 87 | def forward(self, input_seq, last_hidden, encoder_outputs, seq_mask): 88 | # Get the embedding of the current input word (last output word) 89 | batch_size = input_seq.size(0) 90 | embedded = self.embedding(input_seq) 91 | embedded = self.em_dropout(embedded) 92 | embedded = embedded.view(1, batch_size, self.embedding_size) # S=1 x B x N 93 | 94 | # Calculate attention from current RNN state and all encoder outputs; 95 | # apply to encoder outputs to get weighted average 96 | attn_weights = self.attn(last_hidden[-1].unsqueeze(0), encoder_outputs, seq_mask) 97 | context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # B x S=1 x N 98 | 99 | # Get current hidden state from input word and last hidden state 100 | rnn_output, hidden = self.gru(torch.cat((embedded, context.transpose(0, 1)), 2), last_hidden) 101 | 102 | # Attentional vector using the RNN hidden state and context vector 103 | # concatenated together (Luong eq. 5) 104 | output = self.out(torch.tanh(self.concat(torch.cat((rnn_output.squeeze(0), context.squeeze(1)), 1)))) 105 | 106 | # Return final output, hidden state 107 | return output, hidden 108 | 109 | 110 | class TreeNode: # the class save the tree node 111 | def __init__(self, embedding, left_flag=False): 112 | self.embedding = embedding 113 | self.left_flag = left_flag 114 | 115 | 116 | class Score(nn.Module): 117 | def __init__(self, input_size, hidden_size): 118 | super(Score, self).__init__() 119 | self.input_size = input_size 120 | self.hidden_size = hidden_size 121 | self.attn = nn.Linear(hidden_size + input_size, hidden_size) 122 | self.score = nn.Linear(hidden_size, 1, bias=False) 123 | 124 | def forward(self, hidden, num_embeddings, num_mask=None): 125 | max_len = num_embeddings.size(1) 126 | repeat_dims = [1] * hidden.dim() 127 | repeat_dims[1] = max_len 128 | hidden = hidden.repeat(*repeat_dims) # B x O x H 129 | # For each position of encoder outputs 130 | this_batch_size = num_embeddings.size(0) 131 | energy_in = torch.cat((hidden, num_embeddings), 2).view(-1, self.input_size + self.hidden_size) 132 | score = self.score(torch.tanh(self.attn(energy_in))) # (B x O) x 1 133 | score = score.squeeze(1) 134 | score = score.view(this_batch_size, -1) # B x O 135 | if num_mask is not None: 136 | score = score.masked_fill_(num_mask.bool(), -1e12) 137 | return score 138 | 139 | 140 | class TreeAttn(nn.Module): 141 | def __init__(self, input_size, hidden_size): 142 | super(TreeAttn, self).__init__() 143 | self.input_size = input_size 144 | self.hidden_size = hidden_size 145 | self.attn = nn.Linear(hidden_size + input_size, hidden_size) 146 | self.score = nn.Linear(hidden_size, 1) 147 | 148 | def forward(self, hidden, encoder_outputs, seq_mask=None): 149 | max_len = encoder_outputs.size(0) 150 | 151 | repeat_dims = [1] * hidden.dim() 152 | repeat_dims[0] = max_len 153 | hidden = hidden.repeat(*repeat_dims) # S x B x H 154 | this_batch_size = encoder_outputs.size(1) 155 | 156 | energy_in = torch.cat((hidden, encoder_outputs), 2).view(-1, self.input_size + self.hidden_size) 157 | 158 | score_feature = torch.tanh(self.attn(energy_in)) 159 | attn_energies = self.score(score_feature) # (S x B) x 1 160 | attn_energies = attn_energies.squeeze(1) 161 | attn_energies = attn_energies.view(max_len, this_batch_size).transpose(0, 1) # B x S 162 | if seq_mask is not None: 163 | attn_energies = attn_energies.masked_fill_(seq_mask.bool(), -1e12) 164 | attn_energies = nn.functional.softmax(attn_energies, dim=1) # B x S 165 | 166 | return attn_energies.unsqueeze(1) 167 | 168 | 169 | class EncoderSeq(nn.Module): 170 | def __init__(self, input1_size, input2_size, embed_model, embedding1_size, 171 | embedding2_size, hidden_size, n_layers=2, hop_size=2, dropout=0.5): 172 | super(EncoderSeq, self).__init__() 173 | 174 | self.input1_size = input1_size 175 | self.input2_size = input2_size 176 | self.embedding1_size = embedding1_size 177 | self.embedding2_size = embedding2_size 178 | self.hidden_size = hidden_size 179 | self.n_layers = n_layers 180 | self.dropout = dropout 181 | self.hop_size = hop_size 182 | 183 | self.embedding1 = embed_model 184 | self.embedding2 = nn.Embedding(input2_size, embedding2_size, padding_idx=0) 185 | self.em_dropout = nn.Dropout(dropout) 186 | self.gru = nn.GRU(embedding1_size+embedding2_size, hidden_size, n_layers, dropout=dropout, bidirectional=True) 187 | self.parse_gnn = clones(Parse_Graph_Module(hidden_size), hop_size) 188 | 189 | def forward(self, input1_var, input2_var, input_length, parse_graph, hidden=None): 190 | # Note: we run this all at once (over multiple batches of multiple sequences) 191 | embedded1 = self.embedding1(input1_var) # S x B x E 192 | embedded2 = self.embedding2(input2_var) 193 | embedded = torch.cat((embedded1, embedded2), dim=2) 194 | embedded = self.em_dropout(embedded) 195 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_length) 196 | pade_hidden = hidden 197 | pade_outputs, pade_hidden = self.gru(packed, pade_hidden) 198 | pade_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(pade_outputs) 199 | 200 | pade_outputs = pade_outputs[:, :, :self.hidden_size] + pade_outputs[:, :, self.hidden_size:] # S x B x H 201 | pade_outputs = pade_outputs.transpose(0, 1) 202 | for i in range(self.hop_size): 203 | pade_outputs = self.parse_gnn[i](pade_outputs, parse_graph[:,2]) 204 | pade_outputs = pade_outputs.transpose(0, 1) 205 | # problem_output = pade_outputs[-1, :, :self.hidden_size] + pade_outputs[0, :, self.hidden_size:] 206 | 207 | return pade_outputs, pade_hidden 208 | 209 | 210 | class Parse_Graph_Module(nn.Module): 211 | def __init__(self, hidden_size): 212 | super(Parse_Graph_Module, self).__init__() 213 | 214 | self.hidden_size = hidden_size 215 | self.node_fc1 = nn.Linear(hidden_size, hidden_size) 216 | self.node_fc2 = nn.Linear(hidden_size, hidden_size) 217 | self.node_out = nn.Linear(hidden_size * 2, hidden_size) 218 | 219 | def normalize(self, graph, symmetric=True): 220 | d = graph.sum(1) 221 | if symmetric: 222 | D = torch.diag(torch.pow(d, -0.5)) 223 | return D.mm(graph).mm(D) 224 | else : 225 | D = torch.diag(torch.pow(d,-1)) 226 | return D.mm(graph) 227 | 228 | def forward(self, node, graph): 229 | graph = graph.float() 230 | batch_size = node.size(0) 231 | for i in range(batch_size): 232 | graph[i] = self.normalize(graph[i]) 233 | 234 | node_info = torch.relu(self.node_fc1(torch.matmul(graph, node))) 235 | node_info = torch.relu(self.node_fc2(torch.matmul(graph, node_info))) 236 | 237 | agg_node_info = torch.cat((node, node_info), dim=2) 238 | agg_node_info = torch.relu(self.node_out(agg_node_info)) 239 | 240 | return agg_node_info 241 | 242 | 243 | class NumEncoder(nn.Module): 244 | def __init__(self, node_dim, hop_size=2): 245 | super(NumEncoder, self).__init__() 246 | 247 | self.node_dim = node_dim 248 | self.hop_size = hop_size 249 | self.num_gnn = clones(Num_Graph_Module(node_dim), hop_size) 250 | 251 | def forward(self, encoder_outputs, num_encoder_outputs, num_pos_pad, num_order_pad): 252 | num_embedding = num_encoder_outputs.clone() 253 | batch_size = num_embedding.size(0) 254 | num_mask = (num_pos_pad > -1).long() 255 | node_mask = (num_order_pad > 0).long() 256 | greater_graph_mask = num_order_pad.unsqueeze(-1).expand(batch_size, -1, num_order_pad.size(-1)) > \ 257 | num_order_pad.unsqueeze(1).expand(batch_size, num_order_pad.size(-1), -1) 258 | lower_graph_mask = num_order_pad.unsqueeze(-1).expand(batch_size, -1, num_order_pad.size(-1)) <= \ 259 | num_order_pad.unsqueeze(1).expand(batch_size, num_order_pad.size(-1), -1) 260 | greater_graph_mask = greater_graph_mask.long() 261 | lower_graph_mask = lower_graph_mask.long() 262 | 263 | diagmat = torch.diagflat(torch.ones(num_embedding.size(1), dtype=torch.long, device=num_embedding.device)) 264 | diagmat = diagmat.unsqueeze(0).expand(num_embedding.size(0), -1, -1) 265 | graph_ = node_mask.unsqueeze(1) * node_mask.unsqueeze(-1) * (1-diagmat) 266 | graph_greater = graph_ * greater_graph_mask + diagmat 267 | graph_lower = graph_ * lower_graph_mask + diagmat 268 | 269 | for i in range(self.hop_size): 270 | num_embedding = self.num_gnn[i](num_embedding, graph_greater, graph_lower) 271 | 272 | # gnn_info_vec = torch.zeros((batch_size, 1, encoder_outputs.size(-1)), 273 | # dtype=torch.float, device=num_embedding.device) 274 | # gnn_info_vec = torch.cat((encoder_outputs.transpose(0, 1), gnn_info_vec), dim=1) 275 | gnn_info_vec = torch.zeros((batch_size, encoder_outputs.size(0)+1, encoder_outputs.size(-1)), 276 | dtype=torch.float, device=num_embedding.device) 277 | clamped_number_indices = replace_masked_values(num_pos_pad, num_mask, gnn_info_vec.size(1)-1) 278 | gnn_info_vec.scatter_(1, clamped_number_indices.unsqueeze(-1).expand(-1, -1, num_embedding.size(-1)), num_embedding) 279 | gnn_info_vec = gnn_info_vec[:, :-1, :] 280 | gnn_info_vec = gnn_info_vec.transpose(0, 1) 281 | gnn_info_vec = encoder_outputs + gnn_info_vec 282 | num_embedding = num_encoder_outputs + num_embedding 283 | problem_output = torch.max(gnn_info_vec, 0).values 284 | 285 | return gnn_info_vec, num_embedding, problem_output 286 | 287 | 288 | class Num_Graph_Module(nn.Module): 289 | def __init__(self, node_dim): 290 | super(Num_Graph_Module, self).__init__() 291 | 292 | self.node_dim = node_dim 293 | self.node1_fc1 = nn.Linear(node_dim, node_dim) 294 | self.node1_fc2 = nn.Linear(node_dim, node_dim) 295 | self.node2_fc1 = nn.Linear(node_dim, node_dim) 296 | self.node2_fc2 = nn.Linear(node_dim, node_dim) 297 | self.graph_weight = nn.Linear(node_dim * 4, node_dim) 298 | self.node_out = nn.Linear(node_dim * 2, node_dim) 299 | 300 | def normalize(self, graph, symmetric=True): 301 | d = graph.sum(1) 302 | if symmetric: 303 | D = torch.diag(torch.pow(d, -0.5)) 304 | return D.mm(graph).mm(D) 305 | else : 306 | D = torch.diag(torch.pow(d,-1)) 307 | return D.mm(graph) 308 | 309 | def forward(self, node, graph1, graph2): 310 | graph1 = graph1.float() 311 | graph2 = graph2.float() 312 | batch_size = node.size(0) 313 | 314 | for i in range(batch_size): 315 | graph1[i] = self.normalize(graph1[i], False) 316 | graph2[i] = self.normalize(graph2[i], False) 317 | 318 | node_info1 = torch.relu(self.node1_fc1(torch.matmul(graph1, node))) 319 | node_info1 = torch.relu(self.node1_fc2(torch.matmul(graph1, node_info1))) 320 | node_info2 = torch.relu(self.node2_fc1(torch.matmul(graph2, node))) 321 | node_info2 = torch.relu(self.node2_fc2(torch.matmul(graph2, node_info2))) 322 | gate = torch.cat((node_info1, node_info2, node_info1+node_info2, node_info1-node_info2), dim=2) 323 | gate = torch.sigmoid(self.graph_weight(gate)) 324 | node_info = gate * node_info1 + (1-gate) * node_info2 325 | agg_node_info = torch.cat((node, node_info), dim=2) 326 | agg_node_info = torch.relu(self.node_out(agg_node_info)) 327 | 328 | return agg_node_info 329 | 330 | 331 | class Prediction(nn.Module): 332 | # a seq2tree decoder with Problem aware dynamic encoding 333 | 334 | def __init__(self, hidden_size, op_nums, input_size, dropout=0.5): 335 | super(Prediction, self).__init__() 336 | 337 | # Keep for reference 338 | self.hidden_size = hidden_size 339 | self.input_size = input_size 340 | self.op_nums = op_nums 341 | 342 | # Define layers 343 | self.dropout = nn.Dropout(dropout) 344 | 345 | self.embedding_weight = nn.Parameter(torch.randn(1, input_size, hidden_size)) 346 | 347 | # for Computational symbols and Generated numbers 348 | self.concat_l = nn.Linear(hidden_size, hidden_size) 349 | self.concat_r = nn.Linear(hidden_size * 2, hidden_size) 350 | self.concat_lg = nn.Linear(hidden_size, hidden_size) 351 | self.concat_rg = nn.Linear(hidden_size * 2, hidden_size) 352 | 353 | self.ops = nn.Linear(hidden_size * 2, op_nums) 354 | 355 | self.attn = TreeAttn(hidden_size, hidden_size) 356 | self.score = Score(hidden_size * 2, hidden_size) 357 | 358 | def forward(self, node_stacks, left_childs, encoder_outputs, num_pades, padding_hidden, seq_mask, mask_nums): 359 | current_embeddings = [] 360 | 361 | for st in node_stacks: 362 | if len(st) == 0: 363 | current_embeddings.append(padding_hidden) 364 | else: 365 | current_node = st[-1] 366 | current_embeddings.append(current_node.embedding) 367 | 368 | current_node_temp = [] 369 | for l, c in zip(left_childs, current_embeddings): 370 | if l is None: 371 | c = self.dropout(c) 372 | g = torch.tanh(self.concat_l(c)) 373 | t = torch.sigmoid(self.concat_lg(c)) 374 | current_node_temp.append(g * t) 375 | else: 376 | ld = self.dropout(l) 377 | c = self.dropout(c) 378 | g = torch.tanh(self.concat_r(torch.cat((ld, c), 1))) 379 | t = torch.sigmoid(self.concat_rg(torch.cat((ld, c), 1))) 380 | current_node_temp.append(g * t) 381 | 382 | current_node = torch.stack(current_node_temp) 383 | 384 | current_embeddings = self.dropout(current_node) 385 | 386 | current_attn = self.attn(current_embeddings.transpose(0, 1), encoder_outputs, seq_mask) 387 | current_context = current_attn.bmm(encoder_outputs.transpose(0, 1)) # B x 1 x N 388 | 389 | # the information to get the current quantity 390 | batch_size = current_embeddings.size(0) 391 | # predict the output (this node corresponding to output(number or operator)) with PADE 392 | 393 | repeat_dims = [1] * self.embedding_weight.dim() 394 | repeat_dims[0] = batch_size 395 | embedding_weight = self.embedding_weight.repeat(*repeat_dims) # B x input_size x N 396 | embedding_weight = torch.cat((embedding_weight, num_pades), dim=1) # B x O x N 397 | 398 | leaf_input = torch.cat((current_node, current_context), 2) 399 | leaf_input = leaf_input.squeeze(1) 400 | leaf_input = self.dropout(leaf_input) 401 | 402 | # p_leaf = nn.functional.softmax(self.is_leaf(leaf_input), 1) 403 | # max pooling the embedding_weight 404 | embedding_weight_ = self.dropout(embedding_weight) 405 | num_score = self.score(leaf_input.unsqueeze(1), embedding_weight_, mask_nums) 406 | 407 | # num_score = nn.functional.softmax(num_score, 1) 408 | 409 | op = self.ops(leaf_input) 410 | 411 | # return p_leaf, num_score, op, current_embeddings, current_attn 412 | 413 | return num_score, op, current_node, current_context, embedding_weight 414 | 415 | 416 | class GenerateNode(nn.Module): 417 | def __init__(self, hidden_size, op_nums, embedding_size, dropout=0.5): 418 | super(GenerateNode, self).__init__() 419 | 420 | self.embedding_size = embedding_size 421 | self.hidden_size = hidden_size 422 | 423 | self.embeddings = nn.Embedding(op_nums, embedding_size) 424 | self.em_dropout = nn.Dropout(dropout) 425 | self.generate_l = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) 426 | self.generate_r = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) 427 | self.generate_lg = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) 428 | self.generate_rg = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) 429 | 430 | def forward(self, node_embedding, node_label, current_context): 431 | node_label_ = self.embeddings(node_label) 432 | node_label = self.em_dropout(node_label_) 433 | node_embedding = node_embedding.squeeze(1) 434 | current_context = current_context.squeeze(1) 435 | node_embedding = self.em_dropout(node_embedding) 436 | current_context = self.em_dropout(current_context) 437 | 438 | l_child = torch.tanh(self.generate_l(torch.cat((node_embedding, current_context, node_label), 1))) 439 | l_child_g = torch.sigmoid(self.generate_lg(torch.cat((node_embedding, current_context, node_label), 1))) 440 | r_child = torch.tanh(self.generate_r(torch.cat((node_embedding, current_context, node_label), 1))) 441 | r_child_g = torch.sigmoid(self.generate_rg(torch.cat((node_embedding, current_context, node_label), 1))) 442 | l_child = l_child * l_child_g 443 | r_child = r_child * r_child_g 444 | return l_child, r_child, node_label_ 445 | 446 | 447 | class Merge(nn.Module): 448 | def __init__(self, hidden_size, embedding_size, dropout=0.5): 449 | super(Merge, self).__init__() 450 | 451 | self.embedding_size = embedding_size 452 | self.hidden_size = hidden_size 453 | 454 | self.em_dropout = nn.Dropout(dropout) 455 | self.merge = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) 456 | self.merge_g = nn.Linear(hidden_size * 2 + embedding_size, hidden_size) 457 | 458 | def forward(self, node_embedding, sub_tree_1, sub_tree_2): 459 | sub_tree_1 = self.em_dropout(sub_tree_1) 460 | sub_tree_2 = self.em_dropout(sub_tree_2) 461 | node_embedding = self.em_dropout(node_embedding) 462 | 463 | sub_tree = torch.tanh(self.merge(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1))) 464 | sub_tree_g = torch.sigmoid(self.merge_g(torch.cat((node_embedding, sub_tree_1, sub_tree_2), 1))) 465 | sub_tree = sub_tree * sub_tree_g 466 | return sub_tree 467 | -------------------------------------------------------------------------------- /src/pre_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import json 3 | import copy 4 | import re 5 | import numpy as np 6 | 7 | 8 | PAD_token = 0 9 | 10 | 11 | class Lang: 12 | """ 13 | class to save the vocab and two dict: the word->index and index->word 14 | """ 15 | def __init__(self): 16 | self.word2index = {} 17 | self.word2count = {} 18 | self.index2word = [] 19 | self.n_words = 0 # Count word tokens 20 | self.num_start = 0 21 | 22 | def add_sen_to_vocab(self, sentence): # add words of sentence to vocab 23 | for word in sentence: 24 | if re.search("N\d+|NUM|\d+", word): 25 | continue 26 | if word not in self.index2word: 27 | self.word2index[word] = self.n_words 28 | self.word2count[word] = 1 29 | self.index2word.append(word) 30 | self.n_words += 1 31 | else: 32 | self.word2count[word] += 1 33 | 34 | def trim(self, min_count): # trim words below a certain count threshold 35 | keep_words = [] 36 | 37 | for k, v in self.word2count.items(): 38 | if v >= min_count: 39 | keep_words.append(k) 40 | 41 | print('keep_words %s / %s = %.4f' % ( 42 | len(keep_words), len(self.index2word), len(keep_words) / len(self.index2word) 43 | )) 44 | 45 | # Reinitialize dictionaries 46 | self.word2index = {} 47 | self.word2count = {} 48 | self.index2word = [] 49 | self.n_words = 0 # Count default tokens 50 | 51 | for word in keep_words: 52 | self.word2index[word] = self.n_words 53 | self.index2word.append(word) 54 | self.n_words += 1 55 | 56 | def build_input_lang(self, trim_min_count): # build the input lang vocab and dict 57 | if trim_min_count > 0: 58 | self.trim(trim_min_count) 59 | self.index2word = ["PAD", "NUM", "UNK"] + self.index2word 60 | else: 61 | self.index2word = ["PAD", "NUM"] + self.index2word 62 | self.word2index = {} 63 | self.n_words = len(self.index2word) 64 | for i, j in enumerate(self.index2word): 65 | self.word2index[j] = i 66 | 67 | def build_input_lang_for_pos(self): 68 | self.index2word = ["PAD", "UNK"] + self.index2word 69 | self.n_words = len(self.index2word) 70 | for i, j in enumerate(self.index2word): 71 | self.word2index[j] = i 72 | 73 | def build_output_lang(self, generate_num, copy_nums): # build the output lang vocab and dict 74 | self.index2word = ["PAD", "EOS"] + self.index2word + generate_num + ["N" + str(i) for i in range(copy_nums)] +\ 75 | ["SOS", "UNK"] 76 | self.n_words = len(self.index2word) 77 | for i, j in enumerate(self.index2word): 78 | self.word2index[j] = i 79 | 80 | def build_output_lang_for_tree(self, generate_num, copy_nums): # build the output lang vocab and dict 81 | self.num_start = len(self.index2word) 82 | 83 | self.index2word = self.index2word + generate_num + ["N" + str(i) for i in range(copy_nums)] + ["UNK"] 84 | self.n_words = len(self.index2word) 85 | for i, j in enumerate(self.index2word): 86 | self.word2index[j] = i 87 | 88 | 89 | def load_raw_data(filename): # load the json data to list(dict()) for MATH 23K 90 | print("Reading lines...") 91 | f = open(filename, encoding="utf-8") 92 | js = "" 93 | data = [] 94 | for i, s in enumerate(f): 95 | js += s 96 | i += 1 97 | if i % 7 == 0: # every 7 line is a json 98 | data_d = json.loads(js) 99 | if "千米/小时" in data_d["equation"]: 100 | data_d["equation"] = data_d["equation"][:-5] 101 | data.append(data_d) 102 | js = "" 103 | 104 | return data 105 | 106 | 107 | def transfer_num(data): # transfer num into "NUM" 108 | print("Transfer numbers...") 109 | pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?") 110 | pairs = [] 111 | generate_nums = [] 112 | generate_nums_dict = {} 113 | copy_nums = 0 114 | for d in data: 115 | idx = d["id"] 116 | nums = [] 117 | input_seq = [] 118 | seg = d["segmented_text"].strip().split(" ") 119 | equations = d["equation"][2:] 120 | 121 | for s in seg: 122 | pos = re.search(pattern, s) 123 | if pos and pos.start() == 0: 124 | nums.append(s[pos.start(): pos.end()]) 125 | input_seq.append("NUM") 126 | if pos.end() < len(s): 127 | input_seq.append(s[pos.end():]) 128 | elif s != "": 129 | input_seq.append(s) 130 | if copy_nums < len(nums): 131 | copy_nums = len(nums) 132 | 133 | nums_fraction = [] 134 | 135 | for num in nums: 136 | if re.search("\d*\(\d+/\d+\)\d*", num): 137 | nums_fraction.append(num) 138 | nums_fraction = sorted(nums_fraction, key=lambda x: len(x), reverse=True) 139 | 140 | def seg_and_tag(st): # seg the equation and tag the num 141 | res = [] 142 | for n in nums_fraction: 143 | if n in st: 144 | p_start = st.find(n) 145 | p_end = p_start + len(n) 146 | if p_start > 0: 147 | res += seg_and_tag(st[:p_start]) 148 | if nums.count(n) == 1: 149 | res.append("N"+str(nums.index(n))) 150 | else: 151 | res.append(n) 152 | if p_end < len(st): 153 | res += seg_and_tag(st[p_end:]) 154 | return res 155 | pos_st = re.search("\d+\.\d+%?|\d+%?", st) 156 | if pos_st: 157 | p_start = pos_st.start() 158 | p_end = pos_st.end() 159 | if p_start > 0: 160 | res += seg_and_tag(st[:p_start]) 161 | st_num = st[p_start:p_end] 162 | if nums.count(st_num) == 1: 163 | res.append("N"+str(nums.index(st_num))) 164 | else: 165 | res.append(st_num) 166 | if p_end < len(st): 167 | res += seg_and_tag(st[p_end:]) 168 | return res 169 | for ss in st: 170 | res.append(ss) 171 | return res 172 | 173 | out_seq = seg_and_tag(equations) 174 | for s in out_seq: # tag the num which is generated 175 | if s[0].isdigit() and s not in generate_nums and s not in nums: 176 | generate_nums.append(s) 177 | generate_nums_dict[s] = 0 178 | if s in generate_nums and s not in nums: 179 | generate_nums_dict[s] = generate_nums_dict[s] + 1 180 | 181 | num_pos = [] 182 | for i, j in enumerate(input_seq): 183 | if j == "NUM": 184 | num_pos.append(i) 185 | assert len(nums) == len(num_pos) 186 | # pairs.append((input_seq, out_seq, nums, num_pos, d["ans"])) 187 | pairs.append((idx, input_seq, out_seq, nums, num_pos)) 188 | 189 | temp_g = [] 190 | for g in generate_nums: 191 | if generate_nums_dict[g] >= 5: 192 | temp_g.append(g) 193 | return pairs, temp_g, copy_nums 194 | 195 | 196 | # Return a list of indexes, one for each word in the sentence, plus EOS 197 | def indexes_from_sentence(lang, sentence, tree=False): 198 | res = [] 199 | for word in sentence: 200 | if len(word) == 0: 201 | continue 202 | if word in lang.word2index: 203 | res.append(lang.word2index[word]) 204 | else: 205 | res.append(lang.word2index["UNK"]) 206 | if "EOS" in lang.index2word and not tree: 207 | res.append(lang.word2index["EOS"]) 208 | return res 209 | 210 | 211 | def texts_from_sentence(lang, sentence, tree=False): 212 | res = [] 213 | for word in sentence: 214 | if len(word) == 0: 215 | continue 216 | if word in lang.word2index: 217 | res.append(word) 218 | else: 219 | res.append("UNK") 220 | if "EOS" in lang.index2word and not tree: 221 | res.append(lang.word2index["EOS"]) 222 | return res 223 | 224 | 225 | def num_list_processed(num_list): 226 | st = [] 227 | for p in num_list: 228 | pos1 = re.search("\d+\(", p) 229 | pos2 = re.search("\)\d+", p) 230 | if pos1: 231 | st.append(eval(p[pos1.start(): pos1.end() - 1] + "+" + p[pos1.end() - 1:])) 232 | elif pos2: 233 | st.append(eval(p[:pos2.start() + 1] + "+" + p[pos2.start() + 1: pos2.end()])) 234 | elif p[-1] == "%": 235 | st.append(float(p[:-1]) / 100) 236 | else: 237 | st.append(eval(p)) 238 | return st 239 | 240 | 241 | def num_order_processed(num_list): 242 | num_order = [] 243 | num_array = np.asarray(num_list) 244 | for num in num_array: 245 | num_order.append(sum(num>num_array)+1) 246 | 247 | return num_order 248 | 249 | 250 | def prepare_data(pairs_trained, pairs_tested, trim_min_count, generate_nums, copy_nums): 251 | input1_lang = Lang() 252 | input2_lang = Lang() 253 | output1_lang = Lang() 254 | output2_lang = Lang() 255 | train_pairs = [] 256 | test_pairs = [] 257 | 258 | print("Indexing words...") 259 | for pair in pairs_trained: 260 | if pair[-1]: 261 | input1_lang.add_sen_to_vocab(pair[1]) 262 | input2_lang.add_sen_to_vocab(pair[2]) 263 | output1_lang.add_sen_to_vocab(pair[4]) 264 | output2_lang.add_sen_to_vocab(pair[5]) 265 | 266 | input1_lang.build_input_lang(trim_min_count) 267 | input2_lang.build_input_lang_for_pos() 268 | output1_lang.build_output_lang_for_tree(generate_nums, copy_nums) 269 | output2_lang.build_output_lang(generate_nums, copy_nums) 270 | 271 | for pair in pairs_trained: 272 | num_stack = [] 273 | for word in pair[4]: 274 | temp_num = [] 275 | flag_not = True 276 | if word not in output1_lang.index2word: 277 | flag_not = False 278 | for i, j in enumerate(pair[6]): 279 | if j == word: 280 | temp_num.append(i) 281 | 282 | if not flag_not and len(temp_num) != 0: 283 | num_stack.append(temp_num) 284 | if not flag_not and len(temp_num) == 0: 285 | num_stack.append([_ for _ in range(len(pair[6]))]) 286 | 287 | num_stack.reverse() 288 | input1_cell = indexes_from_sentence(input1_lang, pair[1]) 289 | texts_cell = texts_from_sentence(input1_lang, pair[1]) 290 | input2_cell = indexes_from_sentence(input2_lang, pair[2]) 291 | output1_cell = indexes_from_sentence(output1_lang, pair[4], True) 292 | output2_cell = indexes_from_sentence(output2_lang, pair[5], False) 293 | num_list = num_list_processed(pair[6]) 294 | num_order = num_order_processed(num_list) 295 | train_pairs.append((pair[0], texts_cell, input1_cell, input2_cell, pair[3], len(input1_cell), 296 | output1_cell, len(output1_cell), output2_cell, len(output2_cell), 297 | pair[6], pair[7], num_stack, num_order)) 298 | print('Indexed %d words in input language, %d words in output1, %d words in output2' % 299 | (input1_lang.n_words, output1_lang.n_words, output2_lang.n_words)) 300 | print('Number of training data %d' % (len(train_pairs))) 301 | for pair in pairs_tested: 302 | num_stack = [] 303 | for word in pair[4]: 304 | temp_num = [] 305 | flag_not = True 306 | if word not in output1_lang.index2word: 307 | flag_not = False 308 | for i, j in enumerate(pair[6]): 309 | if j == word: 310 | temp_num.append(i) 311 | 312 | if not flag_not and len(temp_num) != 0: 313 | num_stack.append(temp_num) 314 | if not flag_not and len(temp_num) == 0: 315 | num_stack.append([_ for _ in range(len(pair[6]))]) 316 | 317 | num_stack.reverse() 318 | input1_cell = indexes_from_sentence(input1_lang, pair[1]) 319 | texts_cell = texts_from_sentence(input1_lang, pair[1]) 320 | input2_cell = indexes_from_sentence(input2_lang, pair[2]) 321 | output1_cell = indexes_from_sentence(output1_lang, pair[4], True) 322 | output2_cell = indexes_from_sentence(output2_lang, pair[5], False) 323 | num_list = num_list_processed(pair[6]) 324 | num_order = num_order_processed(num_list) 325 | test_pairs.append((pair[0], texts_cell, input1_cell, input2_cell, pair[3], len(input1_cell), 326 | output1_cell, len(output1_cell), output2_cell, len(output2_cell), 327 | pair[6], pair[7], num_stack, num_order)) 328 | print('Number of testind data %d' % (len(test_pairs))) 329 | return input1_lang, input2_lang, output1_lang, output2_lang, train_pairs, test_pairs 330 | 331 | 332 | # Pad a with the PAD symbol 333 | def pad_seq(seq, seq_len, max_length): 334 | seq += [PAD_token for _ in range(max_length - seq_len)] 335 | return seq 336 | 337 | 338 | # prepare the batches 339 | def prepare_train_batch(pairs_to_batch, batch_size): 340 | pairs = copy.deepcopy(pairs_to_batch) 341 | random.shuffle(pairs) # shuffle the pairs 342 | pos = 0 343 | id_batches = [] 344 | input_lengths = [] 345 | output1_lengths = [] 346 | output2_lengths = [] 347 | nums_batches = [] 348 | batches = [] 349 | input1_batches = [] 350 | input2_batches = [] 351 | output1_batches = [] 352 | output2_batches = [] 353 | num_stack_batches = [] # save the num stack which 354 | num_pos_batches = [] 355 | num_order_batches = [] 356 | num_size_batches = [] 357 | parse_graph_batches = [] 358 | while pos + batch_size < len(pairs): 359 | batches.append(pairs[pos:pos+batch_size]) 360 | pos += batch_size 361 | batches.append(pairs[pos:]) 362 | 363 | for batch in batches: 364 | batch = sorted(batch, key=lambda tp: tp[5], reverse=True) 365 | input_length = [] 366 | output1_length = [] 367 | output2_length = [] 368 | for _, _, _, _, _, i, _, j,_, k, _, _, _, _ in batch: 369 | input_length.append(i) 370 | output1_length.append(j) 371 | output2_length.append(k) 372 | input_lengths.append(input_length) 373 | output1_lengths.append(output1_length) 374 | output2_lengths.append(output2_length) 375 | input_len_max = input_length[0] 376 | output1_len_max = max(output1_length) 377 | output2_len_max = max(output2_length) 378 | id_batch = [] 379 | input1_batch = [] 380 | input2_batch = [] 381 | output1_batch = [] 382 | output2_batch = [] 383 | num_batch = [] 384 | num_stack_batch = [] 385 | num_pos_batch = [] 386 | num_order_batch = [] 387 | num_size_batch = [] 388 | parse_tree_batch = [] 389 | for idx, _, i1, i2, parse_tree, li, j, lj, k, lk, num, num_pos, num_stack, num_order in batch: 390 | id_batch.append(idx) 391 | input1_batch.append(pad_seq(i1, li, input_len_max)) 392 | input2_batch.append(pad_seq(i2, li, input_len_max)) 393 | output1_batch.append(pad_seq(j, lj, output1_len_max)) 394 | output2_batch.append(pad_seq(k, lk, output2_len_max)) 395 | num_batch.append(len(num)) 396 | num_stack_batch.append(num_stack) 397 | num_pos_batch.append(num_pos) 398 | num_order_batch.append(num_order) 399 | num_size_batch.append(len(num_pos)) 400 | parse_tree_batch.append(parse_tree) 401 | 402 | id_batches.append(id_batch) 403 | input1_batches.append(input1_batch) 404 | input2_batches.append(input2_batch) 405 | output1_batches.append(output1_batch) 406 | output2_batches.append(output2_batch) 407 | nums_batches.append(num_batch) 408 | num_stack_batches.append(num_stack_batch) 409 | num_pos_batches.append(num_pos_batch) 410 | num_order_batches.append(num_order_batch) 411 | num_size_batches.append(num_size_batch) 412 | parse_graph_batches.append(get_parse_graph_batch(input_length, parse_tree_batch)) 413 | 414 | return id_batches, input1_batches, input2_batches, input_lengths, output1_batches, output1_lengths, output2_batches, output2_lengths, \ 415 | nums_batches, num_stack_batches, num_pos_batches, num_order_batches, num_size_batches, parse_graph_batches 416 | 417 | 418 | def get_parse_graph_batch(input_length, parse_tree_batch): 419 | batch_graph = [] 420 | max_len = max(input_length) 421 | for i in range(len(input_length)): 422 | parse_tree = parse_tree_batch[i] 423 | diag_ele = [1] * input_length[i] + [0] * (max_len - input_length[i]) 424 | graph1 = np.diag([1]*max_len) + np.diag(diag_ele[1:], 1) + np.diag(diag_ele[1:], -1) 425 | graph2 = copy.deepcopy(graph1) 426 | graph3 = copy.deepcopy(graph1) 427 | for j in range(len(parse_tree)): 428 | if parse_tree[j] != -1: 429 | graph1[j, parse_tree[j]] = 1 430 | graph2[parse_tree[j], j] = 1 431 | graph3[j, parse_tree[j]] = 1 432 | graph3[parse_tree[j], j] = 1 433 | graph = [graph1.tolist(), graph2.tolist(), graph3.tolist()] 434 | batch_graph.append(graph) 435 | batch_graph = np.array(batch_graph) 436 | return batch_graph 437 | 438 | 439 | def word2vec(train_pairs, embedding_size, input_lang): 440 | sentences = [] 441 | for train in train_pairs: 442 | sentence = train[1] 443 | sentences.append(sentence) 444 | 445 | from gensim.models import word2vec 446 | model = word2vec.Word2Vec(sentences, size=embedding_size, min_count=1) 447 | 448 | emb_vectors = [] 449 | emb_vectors.append(np.zeros((embedding_size))) 450 | for i in range(1, input_lang.n_words): 451 | emb_vectors.append(np.array(model.wv[input_lang.index2word[i]])) 452 | 453 | return emb_vectors -------------------------------------------------------------------------------- /src/train_and_evaluate.py: -------------------------------------------------------------------------------- 1 | from src.masked_cross_entropy import * 2 | from src.pre_data import * 3 | from src.expressions_transfer import * 4 | from src.models import * 5 | import math 6 | import torch 7 | import torch.optim 8 | import torch.nn.functional as f 9 | import time 10 | 11 | MAX_OUTPUT_LENGTH = 45 12 | MAX_INPUT_LENGTH = 120 13 | USE_CUDA = torch.cuda.is_available() 14 | 15 | 16 | class Beam: # the class save the beam node 17 | def __init__(self, score, input_var, hidden, all_output): 18 | self.score = score 19 | self.input_var = input_var 20 | self.hidden = hidden 21 | self.all_output = all_output 22 | 23 | 24 | def time_since(s): # compute time 25 | m = math.floor(s / 60) 26 | s -= m * 60 27 | h = math.floor(m / 60) 28 | m -= h * 60 29 | return '%dh %dm %ds' % (h, m, s) 30 | 31 | 32 | def generate_tree_input(target, decoder_output, nums_stack_batch, num_start, unk): 33 | # when the decoder input is copied num but the num has two pos, chose the max 34 | target_input = copy.deepcopy(target) 35 | for i in range(len(target)): 36 | if target[i] == unk: 37 | num_stack = nums_stack_batch[i].pop() 38 | max_score = -float("1e12") 39 | for num in num_stack: 40 | if decoder_output[i, num_start + num] > max_score: 41 | target[i] = num + num_start 42 | max_score = decoder_output[i, num_start + num] 43 | if target_input[i] >= num_start: 44 | target_input[i] = 0 45 | return torch.LongTensor(target), torch.LongTensor(target_input) 46 | 47 | 48 | def generate_decoder_input(target, decoder_output, nums_stack_batch, num_start, unk): 49 | # when the decoder input is copied num but the num has two pos, chose the max 50 | if USE_CUDA: 51 | decoder_output = decoder_output.cpu() 52 | for i in range(target.size(0)): 53 | if target[i] == unk: 54 | num_stack = nums_stack_batch[i].pop() 55 | max_score = -float("1e12") 56 | for num in num_stack: 57 | if decoder_output[i, num_start + num] > max_score: 58 | target[i] = num + num_start 59 | max_score = decoder_output[i, num_start + num] 60 | return target 61 | 62 | 63 | def compute_prefix_tree_result(test_res, test_tar, output_lang, num_list, num_stack): 64 | # print(test_res, test_tar) 65 | 66 | if len(num_stack) == 0 and test_res == test_tar: 67 | return True, True, test_res, test_tar 68 | test = out_expression_list(test_res, output_lang, num_list) 69 | tar = out_expression_list(test_tar, output_lang, num_list, copy.deepcopy(num_stack)) 70 | # print(test, tar) 71 | if test is None: 72 | return False, False, test, tar 73 | if test == tar: 74 | return True, True, test, tar 75 | try: 76 | if abs(compute_prefix_expression(test) - compute_prefix_expression(tar)) < 1e-4: 77 | return True, False, test, tar 78 | else: 79 | return False, False, test, tar 80 | except: 81 | return False, False, test, tar 82 | 83 | 84 | def compute_postfix_tree_result(test_res, test_tar, output_lang, num_list, num_stack): 85 | # print(test_res, test_tar) 86 | 87 | if len(num_stack) == 0 and test_res == test_tar: 88 | return True, True, test_res, test_tar 89 | test = out_expression_list(test_res, output_lang, num_list) 90 | tar = out_expression_list(test_tar, output_lang, num_list, copy.deepcopy(num_stack)) 91 | # print(test, tar) 92 | if test is None: 93 | return False, False, test, tar 94 | if test == tar: 95 | return True, True, test, tar 96 | try: 97 | if abs(compute_postfix_expression(test) - compute_postfix_expression(tar)) < 1e-4: 98 | return True, False, test, tar 99 | else: 100 | return False, False, test, tar 101 | except: 102 | return False, False, test, tar 103 | 104 | 105 | def get_all_number_encoder_outputs(encoder_outputs, num_pos, batch_size, num_size, hidden_size): 106 | indices = list() 107 | sen_len = encoder_outputs.size(0) 108 | masked_index = [] 109 | temp_1 = [1 for _ in range(hidden_size)] 110 | temp_0 = [0 for _ in range(hidden_size)] 111 | for b in range(batch_size): 112 | for i in num_pos[b]: 113 | indices.append(i + b * sen_len) 114 | masked_index.append(temp_0) 115 | indices += [0 for _ in range(len(num_pos[b]), num_size)] 116 | masked_index += [temp_1 for _ in range(len(num_pos[b]), num_size)] 117 | indices = torch.LongTensor(indices) 118 | masked_index = torch.ByteTensor(masked_index) 119 | masked_index = masked_index.view(batch_size, num_size, hidden_size) 120 | if USE_CUDA: 121 | indices = indices.cuda() 122 | masked_index = masked_index.cuda() 123 | all_outputs = encoder_outputs.transpose(0, 1).contiguous() 124 | all_embedding = all_outputs.view(-1, encoder_outputs.size(2)) # S x B x H -> (B x S) x H 125 | all_num = all_embedding.index_select(0, indices) 126 | all_num = all_num.view(batch_size, num_size, hidden_size) 127 | return all_num.masked_fill_(masked_index.bool(), 0.0), masked_index 128 | 129 | 130 | def copy_list(l): 131 | r = [] 132 | if len(l) == 0: 133 | return r 134 | for i in l: 135 | if type(i) is list: 136 | r.append(copy_list(i)) 137 | else: 138 | r.append(i) 139 | return r 140 | 141 | 142 | class TreeBeam: # the class save the beam node 143 | def __init__(self, score, node_stack, embedding_stack, left_childs, out): 144 | self.score = score 145 | self.embedding_stack = copy_list(embedding_stack) 146 | self.node_stack = copy_list(node_stack) 147 | self.left_childs = copy_list(left_childs) 148 | self.out = copy.deepcopy(out) 149 | 150 | 151 | class TreeEmbedding: # the class save the tree 152 | def __init__(self, embedding, terminal=False): 153 | self.embedding = embedding 154 | self.terminal = terminal 155 | 156 | 157 | def train_tree_double(encoder_outputs, problem_output, all_nums_encoder_outputs, target, target_length, 158 | output_lang, batch_size, padding_hidden, seq_mask, 159 | num_mask, num_pos, num_order_pad, nums_stack_batch, unk, 160 | encoder, numencoder, predict, generate, merge): 161 | # Prepare input and output variables 162 | node_stacks = [[TreeNode(_)] for _ in problem_output.split(1, dim=0)] 163 | 164 | max_target_length = max(target_length) 165 | 166 | all_node_outputs = [] 167 | # all_leafs = [] 168 | 169 | # copy_num_len = [len(_) for _ in num_pos] 170 | # num_size = max(copy_num_len) 171 | # nums_encoder_outputs, masked_index = get_all_number_encoder_outputs(encoder_outputs, num_pos, batch_size, num_size, 172 | # encoder.hidden_size) 173 | # all_nums_encoder_outputs = numencoder(nums_encoder_outputs, num_order_pad) 174 | # all_nums_encoder_outputs = all_nums_encoder_outputs.masked_fill_(masked_index.bool(), 0.0) 175 | 176 | num_start = output_lang.num_start 177 | embeddings_stacks = [[] for _ in range(batch_size)] 178 | left_childs = [None for _ in range(batch_size)] 179 | for t in range(max_target_length): 180 | num_score, op, current_embeddings, current_context, current_nums_embeddings = predict( 181 | node_stacks, left_childs, encoder_outputs, all_nums_encoder_outputs, padding_hidden, seq_mask, num_mask) 182 | 183 | # all_leafs.append(p_leaf) 184 | outputs = torch.cat((op, num_score), 1) 185 | all_node_outputs.append(outputs) 186 | 187 | target_t, generate_input = generate_tree_input(target[t].tolist(), outputs, nums_stack_batch, num_start, unk) 188 | target[t] = target_t 189 | if USE_CUDA: 190 | generate_input = generate_input.cuda() 191 | left_child, right_child, node_label = generate(current_embeddings, generate_input, current_context) 192 | left_childs = [] 193 | for idx, l, r, node_stack, i, o in zip(range(batch_size), left_child.split(1), right_child.split(1), 194 | node_stacks, target[t].tolist(), embeddings_stacks): 195 | if len(node_stack) != 0: 196 | node = node_stack.pop() 197 | else: 198 | left_childs.append(None) 199 | continue 200 | 201 | if i < num_start: 202 | node_stack.append(TreeNode(r)) 203 | node_stack.append(TreeNode(l, left_flag=True)) 204 | o.append(TreeEmbedding(node_label[idx].unsqueeze(0), False)) 205 | else: 206 | current_num = current_nums_embeddings[idx, i - num_start].unsqueeze(0) 207 | while len(o) > 0 and o[-1].terminal: 208 | sub_stree = o.pop() 209 | op = o.pop() 210 | current_num = merge(op.embedding, sub_stree.embedding, current_num) 211 | o.append(TreeEmbedding(current_num, True)) 212 | if len(o) > 0 and o[-1].terminal: 213 | left_childs.append(o[-1].embedding) 214 | else: 215 | left_childs.append(None) 216 | 217 | # all_leafs = torch.stack(all_leafs, dim=1) # B x S x 2 218 | all_node_outputs = torch.stack(all_node_outputs, dim=1) # B x S x N 219 | 220 | target = target.transpose(0, 1).contiguous() 221 | if USE_CUDA: 222 | # all_leafs = all_leafs.cuda() 223 | all_node_outputs = all_node_outputs.cuda() 224 | target = target.cuda() 225 | 226 | # op_target = target < num_start 227 | # loss_0 = masked_cross_entropy_without_logit(all_leafs, op_target.long(), target_length) 228 | loss = masked_cross_entropy(all_node_outputs, target, target_length) 229 | # loss = loss_0 + loss_1 230 | return loss # , loss_0.item(), loss_1.item() 231 | 232 | 233 | def train_attn_double(encoder_outputs, decoder_hidden, target, target_length, 234 | output_lang, batch_size, seq_mask, 235 | num_start, nums_stack_batch, unk, 236 | decoder, beam_size, use_teacher_forcing): 237 | # Prepare input and output variables 238 | decoder_input = torch.LongTensor([output_lang.word2index["SOS"]] * batch_size) 239 | 240 | max_target_length = max(target_length) 241 | all_decoder_outputs = torch.zeros(max_target_length, batch_size, decoder.output_size) 242 | 243 | # Move new Variables to CUDA 244 | if USE_CUDA: 245 | all_decoder_outputs = all_decoder_outputs.cuda() 246 | 247 | if random.random() < use_teacher_forcing: 248 | # Run through decoder one time step at a time 249 | for t in range(max_target_length): 250 | if USE_CUDA: 251 | decoder_input = decoder_input.cuda() 252 | 253 | decoder_output, decoder_hidden = decoder( 254 | decoder_input, decoder_hidden, encoder_outputs, seq_mask) 255 | all_decoder_outputs[t] = decoder_output 256 | decoder_input = generate_decoder_input( 257 | target[t], decoder_output, nums_stack_batch, num_start, unk) 258 | target[t] = decoder_input 259 | else: 260 | beam_list = list() 261 | score = torch.zeros(batch_size) 262 | if USE_CUDA: 263 | score = score.cuda() 264 | beam_list.append(Beam(score, decoder_input, decoder_hidden, all_decoder_outputs)) 265 | # Run through decoder one time step at a time 266 | for t in range(max_target_length): 267 | beam_len = len(beam_list) 268 | beam_scores = torch.zeros(batch_size, decoder.output_size * beam_len) 269 | all_hidden = torch.zeros(decoder_hidden.size(0), batch_size * beam_len, decoder_hidden.size(2)) 270 | all_outputs = torch.zeros(max_target_length, batch_size * beam_len, decoder.output_size) 271 | if USE_CUDA: 272 | beam_scores = beam_scores.cuda() 273 | all_hidden = all_hidden.cuda() 274 | all_outputs = all_outputs.cuda() 275 | 276 | for b_idx in range(len(beam_list)): 277 | decoder_input = beam_list[b_idx].input_var 278 | decoder_hidden = beam_list[b_idx].hidden 279 | 280 | # rule_mask = generate_rule_mask(decoder_input, num_batch, output_lang.word2index, batch_size, 281 | # num_start, copy_nums, generate_nums, english) 282 | if USE_CUDA: 283 | # rule_mask = rule_mask.cuda() 284 | decoder_input = decoder_input.cuda() 285 | 286 | decoder_output, decoder_hidden = decoder( 287 | decoder_input, decoder_hidden, encoder_outputs, seq_mask) 288 | 289 | # score = f.log_softmax(decoder_output, dim=1) + rule_mask 290 | score = f.log_softmax(decoder_output, dim=1) 291 | beam_score = beam_list[b_idx].score 292 | beam_score = beam_score.unsqueeze(1) 293 | repeat_dims = [1] * beam_score.dim() 294 | repeat_dims[1] = score.size(1) 295 | beam_score = beam_score.repeat(*repeat_dims) 296 | score += beam_score 297 | beam_scores[:, b_idx * decoder.output_size: (b_idx + 1) * decoder.output_size] = score 298 | all_hidden[:, b_idx * batch_size:(b_idx + 1) * batch_size, :] = decoder_hidden 299 | 300 | beam_list[b_idx].all_output[t] = decoder_output 301 | all_outputs[:, batch_size * b_idx: batch_size * (b_idx + 1), :] = \ 302 | beam_list[b_idx].all_output 303 | topv, topi = beam_scores.topk(beam_size, dim=1) 304 | beam_list = list() 305 | 306 | for k in range(beam_size): 307 | temp_topk = topi[:, k] 308 | temp_input = temp_topk % decoder.output_size 309 | temp_input = temp_input.data 310 | if USE_CUDA: 311 | temp_input = temp_input.cpu() 312 | temp_beam_pos = temp_topk / decoder.output_size 313 | 314 | indices = torch.LongTensor(range(batch_size)) 315 | if USE_CUDA: 316 | indices = indices.cuda() 317 | indices += temp_beam_pos * batch_size 318 | 319 | temp_hidden = all_hidden.index_select(1, indices) 320 | temp_output = all_outputs.index_select(1, indices) 321 | 322 | beam_list.append(Beam(topv[:, k], temp_input, temp_hidden, temp_output)) 323 | all_decoder_outputs = beam_list[0].all_output 324 | 325 | for t in range(max_target_length): 326 | target[t] = generate_decoder_input( 327 | target[t], all_decoder_outputs[t], nums_stack_batch, num_start, unk) 328 | # Loss calculation and backpropagation 329 | 330 | if USE_CUDA: 331 | target = target.cuda() 332 | 333 | loss = masked_cross_entropy( 334 | all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq 335 | target.transpose(0, 1).contiguous(), # -> batch x seq 336 | target_length 337 | ) 338 | 339 | return loss 340 | 341 | 342 | def train_double(input1_batch, input2_batch, input_length, target1_batch, target1_length, target2_batch, target2_length, 343 | num_stack_batch, num_size_batch, generate_num1_ids, generate_num2_ids, copy_nums, 344 | encoder, numencoder, predict, generate, merge, decoder, 345 | encoder_optimizer, numencoder_optimizer, predict_optimizer, generate_optimizer, merge_optimizer, decoder_optimizer, 346 | input_lang, output1_lang, output2_lang, num_pos_batch, num_order_batch, parse_graph_batch, 347 | beam_size=5, use_teacher_forcing=0.83, english=False): 348 | # sequence mask for attention 349 | seq_mask = [] 350 | max_len = max(input_length) 351 | for i in input_length: 352 | seq_mask.append([0 for _ in range(i)] + [1 for _ in range(i, max_len)]) 353 | seq_mask = torch.ByteTensor(seq_mask) 354 | 355 | num_mask = [] 356 | max_num_size = max(num_size_batch) + len(generate_num1_ids) 357 | for i in num_size_batch: 358 | d = i + len(generate_num1_ids) 359 | num_mask.append([0] * d + [1] * (max_num_size - d)) 360 | num_mask = torch.ByteTensor(num_mask) 361 | 362 | num_pos_pad = [] 363 | max_num_pos_size = max(num_size_batch) 364 | for i in range(len(num_pos_batch)): 365 | temp = num_pos_batch[i] + [-1] * (max_num_pos_size-len(num_pos_batch[i])) 366 | num_pos_pad.append(temp) 367 | num_pos_pad = torch.LongTensor(num_pos_pad) 368 | 369 | num_order_pad = [] 370 | max_num_order_size = max(num_size_batch) 371 | for i in range(len(num_order_batch)): 372 | temp = num_order_batch[i] + [0] * (max_num_order_size-len(num_order_batch[i])) 373 | num_order_pad.append(temp) 374 | num_order_pad = torch.LongTensor(num_order_pad) 375 | 376 | num_stack1_batch = copy.deepcopy(num_stack_batch) 377 | num_stack2_batch = copy.deepcopy(num_stack_batch) 378 | num_start2 = output2_lang.n_words - copy_nums - 2 379 | unk1 = output1_lang.word2index["UNK"] 380 | unk2 = output2_lang.word2index["UNK"] 381 | 382 | # Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size) 383 | input1_var = torch.LongTensor(input1_batch).transpose(0, 1) 384 | input2_var = torch.LongTensor(input2_batch).transpose(0, 1) 385 | target1 = torch.LongTensor(target1_batch).transpose(0, 1) 386 | target2 = torch.LongTensor(target2_batch).transpose(0, 1) 387 | parse_graph_pad = torch.LongTensor(parse_graph_batch) 388 | 389 | padding_hidden = torch.FloatTensor([0.0 for _ in range(predict.hidden_size)]).unsqueeze(0) 390 | batch_size = len(input_length) 391 | 392 | encoder.train() 393 | numencoder.train() 394 | predict.train() 395 | generate.train() 396 | merge.train() 397 | decoder.train() 398 | 399 | if USE_CUDA: 400 | input1_var = input1_var.cuda() 401 | input2_var = input2_var.cuda() 402 | seq_mask = seq_mask.cuda() 403 | padding_hidden = padding_hidden.cuda() 404 | num_mask = num_mask.cuda() 405 | num_pos_pad = num_pos_pad.cuda() 406 | num_order_pad = num_order_pad.cuda() 407 | parse_graph_pad = parse_graph_pad.cuda() 408 | 409 | # Zero gradients of both optimizers 410 | encoder_optimizer.zero_grad() 411 | numencoder_optimizer.zero_grad() 412 | predict_optimizer.zero_grad() 413 | generate_optimizer.zero_grad() 414 | merge_optimizer.zero_grad() 415 | decoder_optimizer.zero_grad() 416 | # Run words through encoder 417 | 418 | encoder_outputs, encoder_hidden = encoder(input1_var, input2_var, input_length, parse_graph_pad) 419 | copy_num_len = [len(_) for _ in num_pos_batch] 420 | num_size = max(copy_num_len) 421 | num_encoder_outputs, masked_index = get_all_number_encoder_outputs(encoder_outputs, num_pos_batch, 422 | batch_size, num_size, encoder.hidden_size) 423 | encoder_outputs, num_outputs, problem_output = numencoder(encoder_outputs, num_encoder_outputs, 424 | num_pos_pad, num_order_pad) 425 | num_outputs = num_outputs.masked_fill_(masked_index.bool(), 0.0) 426 | 427 | decoder_hidden = encoder_hidden[:decoder.n_layers] # Use last (forward) hidden state from encoder 428 | 429 | loss_0 = train_tree_double(encoder_outputs, problem_output, num_outputs, target1, target1_length, 430 | output1_lang, batch_size, padding_hidden, seq_mask, 431 | num_mask, num_pos_batch, num_order_pad, num_stack1_batch, unk1, 432 | encoder, numencoder, predict, generate, merge) 433 | 434 | loss_1 = train_attn_double(encoder_outputs, decoder_hidden, target2, target2_length, 435 | output2_lang, batch_size, seq_mask, 436 | num_start2, num_stack2_batch, unk2, 437 | decoder, beam_size, use_teacher_forcing) 438 | 439 | loss = loss_0 + loss_1 440 | loss.backward() 441 | 442 | encoder_optimizer.step() 443 | numencoder_optimizer.step() 444 | predict_optimizer.step() 445 | generate_optimizer.step() 446 | merge_optimizer.step() 447 | decoder_optimizer.step() 448 | return loss.item() # , loss_0.item(), loss_1.item() 449 | 450 | 451 | def evaluate_tree_double(encoder_outputs, problem_output, all_nums_encoder_outputs, 452 | output_lang, batch_size, padding_hidden, seq_mask, num_mask, 453 | max_length, num_pos, num_order_pad, 454 | encoder, numencoder, predict, generate, merge, beam_size): 455 | # Prepare input and output variables 456 | node_stacks = [[TreeNode(_)] for _ in problem_output.split(1, dim=0)] 457 | 458 | num_start = output_lang.num_start 459 | # B x P x N 460 | embeddings_stacks = [[] for _ in range(batch_size)] 461 | left_childs = [None for _ in range(batch_size)] 462 | 463 | beams = [TreeBeam(0.0, node_stacks, embeddings_stacks, left_childs, [])] 464 | 465 | for t in range(max_length): 466 | current_beams = [] 467 | while len(beams) > 0: 468 | b = beams.pop() 469 | if len(b.node_stack[0]) == 0: 470 | current_beams.append(b) 471 | continue 472 | # left_childs = torch.stack(b.left_childs) 473 | left_childs = b.left_childs 474 | 475 | num_score, op, current_embeddings, current_context, current_nums_embeddings = predict( 476 | b.node_stack, left_childs, encoder_outputs, all_nums_encoder_outputs, padding_hidden, 477 | seq_mask, num_mask) 478 | 479 | out_score = nn.functional.log_softmax(torch.cat((op, num_score), dim=1), dim=1) 480 | 481 | topv, topi = out_score.topk(beam_size) 482 | 483 | for tv, ti in zip(topv.split(1, dim=1), topi.split(1, dim=1)): 484 | current_node_stack = copy_list(b.node_stack) 485 | current_left_childs = [] 486 | current_embeddings_stacks = copy_list(b.embedding_stack) 487 | current_out = copy.deepcopy(b.out) 488 | 489 | out_token = int(ti) 490 | current_out.append(out_token) 491 | 492 | node = current_node_stack[0].pop() 493 | 494 | if out_token < num_start: 495 | generate_input = torch.LongTensor([out_token]) 496 | if USE_CUDA: 497 | generate_input = generate_input.cuda() 498 | left_child, right_child, node_label = generate(current_embeddings, generate_input, current_context) 499 | 500 | current_node_stack[0].append(TreeNode(right_child)) 501 | current_node_stack[0].append(TreeNode(left_child, left_flag=True)) 502 | 503 | current_embeddings_stacks[0].append(TreeEmbedding(node_label[0].unsqueeze(0), False)) 504 | else: 505 | current_num = current_nums_embeddings[0, out_token - num_start].unsqueeze(0) 506 | 507 | while len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal: 508 | sub_stree = current_embeddings_stacks[0].pop() 509 | op = current_embeddings_stacks[0].pop() 510 | current_num = merge(op.embedding, sub_stree.embedding, current_num) 511 | current_embeddings_stacks[0].append(TreeEmbedding(current_num, True)) 512 | if len(current_embeddings_stacks[0]) > 0 and current_embeddings_stacks[0][-1].terminal: 513 | current_left_childs.append(current_embeddings_stacks[0][-1].embedding) 514 | else: 515 | current_left_childs.append(None) 516 | current_beams.append(TreeBeam(b.score+float(tv), current_node_stack, current_embeddings_stacks, 517 | current_left_childs, current_out)) 518 | beams = sorted(current_beams, key=lambda x: x.score, reverse=True) 519 | beams = beams[:beam_size] 520 | flag = True 521 | for b in beams: 522 | if len(b.node_stack[0]) != 0: 523 | flag = False 524 | if flag: 525 | break 526 | 527 | return beams[0] 528 | 529 | 530 | def evaluate_attn_double(encoder_outputs, decoder_hidden, 531 | output_lang, batch_size, seq_mask, max_length, 532 | decoder, beam_size): 533 | # Create starting vectors for decoder 534 | decoder_input = torch.LongTensor([output_lang.word2index["SOS"]]) # SOS 535 | beam_list = list() 536 | score = 0 537 | beam_list.append(Beam(score, decoder_input, decoder_hidden, [])) 538 | 539 | # Run through decoder 540 | for di in range(max_length): 541 | temp_list = list() 542 | beam_len = len(beam_list) 543 | for xb in beam_list: 544 | if int(xb.input_var[0]) == output_lang.word2index["EOS"]: 545 | temp_list.append(xb) 546 | beam_len -= 1 547 | if beam_len == 0: 548 | return beam_list[0] 549 | beam_scores = torch.zeros(decoder.output_size * beam_len) 550 | hidden_size_0 = decoder_hidden.size(0) 551 | hidden_size_2 = decoder_hidden.size(2) 552 | all_hidden = torch.zeros(beam_len, hidden_size_0, 1, hidden_size_2) 553 | if USE_CUDA: 554 | beam_scores = beam_scores.cuda() 555 | all_hidden = all_hidden.cuda() 556 | all_outputs = [] 557 | current_idx = -1 558 | 559 | for b_idx in range(len(beam_list)): 560 | decoder_input = beam_list[b_idx].input_var 561 | if int(decoder_input[0]) == output_lang.word2index["EOS"]: 562 | continue 563 | current_idx += 1 564 | decoder_hidden = beam_list[b_idx].hidden 565 | 566 | # rule_mask = generate_rule_mask(decoder_input, [num_list], output_lang.word2index, 567 | # 1, num_start, copy_nums, generate_nums, english) 568 | if USE_CUDA: 569 | # rule_mask = rule_mask.cuda() 570 | decoder_input = decoder_input.cuda() 571 | 572 | decoder_output, decoder_hidden = decoder( 573 | decoder_input, decoder_hidden, encoder_outputs, seq_mask) 574 | # score = f.log_softmax(decoder_output, dim=1) + rule_mask.squeeze() 575 | score = f.log_softmax(decoder_output, dim=1) 576 | score += beam_list[b_idx].score 577 | beam_scores[current_idx * decoder.output_size: (current_idx + 1) * decoder.output_size] = score 578 | all_hidden[current_idx] = decoder_hidden 579 | all_outputs.append(beam_list[b_idx].all_output) 580 | topv, topi = beam_scores.topk(beam_size) 581 | 582 | for k in range(beam_size): 583 | word_n = int(topi[k]) 584 | word_input = word_n % decoder.output_size 585 | temp_input = torch.LongTensor([word_input]) 586 | indices = int(word_n / decoder.output_size) 587 | 588 | temp_hidden = all_hidden[indices] 589 | temp_output = all_outputs[indices]+[word_input] 590 | temp_list.append(Beam(float(topv[k]), temp_input, temp_hidden, temp_output)) 591 | 592 | temp_list = sorted(temp_list, key=lambda x: x.score, reverse=True) 593 | 594 | if len(temp_list) < beam_size: 595 | beam_list = temp_list 596 | else: 597 | beam_list = temp_list[:beam_size] 598 | return beam_list[0] 599 | 600 | 601 | def evaluate_double(input1_batch, input2_batch, input_length, generate_num1_ids, generate_num2_ids, 602 | encoder, numencoder, predict, generate, merge, decoder, 603 | input_lang, output1_lang, output2_lang, num_pos_batch, num_order_batch, parse_graph_batch, 604 | beam_size=5, english=False, max_length=MAX_OUTPUT_LENGTH): 605 | 606 | seq_mask = torch.ByteTensor(1, input_length).fill_(0) 607 | num_pos_pad = torch.LongTensor([num_pos_batch]) 608 | num_order_pad = torch.LongTensor([num_order_batch]) 609 | parse_graph_pad = torch.LongTensor(parse_graph_batch) 610 | # Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size) 611 | input1_var = torch.LongTensor(input1_batch).unsqueeze(1) 612 | input2_var = torch.LongTensor(input2_batch).unsqueeze(1) 613 | 614 | num_mask = torch.ByteTensor(1, len(num_pos_batch) + len(generate_num1_ids)).fill_(0) 615 | 616 | # Set to not-training mode to disable dropout 617 | encoder.eval() 618 | numencoder.eval() 619 | predict.eval() 620 | generate.eval() 621 | merge.eval() 622 | decoder.eval() 623 | 624 | padding_hidden = torch.FloatTensor([0.0 for _ in range(predict.hidden_size)]).unsqueeze(0) 625 | 626 | batch_size = 1 627 | 628 | if USE_CUDA: 629 | input1_var = input1_var.cuda() 630 | input2_var = input2_var.cuda() 631 | seq_mask = seq_mask.cuda() 632 | padding_hidden = padding_hidden.cuda() 633 | num_mask = num_mask.cuda() 634 | num_pos_pad = num_pos_pad.cuda() 635 | num_order_pad = num_order_pad.cuda() 636 | parse_graph_pad = parse_graph_pad.cuda() 637 | # Run words through encoder 638 | 639 | encoder_outputs, encoder_hidden = encoder(input1_var, input2_var, [input_length], parse_graph_pad) 640 | num_size = len(num_pos_batch) 641 | num_encoder_outputs, masked_index = get_all_number_encoder_outputs(encoder_outputs, [num_pos_batch], batch_size, 642 | num_size, encoder.hidden_size) 643 | encoder_outputs, num_outputs, problem_output = numencoder(encoder_outputs, num_encoder_outputs, 644 | num_pos_pad, num_order_pad) 645 | decoder_hidden = encoder_hidden[:decoder.n_layers] # Use last (forward) hidden state from encoder 646 | 647 | tree_beam = evaluate_tree_double(encoder_outputs, problem_output, num_outputs, 648 | output1_lang, batch_size, padding_hidden, seq_mask, num_mask, 649 | max_length, num_pos_batch, num_order_pad, 650 | encoder, numencoder, predict, generate, merge, beam_size) 651 | 652 | attn_beam = evaluate_attn_double(encoder_outputs, decoder_hidden, 653 | output2_lang, batch_size, seq_mask, max_length, 654 | decoder, beam_size) 655 | 656 | if tree_beam.score >= attn_beam.score: 657 | return "tree", tree_beam.out, tree_beam.score 658 | else: 659 | return "attn", attn_beam.all_output, attn_beam.score --------------------------------------------------------------------------------