├── README.md └── src ├── calc_5_fold.py ├── calc_score.py ├── calculator.py ├── callbacks.py ├── define.py ├── f1.py ├── make_train_valid.py ├── modules ├── __init__.py ├── common.py └── neural_solver_machine_v1.py ├── permute_stack.py ├── predict.py ├── predict_5_fold.py ├── pytorch_base.py ├── stack_machine.py ├── text_num_utils.py ├── torch_solver.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # End-to-End Math Solver 2 | 3 | The implementation of the NAACL 2019 paper [Semantically-Aligned Equation Generation for Solving and Reasoning Math Word Problems](https://arxiv.org/abs/1811.00720). 4 | 5 | ## Usage 6 | 7 | 0. Download required files: 8 | - Math23K dataset. If it is not in the valid JSON format, you need to fix it. 9 | - [Chinese FastText word vectors](https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zh.300.vec.gz) 10 | 1. Preprocess dataset: 11 | 12 | ``` 13 | cd src/ 14 | mkdir ../data/ 15 | python make_train_valid.py ~/storage/MWP/data/math23k_fix.json ~/storage/cc.zh.300.vec ../data/train_valid.pkl --valid_ratio 0 --char_based 16 | ``` 17 | Note that the warning is not unexpected, because some of the problems use operator out of `+, -, *, /`. The purpose of the flags are as follow: 18 | - `--valid-ratio 0`: Set the ratio the validation dataset should take. It should be set to 0 when running 5-fold cross validation. 19 | - `--char_based`: Specify if you want to tokenize the problem text into characters rather into words. 20 | 2. Start 5-fold cross-validation by 21 | ``` 22 | mkdir ../models 23 | python train.py ../data/train_valid.pkl ../models/model.pkl --five_fold 24 | ``` 25 | -------------------------------------------------------------------------------- /src/calc_5_fold.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pdb 4 | import pickle 5 | import sys 6 | import traceback 7 | import json 8 | import torch 9 | from utils import MWPDataset 10 | from calc_score import tofloat 11 | 12 | 13 | def main(args): 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | # preprocessor = Preprocessor(args.embedding_path) 17 | # train, valid = preprocessor.get_train_valid_dataset(args.data_path) 18 | 19 | with open(args.pickle_path, 'rb') as f: 20 | data = pickle.load(f) 21 | preprocessor = data['preprocessor'] 22 | problems = data['train']._problems 23 | 24 | if args.arch == 'NSMv1': 25 | from torch_solver import TorchSolver 26 | solver = TorchSolver( 27 | preprocessor.get_word_dim(), 28 | args.dim_hidden, 29 | batch_size=args.batch_size, 30 | n_epochs=10000, 31 | device=args.device) 32 | 33 | if args.arch == 'NSMv2': 34 | from torch_solver_v2 import TorchSolverV2 35 | solver = TorchSolverV2( 36 | preprocessor.get_word_dim(), 37 | args.dim_hidden, 38 | batch_size=args.batch_size, 39 | n_epochs=10000, 40 | device=args.device) 41 | 42 | elif args.arch == 'seq2seq': 43 | from torch_seq2seq import TorchSeq2Seq 44 | solver = TorchSeq2Seq( 45 | preprocessor.get_vocabulary_size(), 46 | preprocessor.get_word_dim(), 47 | args.dim_hidden, 48 | embedding=preprocessor.get_embedding(), 49 | batch_size=args.batch_size, 50 | n_epochs=10000, 51 | device=args.device) 52 | 53 | accuracys = [] 54 | for fold in range(5): 55 | # make valid dataset 56 | fold_indices = [int(len(problems) * 0.2) * i for i in range(6)] 57 | start = fold_indices[fold] 58 | end = fold_indices[fold + 1] 59 | valid = MWPDataset(problems[start:end], 60 | preprocessor.indices_to_embeddings) 61 | 62 | # make prediction 63 | predict_filename = '{}.fold{}.{}'.format(args.output, 64 | fold, 65 | args.epoch) 66 | with open(predict_filename) as f: 67 | predicts = json.load(f) 68 | ys_ = [predict['ans'] for predict in predicts] 69 | ys_ = torch.tensor(ys_) 70 | 71 | # make answer list 72 | ys = torch.tensor([tofloat(p['ans']) for p in valid]) 73 | accuracy = (ys == ys_).float().mean().item() 74 | print('Accuracy = {}'.format(accuracy)) 75 | accuracys.append(accuracy) 76 | 77 | print('mean = {}'.format(sum(accuracys) / 5)) 78 | 79 | 80 | def _parse_args(): 81 | parser = argparse.ArgumentParser( 82 | description="Script to train the MWP solver.") 83 | # parser.add_argument('data_path', type=str, 84 | # help='Path to the data.') 85 | # parser.add_argument('embedding_path', type=str, 86 | # help='Path to the embedding.') 87 | parser.add_argument('pickle_path', type=str, 88 | help='Path to the train valid pickle.') 89 | parser.add_argument('output', type=str, 90 | help='Dest to dump prediction.') 91 | parser.add_argument('--dim_hidden', type=int, default=256, 92 | help='Hidden state dimension of the encoder.') 93 | parser.add_argument('--batch_size', type=int, default=32, 94 | help='Batch size.') 95 | parser.add_argument('--device', default=None, 96 | help='Device used to train.') 97 | parser.add_argument('--to_test', type=str, 98 | default='valid', help='To dump train or valid.') 99 | parser.add_argument('--arch', type=str, 100 | default='NSMv1', help='To dump train or valid.') 101 | parser.add_argument('--epoch', type=int, 102 | default=14, help='Index of the epoch to use.') 103 | args = parser.parse_args() 104 | return args 105 | 106 | 107 | class DumpHook: 108 | def __init__(self): 109 | self.outputs = [] 110 | self.batch_outputs = [] 111 | 112 | def forward_hook(self, module, inputs, outputs): 113 | self.batch_outputs.append(outputs) 114 | 115 | def flush_batch(self): 116 | self.outputs.append(self.batch_outputs) 117 | self.batch_outputs = [] 118 | 119 | 120 | if __name__ == '__main__': 121 | args = _parse_args() 122 | try: 123 | main(args) 124 | except KeyboardInterrupt: 125 | pass 126 | except BaseException: 127 | type, value, tb = sys.exc_info() 128 | traceback.print_exc() 129 | pdb.post_mortem(tb) 130 | -------------------------------------------------------------------------------- /src/calc_score.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pdb 3 | import sys 4 | import traceback 5 | import json 6 | import pickle 7 | import re 8 | import numpy as np 9 | 10 | 11 | def tofloat(text): 12 | if '%' in text: 13 | return float(text[:-1]) / 100 14 | 15 | if '/' in text: 16 | match = re.search(r'(\d*)\(\((\d+)\)/\((\d+)\)\)', text) 17 | a = 0 if match.group(1) == '' else int(match.group(1)) 18 | return a + int(match.group(2)) / int(match.group(3)) 19 | 20 | return float(text) 21 | 22 | 23 | def main(args): 24 | with open(args.pickle, 'rb') as f: 25 | data = pickle.load(f)[args.to_test] 26 | with open(args.predict) as f: 27 | raw = json.load(f) 28 | 29 | predict = np.array([p['ans'] if p['ans'] is not None else 0 30 | for p in raw]) 31 | confidence = np.array([p['confidence'] for p in raw]) 32 | answer = np.array([tofloat(d['ans']) for d in data]) 33 | 34 | if args.retrieval is not None: 35 | with open(args.retrieval) as f: 36 | raw = json.load(f) 37 | retrieval = np.array([p['ans'] if p['ans'] is not None else 0 38 | for p in raw]) 39 | predict[confidence < args.threshold] = \ 40 | retrieval[confidence < args.threshold] 41 | 42 | correct = np.abs(predict - answer) < args.epsilon 43 | accuracy = np.mean(correct) 44 | print('Accuracy = {}'.format(accuracy)) 45 | 46 | correct_confidence = confidence[np.where(correct)] 47 | incorrect_confidence = confidence[np.where(~correct)] 48 | print('Correct Confidence mean={}, std={}' 49 | .format(np.mean(correct_confidence), 50 | np.std(correct_confidence)) 51 | ) 52 | print('Incorrect Confidence mean={}, std={}' 53 | .format(np.mean(incorrect_confidence), 54 | np.std(incorrect_confidence)) 55 | ) 56 | 57 | if args.log is not None: 58 | with open(args.log, 'w') as f: 59 | log = { 60 | 'accuracy': accuracy, 61 | 'correct': list(predict - answer) 62 | } 63 | json.dump(log, f, indent=' ') 64 | 65 | 66 | def _parse_args(): 67 | parser = argparse.ArgumentParser( 68 | description="Calculate accuracy") 69 | parser.add_argument('pickle', type=str, 70 | help='') 71 | parser.add_argument('predict', type=str, 72 | help='') 73 | parser.add_argument('--to_test', type=str, 74 | default='valid', help='To dump train or valid.') 75 | parser.add_argument('--epsilon', type=float, 76 | default=1e-4, 77 | help='Error that is tolerant as correct.') 78 | parser.add_argument('--log', type=str, default=None, 79 | help='Destination of log file.') 80 | parser.add_argument('--retrieval', type=str, default=None, 81 | help='') 82 | parser.add_argument('--threshold', type=float, default=-0.6) 83 | args = parser.parse_args() 84 | return args 85 | 86 | 87 | if __name__ == '__main__': 88 | args = _parse_args() 89 | try: 90 | main(args) 91 | except KeyboardInterrupt: 92 | pass 93 | except BaseException: 94 | type, value, tb = sys.exc_info() 95 | traceback.print_exc() 96 | pdb.post_mortem(tb) 97 | -------------------------------------------------------------------------------- /src/calculator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pdb 3 | import re 4 | import sys 5 | import traceback 6 | from enum import Enum 7 | 8 | 9 | num_map = { 10 | '零': 0, '一': 1, '二': 2, '三': 3, '四': 4, 11 | '五': 5, '六': 6, '七': 7, '八': 8, '九': 9, 12 | '兩': 2 13 | } 14 | order_map1 = { 15 | '十': 10, '百': 100, '千': 1000 16 | } 17 | order_map2 = { 18 | '萬': 10000, '億': 100000000, '兆': 1000000000000 19 | } 20 | symbol_pattern = re.compile(r'^[ \d\+\-\*/\.\(\)]*$') 21 | word_pattern = re.compile( 22 | r'^( 等於|多少|次方|加上|減掉|減去|乘以|乘上|除以|然後|會是|是多少|' 23 | r'[後再點的一二三四五六七八九十百千萬億兆加減乘除' 24 | r'\d\+\-\*/\.^\(\)=])*[\??]?$') 25 | num_pattern = re.compile('[零一二三四五六七八九十百千萬億兆點兩]+') 26 | 27 | 28 | class ProblemType(Enum): 29 | SYMBOL = 'pure symbol' 30 | WORD = 'math word problem' 31 | SPOKEN = 'spoken equation' 32 | 33 | 34 | def chinese_to_number(inputs): 35 | if '點' in inputs: 36 | round_part, float_part = inputs.split('點') 37 | else: 38 | round_part, float_part = inputs, '' 39 | 40 | num_ord1 = 0 41 | num_ord2 = 0 42 | number = 0 43 | 44 | # parse the round part 45 | if round_part[0] == '十': 46 | num_ord2 = 10 47 | else: 48 | num_ord1 = num_map[round_part[0]] 49 | 50 | for i, char in enumerate(round_part[1:]): 51 | if char in num_map: 52 | # deal with cases "一萬五、一百五" 53 | if i == len(round_part[1:]) - 1: # last character 54 | if round_part[i] in order_map1: 55 | num_ord1 = num_map[char] * order_map1[round_part[i]] // 10 56 | elif round_part[i] in order_map2: 57 | num_ord1 = num_map[char] * order_map2[round_part[i]] // 10 58 | else: 59 | num_ord1 = num_ord1 * 10 + num_map[char] 60 | else: 61 | num_ord1 = num_ord1 * 10 + num_map[char] 62 | elif char in order_map1: 63 | num_ord2 += num_ord1 * order_map1[char] 64 | num_ord1 = 0 65 | elif char in order_map2: 66 | number += (num_ord2 + num_ord1) * order_map2[char] 67 | num_ord1 = 0 68 | num_ord2 = 0 69 | number += num_ord2 + num_ord1 70 | 71 | # parse the float part 72 | base = 0.1 73 | for char in float_part: 74 | number += num_map[char] * base 75 | base *= 0.1 76 | 77 | return number 78 | 79 | 80 | def spoken_to_symbol(spoken): 81 | spoken = re.sub(r'的([點零一二三四五六七八九十百千萬億兆]*)次方', r'**\1', spoken) 82 | spoken = re.sub(r'[上去以掉再然後會是多少等於=?]', '', spoken) 83 | operator_map = { 84 | '加': '+', 85 | '減': '-', 86 | '乘': '*', 87 | '除': '/', 88 | } 89 | for k, v in operator_map.items(): 90 | spoken = spoken.replace(k, v) 91 | 92 | operands = re.split(r'\*\*|[\+\-\*/]', spoken) 93 | operators = re.findall(r'\*\*|[\+\-\*/]', spoken) 94 | equation = str(chinese_to_number(operands[0])) 95 | for operand, operator in zip(operands[1:], operators): 96 | equation += operator + str(chinese_to_number(operand)) 97 | 98 | return equation 99 | 100 | 101 | def classify_input(inputs): 102 | """ Classify if the input is math word problem or spoken equation or pure 103 | symbol. 104 | """ 105 | if re.match(symbol_pattern, inputs): 106 | return ProblemType.SYMBOL 107 | 108 | if re.match(word_pattern, inputs): 109 | return ProblemType.SPOKEN 110 | 111 | return ProblemType.WORD 112 | 113 | 114 | def main(args): 115 | inputs_type = classify_input(args.inputs) 116 | print('input type = {}, equation = {}, answer = {}' 117 | .format(inputs_type.name, '', '')) 118 | 119 | 120 | def _parse_args(): 121 | parser = argparse.ArgumentParser(description="input") 122 | parser.add_argument('inputs', type=str, 123 | help='') 124 | args = parser.parse_args() 125 | return args 126 | 127 | 128 | if __name__ == '__main__': 129 | args = _parse_args() 130 | try: 131 | main(args) 132 | except KeyboardInterrupt: 133 | pass 134 | except BaseException: 135 | type, value, tb = sys.exc_info() 136 | traceback.print_exc() 137 | pdb.post_mortem(tb) 138 | -------------------------------------------------------------------------------- /src/callbacks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import pdb 4 | import json 5 | 6 | 7 | class Callback: 8 | def __init__(): 9 | pass 10 | 11 | def on_epoch_end(log_train, log_valid, model): 12 | pass 13 | 14 | 15 | class MetricsLogger(Callback): 16 | def __init__(self, log_dest): 17 | self.history = { 18 | 'train': [], 19 | 'valid': [] 20 | } 21 | self.log_dest = log_dest 22 | 23 | def on_epoch_end(self, log_train, log_valid, model): 24 | self.history['train'].append(log_train) 25 | self.history['valid'].append(log_valid) 26 | with open(self.log_dest, 'w') as f: 27 | json.dump(self.history, f, indent=' ') 28 | 29 | 30 | class ModelCheckpoint(Callback): 31 | def __init__(self, filepath, 32 | monitor='loss', 33 | verbose=0, 34 | mode='min'): 35 | self._filepath = filepath 36 | self._verbose = verbose 37 | self._monitor = monitor 38 | self._best = math.inf if mode == 'min' else - math.inf 39 | self._mode = mode 40 | 41 | def on_epoch_end(self, log_train, log_valid, model): 42 | if self._mode == 'min': 43 | score = log_valid[self._monitor] 44 | if score < self._best: 45 | self._best = score 46 | model.save(self._filepath) 47 | if self._verbose > 0: 48 | print('Best model saved (%f)' % score) 49 | 50 | elif self._mode == 'max': 51 | score = log_valid[self._monitor] 52 | if score > self._best: 53 | self._best = score 54 | model.save(self._filepath) 55 | if self._verbose > 0: 56 | print('Best model saved (%f)' % score) 57 | 58 | elif self._mode == 'all': 59 | model.save('{}.{}' 60 | .format(self._filepath, model._epoch)) 61 | -------------------------------------------------------------------------------- /src/define.py: -------------------------------------------------------------------------------- 1 | """ Constants used in the program. 2 | """ 3 | 4 | 5 | class OPERATIONS: 6 | NOOP = 0 7 | GEN_VAR = 1 8 | ADD = 2 9 | SUB = 3 10 | MUL = 4 11 | DIV = 5 12 | EQL = 6 13 | N_OPS = 7 14 | -------------------------------------------------------------------------------- /src/f1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pdb 3 | import sys 4 | import traceback 5 | import pickle 6 | import json 7 | import pickle 8 | import os 9 | from define import OPERATIONS 10 | from calc_5_fold import tofloat 11 | 12 | 13 | def main(args): 14 | with open(args.data_path, 'rb') as f: 15 | train = pickle.load(f)['train']._problems 16 | 17 | ops = ['noop', 'gv', '+', '-', '*', '/', '='] 18 | for sample in train: 19 | mapping = ops + [str(c) for c in sample['constants']] + ['x0'] 20 | sample['equations'] = [mapping[s] for s in sample['operations']] 21 | sample['equations'] = list(filter(lambda x: x not in ['noop', 'gv'], 22 | sample['equations'])) 23 | 24 | predicts = [] 25 | for i in range(5): 26 | predict_path = os.path.join(args.model_path, 27 | 'predict.json.fold{}.{}' 28 | .format(i, args.epoch)) 29 | with open(predict_path) as f: 30 | predicts += json.load(f) 31 | 32 | correct_indices, incorrect_indices = [], [] 33 | for i, (p, a) in enumerate(zip(predicts, train)): 34 | if p['ans'] == tofloat(a['ans']): 35 | correct_indices.append(i) 36 | else: 37 | incorrect_indices.append(i) 38 | 39 | ps, rs, f1s = [], [], [] 40 | for predict, answer in zip(predicts, train): 41 | eqp = eval(predict['equations']) 42 | eqa = answer['equations'] 43 | # eqp = collect_subtrees(eqp) 44 | # eqa = collect_subtrees(eqa) 45 | # eqp = list(filter(lambda x: x not in ['+', '-', '*', '/', '='], eqp)) 46 | # eqa = list(filter(lambda x: x not in ['+', '-', '*', '/', '='], eqa)) 47 | eqp, eqa = set(eqp), set(eqa) 48 | precision = len(eqp & eqa) / len(eqp) 49 | recall = len(eqp & eqa) / len(eqa) 50 | f1 = precision * recall / (precision + recall) * 2 51 | ps.append(precision) 52 | rs.append(recall) 53 | f1s.append(f1) 54 | 55 | print('accuracy = {}'.format(len(correct_indices) / len(predicts))) 56 | print('All, {}, {}, {}' 57 | .format(sum(ps) / len(ps), 58 | sum(rs) / len(rs), 59 | sum(f1s) / len(f1s) 60 | ) 61 | ) 62 | 63 | cps, crs, cf1s = ([ps[i] for i in correct_indices], 64 | [rs[i] for i in correct_indices], 65 | [f1s[i] for i in correct_indices]) 66 | print('Correct, {}, {}, {}' 67 | .format(sum(cps) / len(cps), 68 | sum(crs) / len(crs), 69 | sum(cf1s) / len(cf1s) 70 | ) 71 | ) 72 | ips, irs, if1s = ([ps[i] for i in incorrect_indices], 73 | [rs[i] for i in incorrect_indices], 74 | [f1s[i] for i in incorrect_indices]) 75 | print('Incorrect, {}, {}, {}' 76 | .format(sum(ips) / len(ips), 77 | sum(irs) / len(irs), 78 | sum(if1s) / len(if1s) 79 | ) 80 | ) 81 | 82 | 83 | def collect_subtrees(ops): 84 | subtrees = [] 85 | stack = [] 86 | for op in ops: 87 | if op in ['+', '-', '*', '/', '=']: 88 | opd1, opd2 = stack.pop(), stack.pop() 89 | 90 | if op in ['+', '*'] and opd2 > opd1: 91 | expr = opd1 + op + opd2 92 | 93 | expr = '({} {} {})'.format(opd2, op, opd1) 94 | 95 | stack.append(expr) 96 | subtrees.append(expr) 97 | else: 98 | stack.append(op) 99 | subtrees.append(op) 100 | 101 | return subtrees 102 | 103 | 104 | def _parse_args(): 105 | parser = argparse.ArgumentParser(description="") 106 | parser.add_argument('data_path', type=str, 107 | help='') 108 | parser.add_argument('model_path', type=str, 109 | help='') 110 | parser.add_argument('epoch', type=int) 111 | args = parser.parse_args() 112 | return args 113 | 114 | 115 | if __name__ == '__main__': 116 | args = _parse_args() 117 | try: 118 | main(args) 119 | except KeyboardInterrupt: 120 | pass 121 | except BaseException: 122 | type, value, tb = sys.exc_info() 123 | traceback.print_exc() 124 | pdb.post_mortem(tb) 125 | -------------------------------------------------------------------------------- /src/make_train_valid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pdb 4 | import sys 5 | import traceback 6 | import pickle 7 | import json 8 | 9 | 10 | def main(args): 11 | 12 | if args.dataset == 'Dolphin18k': 13 | from utils import Preprocessor as Preprocessor 14 | elif args.dataset == 'Math23k': 15 | from utils import Math23kPreprocessor as Preprocessor 16 | else: 17 | logging.error('Not compitable dataset!') 18 | return 19 | 20 | if args.index is not None: 21 | with open(args.index) as f: 22 | shuffled_index = json.load(f) 23 | else: 24 | shuffled_index = None 25 | 26 | preprocessor = Preprocessor(args.embedding_path) 27 | 28 | train, valid = preprocessor.get_train_valid_dataset(args.data_path, 29 | args.valid_ratio, 30 | index=shuffled_index, 31 | char_based=args.char_based) 32 | 33 | with open(args.output, 'wb') as f: 34 | pickle.dump({'train': train, 35 | 'valid': valid, 36 | 'preprocessor': preprocessor}, f) 37 | 38 | 39 | def _parse_args(): 40 | parser = argparse.ArgumentParser( 41 | description="Preprocess and generate preprocessed pickle.") 42 | parser.add_argument('data_path', type=str, 43 | help='Path to the data.') 44 | parser.add_argument('embedding_path', type=str, 45 | help='Path to the embedding.') 46 | parser.add_argument('output', type=str, 47 | help='Path to the output pickle file.') 48 | parser.add_argument('--dataset', type=str, default='Math23k', 49 | help='[Math23k|Dolphin18k]') 50 | parser.add_argument('--valid_ratio', type=float, default=0.2, 51 | help='Ratio of data used as validation set.') 52 | parser.add_argument('--index', type=str, default=None, 53 | help='JSON file that stores shuffled index.') 54 | parser.add_argument('--char_based', default=False, action='store_true', 55 | help='If segment the text based on char.') 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | if __name__ == '__main__': 61 | logging.basicConfig(level=logging.INFO) 62 | args = _parse_args() 63 | try: 64 | main(args) 65 | except KeyboardInterrupt: 66 | pass 67 | except BaseException: 68 | type, value, tb = sys.exc_info() 69 | traceback.print_exc() 70 | pdb.post_mortem(tb) 71 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .neural_solver_machine_v1 import * 2 | -------------------------------------------------------------------------------- /src/modules/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import pdb 4 | 5 | 6 | class Encoder(torch.nn.Module): 7 | """ Simple RNN encoder. 8 | 9 | Args: 10 | dim_embed (int): Dimension of input embedding. 11 | dim_hidden (int): Dimension of encoder RNN. 12 | dim_last (int): Dimension of the last state will be transformed to. 13 | dropout_rate (float): Dropout rate. 14 | """ 15 | def __init__(self, dim_embed, dim_hidden, dim_last, dropout_rate): 16 | super(Encoder, self).__init__() 17 | self.rnn = torch.nn.LSTM(dim_embed, 18 | dim_hidden, 19 | 1, 20 | bidirectional=True, 21 | batch_first=True) 22 | self.mlp1 = torch.nn.Sequential( 23 | torch.nn.Linear(dim_hidden * 2, dim_last), 24 | torch.nn.Dropout(dropout_rate), 25 | torch.nn.Tanh()) 26 | self.mlp2 = torch.nn.Sequential( 27 | torch.nn.Linear(dim_hidden * 2, dim_last), 28 | torch.nn.Dropout(dropout_rate), 29 | torch.nn.Tanh()) 30 | 31 | def forward(self, inputs, lengths): 32 | """ 33 | 34 | Args: 35 | inputs (tensor): Indices of words. The shape is `B x T x 1`. 36 | length (list of int): Length of inputs. 37 | 38 | Return: 39 | outputs (tensor): Encoded sequence. The shape is 40 | `B x T x dim_hidden`. 41 | """ 42 | packed = torch.nn.utils.rnn.pack_padded_sequence( 43 | inputs, lengths, batch_first=True) 44 | hidden_state = None 45 | outputs, hidden_state = self.rnn(packed, hidden_state) 46 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, 47 | batch_first=True) 48 | 49 | # reshape (2, batch, dim_hidden) to (batch, dim_hidden) 50 | hidden_state = \ 51 | (hidden_state[0].transpose(1, 0).contiguous() 52 | .view(hidden_state[0].size(1), -1), 53 | hidden_state[1].transpose(1, 0).contiguous() 54 | .view(hidden_state[1].size(1), -1)) 55 | hidden_state = \ 56 | (self.mlp1(hidden_state[0]).unsqueeze(0), 57 | self.mlp2(hidden_state[1]).unsqueeze(0)) 58 | 59 | return outputs, hidden_state 60 | 61 | 62 | class AttnEncoder(torch.nn.Module): 63 | """ Simple RNN encoder with attention which also extract variable embedding. 64 | 65 | Args: 66 | dim_embed (int): Dimension of input embedding. 67 | dim_hidden (int): Dimension of encoder RNN. 68 | dim_last (int): Dimension of the last state will be transformed to. 69 | dropout_rate (float): Dropout rate. 70 | """ 71 | def __init__(self, dim_embed, dim_hidden, dim_last, dropout_rate, 72 | dim_attn_hidden=256): 73 | super(AttnEncoder, self).__init__() 74 | self.rnn = torch.nn.LSTM(dim_embed, 75 | dim_hidden, 76 | 1, 77 | bidirectional=True, 78 | batch_first=True) 79 | self.mlp1 = torch.nn.Sequential( 80 | torch.nn.Linear(dim_hidden * 2, dim_last), 81 | torch.nn.Dropout(dropout_rate), 82 | torch.nn.Tanh()) 83 | self.mlp2 = torch.nn.Sequential( 84 | torch.nn.Linear(dim_hidden * 2, dim_last), 85 | torch.nn.Dropout(dropout_rate), 86 | torch.nn.Tanh()) 87 | self.attn = Attention(dim_hidden * 2, dim_hidden * 2, 88 | dim_attn_hidden) 89 | self.embedding_one = torch.nn.Parameter( 90 | torch.normal(torch.zeros(2 * dim_hidden), 0.01)) 91 | self.embedding_pi = torch.nn.Parameter( 92 | torch.normal(torch.zeros(2 * dim_hidden), 0.01)) 93 | self.register_buffer('padding', 94 | torch.zeros(dim_hidden * 2)) 95 | self.embeddings = torch.nn.Parameter( 96 | torch.normal(torch.zeros(20, 2 * dim_hidden), 0.01)) 97 | 98 | def forward(self, inputs, lengths, constant_indices): 99 | """ 100 | 101 | Args: 102 | inputs (tensor): Indices of words. The shape is `B x T x 1`. 103 | length (list of int): Length of inputs. 104 | constant_indices (list of int): Each list contains list 105 | 106 | Return: 107 | outputs (tensor): Encoded sequence. The shape is 108 | `B x T x dim_hidden`. 109 | """ 110 | packed = torch.nn.utils.rnn.pack_padded_sequence( 111 | inputs, lengths, batch_first=True) 112 | hidden_state = None 113 | outputs, hidden_state = self.rnn(packed, hidden_state) 114 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs, 115 | batch_first=True) 116 | 117 | # reshape (2, batch, dim_hidden) to (batch, dim_hidden) 118 | hidden_state = \ 119 | (hidden_state[0].transpose(1, 0).contiguous() 120 | .view(hidden_state[0].size(1), -1), 121 | hidden_state[1].transpose(1, 0).contiguous() 122 | .view(hidden_state[1].size(1), -1)) 123 | hidden_state = \ 124 | (self.mlp1(hidden_state[0]).unsqueeze(0), 125 | self.mlp2(hidden_state[1]).unsqueeze(0)) 126 | 127 | batch_size = outputs.size(0) 128 | operands = [[self.embedding_one, self.embedding_pi] + 129 | [outputs[b][i] 130 | for i in constant_indices[b]] 131 | for b in range(batch_size)] 132 | # operands = [[self.embedding_one, self.embedding_pi] + 133 | # [self.embeddings[i] 134 | # for i in range(len(constant_indices[b]))] 135 | # for b in range(batch_size)] 136 | # n_operands, operands = pad_and_cat(operands, self.padding) 137 | 138 | # attns = [] 139 | # for i in range(operands.size(1)): 140 | # attn = self.attn(outputs, operands[:, i], lengths) 141 | # attns.append(attn) 142 | 143 | # operands = [[self.embedding_one, self.embedding_pi] 144 | # + [attns[i][b] 145 | # for i in range(len(constant_indices[b]))] 146 | # for b in range(batch_size)] 147 | 148 | return outputs, hidden_state, operands 149 | 150 | 151 | class GenVar(torch.nn.Module): 152 | """ Module to generate variable embedding. 153 | 154 | Args: 155 | dim_encoder_state (int): Dimension of the last cell state of encoder 156 | RNN (output of Encoder module). 157 | dim_context (int): Dimension of RNN in GenVar module. 158 | dim_attn_hidden (int): Dimension of hidden layer in attention. 159 | dim_mlp_hiddens (int): Dimension of hidden layers in the MLP 160 | that transform encoder state to query of attention. 161 | dropout_rate (int): Dropout rate for attention and MLP. 162 | """ 163 | def __init__(self, dim_encoder_state, dim_context, 164 | dim_attn_hidden=256, dropout_rate=0.5): 165 | super(GenVar, self).__init__() 166 | self.attention = Attention( 167 | dim_context, dim_encoder_state, 168 | dim_attn_hidden, dropout_rate) 169 | 170 | def forward(self, encoder_state, context, context_lens): 171 | """ Generate embedding for an unknown variable. 172 | 173 | Args: 174 | encoder_state (FloatTensor): Last cell state of the encoder 175 | (output of Encoder module). 176 | context (FloatTensor): Encoded context, with size 177 | (batch_size, text_len, dim_hidden). 178 | 179 | Return: 180 | var_embedding (FloatTensor): Embedding of an unknown variable, 181 | with size (batch_size, dim_context) 182 | """ 183 | attn = self.attention(context, encoder_state.squeeze(0), context_lens) 184 | return attn 185 | 186 | 187 | class Transformer(torch.nn.Module): 188 | def __init__(self, dim_hidden): 189 | super(Transformer, self).__init__() 190 | self.mlp = torch.nn.Sequential( 191 | torch.nn.Linear(2 * dim_hidden, dim_hidden), 192 | torch.nn.ReLU(), 193 | torch.nn.Linear(dim_hidden, dim_hidden), 194 | torch.nn.Tanh() 195 | ) 196 | self.ret = torch.nn.Parameter(torch.zeros(dim_hidden)) 197 | torch.nn.init.normal_(self.ret.data) 198 | 199 | def forward(self, top2): 200 | return self.mlp(top2) 201 | # return torch.stack([self.ret] * top2.size(0), 0) 202 | 203 | 204 | class Attention(torch.nn.Module): 205 | """ Calculate attention 206 | 207 | Args: 208 | dim_value (int): Dimension of value. 209 | dim_query (int): Dimension of query. 210 | dim_hidden (int): Dimension of hidden layer in attention calculation. 211 | """ 212 | def __init__(self, dim_value, dim_query, 213 | dim_hidden=256, dropout_rate=0.5): 214 | super(Attention, self).__init__() 215 | self.relevant_score = \ 216 | MaskedRelevantScore(dim_value, dim_query, dim_hidden) 217 | 218 | def forward(self, value, query, lens): 219 | """ Generate variable embedding with attention. 220 | 221 | Args: 222 | query (FloatTensor): Current hidden state, with size 223 | (batch_size, dim_query). 224 | value (FloatTensor): Sequence to be attented, with size 225 | (batch_size, seq_len, dim_value). 226 | lens (list of int): Lengths of values in a batch. 227 | 228 | Return: 229 | FloatTensor: Calculated attention, with size 230 | (batch_size, dim_value). 231 | """ 232 | relevant_scores = self.relevant_score(value, query, lens) 233 | e_relevant_scores = torch.exp(relevant_scores) 234 | weights = e_relevant_scores / e_relevant_scores.sum(-1, keepdim=True) 235 | attention = (weights.unsqueeze(-1) * value).sum(1) 236 | return attention 237 | 238 | 239 | class MaskedRelevantScore(torch.nn.Module): 240 | """ Relevant score masked by sequence lengths. 241 | 242 | Args: 243 | dim_value (int): Dimension of value. 244 | dim_query (int): Dimension of query. 245 | dim_hidden (int): Dimension of hidden layer in attention calculation. 246 | """ 247 | def __init__(self, dim_value, dim_query, dim_hidden=256, 248 | dropout_rate=0.0): 249 | super(MaskedRelevantScore, self).__init__() 250 | self.dropout = torch.nn.Dropout(dropout_rate) 251 | self.relevant_score = RelevantScore(dim_value, dim_query, 252 | dim_hidden, 253 | dropout_rate) 254 | 255 | def forward(self, value, query, lens): 256 | """ Choose candidate from candidates. 257 | 258 | Args: 259 | query (FloatTensor): Current hidden state, with size 260 | (batch_size, dim_query). 261 | value (FloatTensor): Sequence to be attented, with size 262 | (batch_size, seq_len, dim_value). 263 | lens (list of int): Lengths of values in a batch. 264 | 265 | Return: 266 | tensor: Activation for each operand, with size 267 | (batch, max([len(os) for os in operands])). 268 | """ 269 | relevant_scores = self.relevant_score(value, query) 270 | 271 | # make mask to mask out padding embeddings 272 | mask = torch.zeros_like(relevant_scores) 273 | for b, n_c in enumerate(lens): 274 | mask[b, n_c:] = -math.inf 275 | 276 | # apply mask 277 | relevant_scores += mask 278 | 279 | return relevant_scores 280 | 281 | 282 | class RelevantScore(torch.nn.Module): 283 | def __init__(self, dim_value, dim_query, hidden1, dropout_rate=0): 284 | super(RelevantScore, self).__init__() 285 | self.lW1 = torch.nn.Linear(dim_value, hidden1, bias=False) 286 | self.lW2 = torch.nn.Linear(dim_query, hidden1, bias=False) 287 | self.b = torch.nn.Parameter( 288 | torch.normal(torch.zeros(1, 1, hidden1), 0.01)) 289 | self.tanh = torch.nn.Tanh() 290 | self.lw = torch.nn.Linear(hidden1, 1, bias=False) 291 | self.dropout = torch.nn.Dropout(dropout_rate) 292 | 293 | def forward(self, value, query): 294 | """ 295 | Args: 296 | value (FloatTensor): (batch, seq_len, dim_value). 297 | query (FloatTensor): (batch, dim_query). 298 | """ 299 | u = self.tanh(self.dropout( 300 | self.lW1(value) 301 | + self.lW2(query).unsqueeze(1) 302 | + self.b)) 303 | # u.size() == (batch, seq_len, dim_hidden) 304 | return self.lw(u).squeeze(-1) 305 | 306 | 307 | def pad_and_cat(tensors, padding): 308 | """ Pad lists to have same number of elements, and concatenate 309 | those elements to a 3d tensor. 310 | 311 | Args: 312 | tensors (list of list of Tensors): Each list contains 313 | list of operand embeddings. Each operand embedding is of 314 | size (dim_element,). 315 | padding (Tensor): 316 | Element used to pad lists, with size (dim_element,). 317 | 318 | Return: 319 | n_tensors (list of int): Length of lists in tensors. 320 | tensors (Tensor): Concatenated tensor after padding the list. 321 | """ 322 | n_tensors = [len(ts) for ts in tensors] 323 | pad_size = max(n_tensors) 324 | 325 | # pad to has same number of operands for each problem 326 | tensors = [ts + (pad_size - len(ts)) * [padding] 327 | for ts in tensors] 328 | 329 | # tensors.size() = (batch_size, pad_size, dim_hidden) 330 | tensors = torch.stack([torch.stack(t) 331 | for t in tensors], dim=0) 332 | 333 | return n_tensors, tensors 334 | -------------------------------------------------------------------------------- /src/modules/neural_solver_machine_v1.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch 3 | from define import OPERATIONS 4 | from .common import (AttnEncoder, Transformer, Attention, 5 | MaskedRelevantScore, pad_and_cat) 6 | 7 | 8 | class NeuralSolverMachineV1(torch.nn.Module): 9 | """ Neural Math Word Problem Solver Machine Version 1. 10 | 11 | Args: 12 | dim_embed (int): Dimension of text embeddings. 13 | dim_hidden (int): Dimension of encoder decoder hidden state. 14 | """ 15 | def __init__(self, dim_embed=300, dim_hidden=300, dropout_rate=0.5): 16 | super(NeuralSolverMachineV1, self).__init__() 17 | self.encoder = AttnEncoder(dim_embed, 18 | dim_hidden, 19 | dim_hidden, 20 | dropout_rate) 21 | self.decoder = Decoder(dim_hidden, dropout_rate) 22 | self.embedding_one = torch.nn.Parameter( 23 | torch.normal(torch.zeros(2 * dim_hidden), 0.01)) 24 | self.embedding_pi = torch.nn.Parameter( 25 | torch.normal(torch.zeros(2 * dim_hidden), 0.01)) 26 | 27 | 28 | class Decoder(torch.nn.Module): 29 | def __init__(self, dim_hidden=300, dropout_rate=0.5): 30 | super(Decoder, self).__init__() 31 | self.transformer_add = Transformer(2 * dim_hidden) 32 | self.transformer_sub = Transformer(2 * dim_hidden) 33 | self.transformer_mul = Transformer(2 * dim_hidden) 34 | self.transformer_div = Transformer(2 * dim_hidden) 35 | self.transformers = { 36 | OPERATIONS.ADD: self.transformer_add, 37 | OPERATIONS.SUB: self.transformer_sub, 38 | OPERATIONS.MUL: self.transformer_mul, 39 | OPERATIONS.DIV: self.transformer_div, 40 | OPERATIONS.EQL: None} 41 | self.gen_var = Attention(2 * dim_hidden, 42 | dim_hidden, 43 | dropout_rate=0.0) 44 | self.attention = Attention(2 * dim_hidden, 45 | dim_hidden, 46 | dropout_rate=dropout_rate) 47 | self.choose_arg = MaskedRelevantScore( 48 | dim_hidden * 2, 49 | dim_hidden * 7, 50 | dropout_rate=dropout_rate) 51 | self.arg_gate = torch.nn.Linear( 52 | dim_hidden * 7, 53 | 3, 54 | torch.nn.Sigmoid() 55 | ) 56 | self.rnn = torch.nn.LSTM(2 * dim_hidden, 57 | dim_hidden, 58 | 1, 59 | batch_first=True) 60 | self.op_selector = torch.nn.Sequential( 61 | torch.nn.Linear(dim_hidden * 7, 256), 62 | torch.nn.ReLU(), 63 | torch.nn.Dropout(dropout_rate), 64 | torch.nn.Linear(256, 8)) 65 | self.op_gate = torch.nn.Linear( 66 | dim_hidden * 7, 67 | 3, 68 | torch.nn.Sigmoid() 69 | ) 70 | self.dropout = torch.nn.Dropout(dropout_rate) 71 | self.register_buffer('noop_padding_return', 72 | torch.zeros(dim_hidden * 2)) 73 | self.register_buffer('padding_embedding', 74 | torch.zeros(dim_hidden * 2)) 75 | 76 | def forward(self, context, text_len, operands, stacks, 77 | prev_op, prev_output, prev_state): 78 | """ 79 | Args: 80 | context (FloatTensor): Encoded context, with size 81 | (batch_size, text_len, dim_hidden). 82 | text_len (LongTensor): Text length for each problem in the batch. 83 | operands (list of FloatTensor): List of operands embeddings for 84 | each problem in the batch. Each element in the list is of size 85 | (n_operands, dim_hidden). 86 | stacks (list of StackMachine): List of stack machines used for each 87 | problem. 88 | prev_op (LongTensor): Previous operation, with size (batch, 1). 89 | prev_arg (LongTensor): Previous argument indices, with size 90 | (batch, 1). Can be None for the first step. 91 | prev_output (FloatTensor): Previous decoder RNN outputs, with size 92 | (batch, dim_hidden). Can be None for the first step. 93 | prev_state (FloatTensor): Previous decoder RNN state, with size 94 | (batch, dim_hidden). Can be None for the first step. 95 | 96 | Returns: 97 | op_logits (FloatTensor): Logits of operation selection. 98 | arg_logits (FloatTensor): Logits of argument choosing. 99 | outputs (FloatTensor): Outputs of decoder RNN. 100 | state (FloatTensor): Hidden state of decoder RNN. 101 | """ 102 | batch_size = context.size(0) 103 | 104 | # collect stack states 105 | stack_states = \ 106 | torch.stack([stack.get_top2().view(-1,) for stack in stacks], 107 | dim=0) 108 | 109 | # skip the first step (all NOOP) 110 | if prev_output is not None: 111 | # result calculated batch-wise 112 | batch_result = { 113 | OPERATIONS.GEN_VAR: self.gen_var( 114 | context, prev_output, text_len), 115 | OPERATIONS.ADD: self.transformer_add(stack_states), 116 | OPERATIONS.SUB: self.transformer_sub(stack_states), 117 | OPERATIONS.MUL: self.transformer_mul(stack_states), 118 | OPERATIONS.DIV: self.transformer_div(stack_states) 119 | } 120 | 121 | prev_returns = [] 122 | # apply previous op on stacks 123 | for b in range(batch_size): 124 | # no op 125 | if prev_op[b].item() == OPERATIONS.NOOP: 126 | ret = self.noop_padding_return 127 | 128 | # generate variable 129 | elif prev_op[b].item() == OPERATIONS.GEN_VAR: 130 | variable = batch_result[OPERATIONS.GEN_VAR][b] 131 | operands[b].append(variable) 132 | stacks[b].add_variable(variable) 133 | ret = variable 134 | 135 | # OPERATIONS.ADD, SUB, MUL, DIV 136 | elif prev_op[b].item() in [OPERATIONS.ADD, OPERATIONS.SUB, 137 | OPERATIONS.MUL, OPERATIONS.DIV]: 138 | transformed = batch_result[prev_op[b].item()][b] 139 | ret = stacks[b].apply( 140 | prev_op[b].item(), 141 | transformed) 142 | 143 | elif prev_op[b].item() == OPERATIONS.EQL: 144 | ret = stacks[b].apply(prev_op[b].item(), None) 145 | 146 | # push operand 147 | else: 148 | stacks[b].push(prev_op[b].item() - OPERATIONS.N_OPS) 149 | ret = operands[b][prev_op[b].item() - OPERATIONS.N_OPS] 150 | prev_returns.append(ret) 151 | 152 | # collect stack states (after applied op) 153 | stack_states = \ 154 | torch.stack([stack.get_top2().view(-1,) for stack in stacks], 155 | dim=0) 156 | 157 | # collect previous returns 158 | prev_returns = torch.stack(prev_returns) 159 | prev_returns = self.dropout(prev_returns) 160 | 161 | # decode 162 | outputs, hidden_state = self.rnn(prev_returns.unsqueeze(1), 163 | prev_state) 164 | outputs = outputs.squeeze(1) 165 | 166 | # attention 167 | attention = self.attention(context, outputs, text_len) 168 | 169 | # collect information for op selector 170 | gate_in = torch.cat([outputs, stack_states, attention], -1) 171 | op_gate_in = self.dropout(gate_in) 172 | op_gate = self.op_gate(op_gate_in) 173 | arg_gate_in = self.dropout(gate_in) 174 | arg_gate = self.arg_gate(arg_gate_in) 175 | op_in = torch.cat([op_gate[:, 0:1] * outputs, 176 | op_gate[:, 1:2] * stack_states, 177 | op_gate[:, 2:3] * attention], -1) 178 | arg_in = torch.cat([arg_gate[:, 0:1] * outputs, 179 | arg_gate[:, 1:2] * stack_states, 180 | arg_gate[:, 2:3] * attention], -1) 181 | # op_in = arg_in = torch.cat([outputs, stack_states, attention], -1) 182 | 183 | op_logits = self.op_selector(op_in) 184 | 185 | n_operands, cated_operands = \ 186 | pad_and_cat(operands, self.padding_embedding) 187 | arg_logits = self.choose_arg( 188 | cated_operands, arg_in, n_operands) 189 | 190 | return op_logits, arg_logits, outputs, hidden_state 191 | -------------------------------------------------------------------------------- /src/permute_stack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from define import OPERATIONS 3 | 4 | 5 | class TreeIterator: 6 | def __init__(self, root): 7 | self.root = root 8 | self.queue = [self.root] 9 | 10 | def __iter__(self): 11 | return self 12 | 13 | def __next__(self): 14 | if len(self.queue) == 0: 15 | raise StopIteration 16 | else: 17 | node = self.queue.pop() 18 | op, left, right = node 19 | if type(left) is list: 20 | self.queue.insert(0, left) 21 | if type(right) is list: 22 | self.queue.insert(0, right) 23 | 24 | return node 25 | 26 | 27 | def build_tree(stack_ops): 28 | stack = [] 29 | for op in stack_ops: 30 | if op in [OPERATIONS.NOOP]: 31 | continue 32 | if op >= OPERATIONS.N_OPS: 33 | stack.append(op) 34 | else: 35 | right = stack.pop() 36 | left = stack.pop() 37 | stack.append([op, left, right]) 38 | 39 | return stack[0] 40 | 41 | 42 | def tree_to_op(node): 43 | if type(node) is not list: 44 | return [node] 45 | else: 46 | op, left, right = node 47 | return tree_to_op(left) + tree_to_op(right) + [op] 48 | 49 | 50 | OP_INV_MAP = { 51 | OPERATIONS.ADD: OPERATIONS.SUB, 52 | OPERATIONS.SUB: OPERATIONS.ADD, 53 | OPERATIONS.MUL: OPERATIONS.DIV, 54 | OPERATIONS.DIV: OPERATIONS.MUL 55 | } 56 | 57 | 58 | def permute_stack_ops(stack_ops, revert_prob=0.25, transpose_prob=0.5): 59 | rands = torch.rand(len(stack_ops)) 60 | tree = build_tree(stack_ops[1:]) 61 | print(rands) 62 | 63 | # revert operands for ADD and MUL 64 | tree_it = TreeIterator(tree) 65 | for node, rand in zip(tree_it, rands): 66 | if rand < revert_prob: 67 | op, left, right = node 68 | if op in [OPERATIONS.ADD, OPERATIONS.MUL]: 69 | node[1], node[2] = right, left 70 | 71 | # transposition 72 | eq, left, right = tree 73 | print('right:', right) 74 | if rands[-1] < transpose_prob and type(right) is list: 75 | print('do trans') 76 | rop, rright, rleft = right 77 | left = [OP_INV_MAP[rop], left, rleft] 78 | right = rright 79 | tree = [eq, left, right] 80 | 81 | return [OPERATIONS.GEN_VAR] + tree_to_op(tree) 82 | 83 | 84 | class PermuteStackOps(object): 85 | def __init__(self, revert_prob=0.25, transpose_prob=0.5): 86 | self.revert_prob = revert_prob 87 | self.transpose_prob = transpose_prob 88 | 89 | def __call__(self, sample): 90 | sample['operations'] = permute_stack_ops(sample['operations'], 91 | self.revert_prob, 92 | self.transpose_prob) 93 | return sample 94 | -------------------------------------------------------------------------------- /src/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pdb 4 | import pickle 5 | import sys 6 | import traceback 7 | import json 8 | 9 | 10 | def main(args): 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | # preprocessor = Preprocessor(args.embedding_path) 14 | # train, valid = preprocessor.get_train_valid_dataset(args.data_path) 15 | 16 | with open(args.pickle_path, 'rb') as f: 17 | data = pickle.load(f) 18 | preprocessor = data['preprocessor'] 19 | to_test = data[args.to_test] 20 | 21 | if args.arch == 'NSMv1': 22 | from torch_solver import TorchSolver 23 | solver = TorchSolver( 24 | preprocessor.get_word_dim(), 25 | args.dim_hidden, 26 | batch_size=args.batch_size, 27 | n_epochs=10000, 28 | device=args.device) 29 | 30 | if args.arch == 'NSMv2': 31 | from torch_solver_v2 import TorchSolverV2 32 | solver = TorchSolverV2( 33 | preprocessor.get_word_dim(), 34 | args.dim_hidden, 35 | batch_size=args.batch_size, 36 | n_epochs=10000, 37 | device=args.device) 38 | 39 | if args.arch == 'NSMv3': 40 | from torch_solver_v3 import TorchSolverV3 41 | solver = TorchSolverV3( 42 | preprocessor.get_word_dim(), 43 | args.dim_hidden, 44 | batch_size=args.batch_size, 45 | n_epochs=10000, 46 | device=args.device) 47 | 48 | elif args.arch == 'seq2seq': 49 | from torch_seq2seq import TorchSeq2Seq 50 | solver = TorchSeq2Seq( 51 | preprocessor.get_vocabulary_size(), 52 | preprocessor.get_word_dim(), 53 | args.dim_hidden, 54 | embedding=preprocessor.get_embedding(), 55 | batch_size=args.batch_size, 56 | n_epochs=10000, 57 | device=args.device) 58 | 59 | solver.load(args.model_path) 60 | 61 | ys_ = solver.predict_dataset(to_test) 62 | for i in range(len(ys_)): 63 | ys_[i]['index'] = i 64 | try: 65 | ys_[i]['equations'] = str(ys_[i]['equations']) 66 | ys_[i]['ans'] = float(list(ys_[i]['ans'].values())[0]) 67 | except BaseException: 68 | ys_[i]['ans'] = None 69 | 70 | with open(args.output, 'w') as f: 71 | json.dump(ys_, f, indent=' ') 72 | 73 | 74 | def _parse_args(): 75 | parser = argparse.ArgumentParser( 76 | description="Script to train the MWP solver.") 77 | # parser.add_argument('data_path', type=str, 78 | # help='Path to the data.') 79 | # parser.add_argument('embedding_path', type=str, 80 | # help='Path to the embedding.') 81 | parser.add_argument('pickle_path', type=str, 82 | help='Path to the train valid pickle.') 83 | parser.add_argument('model_path', type=str, 84 | help='Path to the model checkpoint.') 85 | parser.add_argument('output', type=str, 86 | help='Dest to dump prediction.') 87 | parser.add_argument('--dim_hidden', type=int, default=256, 88 | help='Hidden state dimension of the encoder.') 89 | parser.add_argument('--batch_size', type=int, default=32, 90 | help='Batch size.') 91 | parser.add_argument('--device', default=None, 92 | help='Device used to train.') 93 | parser.add_argument('--to_test', type=str, 94 | default='valid', help='To dump train or valid.') 95 | parser.add_argument('--arch', type=str, 96 | default='NSMv1', help='To dump train or valid.') 97 | args = parser.parse_args() 98 | return args 99 | 100 | 101 | class DumpHook: 102 | def __init__(self): 103 | self.outputs = [] 104 | self.batch_outputs = [] 105 | 106 | def forward_hook(self, module, inputs, outputs): 107 | self.batch_outputs.append(outputs) 108 | 109 | def flush_batch(self): 110 | self.outputs.append(self.batch_outputs) 111 | self.batch_outputs = [] 112 | 113 | 114 | if __name__ == '__main__': 115 | args = _parse_args() 116 | try: 117 | main(args) 118 | except KeyboardInterrupt: 119 | pass 120 | except BaseException: 121 | type, value, tb = sys.exc_info() 122 | traceback.print_exc() 123 | pdb.post_mortem(tb) 124 | -------------------------------------------------------------------------------- /src/predict_5_fold.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import pdb 4 | import pickle 5 | import sys 6 | import traceback 7 | import json 8 | import torch 9 | from utils import MWPDataset 10 | from calc_score import tofloat 11 | 12 | 13 | def main(args): 14 | logging.basicConfig(level=logging.INFO) 15 | 16 | # preprocessor = Preprocessor(args.embedding_path) 17 | # train, valid = preprocessor.get_train_valid_dataset(args.data_path) 18 | 19 | with open(args.pickle_path, 'rb') as f: 20 | data = pickle.load(f) 21 | preprocessor = data['preprocessor'] 22 | problems = data['train']._problems 23 | 24 | if args.arch == 'NSMv1': 25 | from torch_solver import TorchSolver 26 | solver = TorchSolver( 27 | preprocessor.get_word_dim(), 28 | args.dim_hidden, 29 | batch_size=args.batch_size, 30 | n_epochs=10000, 31 | device=args.device) 32 | 33 | if args.arch == 'NSMv2': 34 | from torch_solver_v2 import TorchSolverV2 35 | solver = TorchSolverV2( 36 | preprocessor.get_word_dim(), 37 | args.dim_hidden, 38 | batch_size=args.batch_size, 39 | n_epochs=10000, 40 | device=args.device) 41 | 42 | elif args.arch == 'seq2seq': 43 | from torch_seq2seq import TorchSeq2Seq 44 | solver = TorchSeq2Seq( 45 | preprocessor.get_vocabulary_size(), 46 | preprocessor.get_word_dim(), 47 | args.dim_hidden, 48 | embedding=preprocessor.get_embedding(), 49 | batch_size=args.batch_size, 50 | n_epochs=10000, 51 | device=args.device) 52 | 53 | for fold in range(5): 54 | # load model 55 | solver.load('{}.fold{}.{}' 56 | .format(args.model_path, fold, args.epoch)) 57 | 58 | # make valid dataset 59 | fold_indices = [int(len(problems) * 0.2) * i for i in range(6)] 60 | start = fold_indices[fold] 61 | end = fold_indices[fold + 1] 62 | valid = MWPDataset(problems[start:end], 63 | preprocessor.indices_to_embeddings) 64 | 65 | # make prediction 66 | ys_ = solver.predict_dataset(valid) 67 | for i in range(len(ys_)): 68 | try: 69 | ys_[i]['equations'] = str(ys_[i]['equations']) 70 | ys_[i]['ans'] = float(list(ys_[i]['ans'].values())[0]) 71 | except BaseException: 72 | ys_[i]['ans'] = 0.0 73 | 74 | # dump prediction 75 | output_filename = '{}.fold{}.{}'.format(args.output, 76 | fold, 77 | args.epoch) 78 | with open(output_filename, 'w') as f: 79 | json.dump(ys_, f, indent=' ') 80 | 81 | # make answer list 82 | ys = torch.tensor([tofloat(p['ans']) for p in valid]) 83 | ys_ = torch.tensor([y['ans'] for y in ys_]) 84 | accuracy = (ys == ys_).float().mean().item() 85 | print('Accuracy = {}'.format(accuracy)) 86 | 87 | 88 | def _parse_args(): 89 | parser = argparse.ArgumentParser( 90 | description="Script to train the MWP solver.") 91 | # parser.add_argument('data_path', type=str, 92 | # help='Path to the data.') 93 | # parser.add_argument('embedding_path', type=str, 94 | # help='Path to the embedding.') 95 | parser.add_argument('pickle_path', type=str, 96 | help='Path to the train valid pickle.') 97 | parser.add_argument('model_path', type=str, 98 | help='Path to the model checkpoint. (Without .*)') 99 | parser.add_argument('output', type=str, 100 | help='Dest to dump prediction.') 101 | parser.add_argument('--dim_hidden', type=int, default=256, 102 | help='Hidden state dimension of the encoder.') 103 | parser.add_argument('--batch_size', type=int, default=32, 104 | help='Batch size.') 105 | parser.add_argument('--device', default=None, 106 | help='Device used to train.') 107 | parser.add_argument('--to_test', type=str, 108 | default='valid', help='To dump train or valid.') 109 | parser.add_argument('--arch', type=str, 110 | default='NSMv1', help='To dump train or valid.') 111 | parser.add_argument('--epoch', type=int, 112 | default=14, help='Index of the epoch to use.') 113 | args = parser.parse_args() 114 | return args 115 | 116 | 117 | class DumpHook: 118 | def __init__(self): 119 | self.outputs = [] 120 | self.batch_outputs = [] 121 | 122 | def forward_hook(self, module, inputs, outputs): 123 | self.batch_outputs.append(outputs) 124 | 125 | def flush_batch(self): 126 | self.outputs.append(self.batch_outputs) 127 | self.batch_outputs = [] 128 | 129 | 130 | if __name__ == '__main__': 131 | args = _parse_args() 132 | try: 133 | main(args) 134 | except KeyboardInterrupt: 135 | pass 136 | except BaseException: 137 | type, value, tb = sys.exc_info() 138 | traceback.print_exc() 139 | pdb.post_mortem(tb) 140 | -------------------------------------------------------------------------------- /src/pytorch_base.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import torch 4 | import torch.utils.data.dataloader 5 | from torch.autograd import Variable 6 | from tqdm import tqdm 7 | import pdb 8 | 9 | 10 | class TorchBase(): 11 | def _run_iter(self, batch, training): 12 | pass 13 | 14 | def _predict_batch(self, batch): 15 | pass 16 | 17 | def _run_epoch(self, dataloader, training): 18 | # set model training/evaluation mode 19 | self._model.train(training) 20 | 21 | # run batches for train 22 | loss = 0 23 | 24 | # init metric_scores 25 | # metric_scores = {} 26 | # for metric in self._metrics: 27 | # metric_scores[metric] = 0 28 | 29 | for batch in tqdm(dataloader): 30 | outputs, batch_loss = \ 31 | self._run_iter(batch, training) 32 | 33 | if training: 34 | self._optimizer.zero_grad() 35 | batch_loss.backward() 36 | self._optimizer.step() 37 | 38 | loss += batch_loss.item() 39 | # for metric, func in self._metrics.items(): 40 | # metric_scores[metric] += func( 41 | 42 | # calculate averate loss 43 | loss /= (len(dataloader) + 1e-6) 44 | 45 | epoch_log = {} 46 | epoch_log['loss'] = float(loss) 47 | print('loss=%f\n' % loss) 48 | return epoch_log 49 | 50 | def __init__(self, 51 | learning_rate=1e-3, batch_size=10, 52 | n_epochs=10, valid=None, 53 | reg_constant=0.0, 54 | device=None): 55 | 56 | self._learning_rate = learning_rate 57 | self._batch_size = batch_size 58 | self._n_epochs = n_epochs 59 | self._valid = valid 60 | self._reg_constant = reg_constant 61 | self._epoch = 0 62 | if device is not None: 63 | self._device = torch.device(device) 64 | else: 65 | self._device = torch.device('cuda:0' if torch.cuda.is_available() 66 | else 'cpu') 67 | 68 | def fit_dataset(self, data, callbacks=[]): 69 | # Start the training loop. 70 | while self._epoch < self._n_epochs: 71 | 72 | # train and evaluate train score 73 | print('training %i' % self._epoch) 74 | dataloader = torch.utils.data.DataLoader( 75 | data, 76 | batch_size=self._batch_size, 77 | shuffle=True, 78 | collate_fn=skip_list_collate, 79 | num_workers=0) 80 | # train epoch 81 | log_train = self._run_epoch(dataloader, True) 82 | 83 | # evaluate valid score 84 | if self._valid is not None and len(self._valid) > 0: 85 | print('evaluating %i' % self._epoch) 86 | dataloader = torch.utils.data.DataLoader( 87 | self._valid, 88 | batch_size=self._batch_size, 89 | shuffle=True, 90 | collate_fn=skip_list_collate, 91 | num_workers=1) 92 | # evaluate model 93 | log_valid = self._run_epoch(dataloader, False) 94 | else: 95 | log_valid = {} 96 | 97 | for callback in callbacks: 98 | callback.on_epoch_end(log_train, log_valid, self) 99 | 100 | self._epoch += 1 101 | 102 | def predict_dataset(self, data, 103 | batch_size=None, 104 | predict_fn=None, 105 | progress_bar=True): 106 | if batch_size is None: 107 | batch_size = self._batch_size 108 | if predict_fn is None: 109 | predict_fn = self._predict_batch 110 | 111 | # set model to eval mode 112 | self._model.eval() 113 | 114 | # make dataloader 115 | dataloader = torch.utils.data.DataLoader( 116 | data, 117 | batch_size=batch_size, 118 | shuffle=False, 119 | collate_fn=skip_list_collate, 120 | num_workers=1) 121 | 122 | ys_ = [] 123 | dataloader = tqdm(dataloader) if progress_bar else dataloader 124 | for batch in dataloader: 125 | with torch.no_grad(): 126 | batch_y_ = predict_fn(batch) 127 | ys_ += batch_y_ 128 | 129 | return ys_ 130 | 131 | def save(self, path): 132 | torch.save({ 133 | 'epoch': self._epoch + 1, 134 | 'model': self._model.state_dict(), 135 | # 'optimizer': self._optimizer.state_dict() 136 | }, path) 137 | 138 | def load(self, path): 139 | checkpoint = torch.load(path) 140 | self._epoch = checkpoint['epoch'] 141 | self._model.load_state_dict(checkpoint['model']) 142 | # self._optimizer.load_state_dict(checkpoint['optimizer']) 143 | 144 | 145 | numpy_type_map = { 146 | 'float64': torch.DoubleTensor, 147 | 'float32': torch.FloatTensor, 148 | 'float16': torch.HalfTensor, 149 | 'int64': torch.LongTensor, 150 | 'int32': torch.IntTensor, 151 | 'int16': torch.ShortTensor, 152 | 'int8': torch.CharTensor, 153 | 'uint8': torch.ByteTensor, 154 | } 155 | 156 | 157 | def skip_list_collate(batch): 158 | """ 159 | Puts each data field into a tensor with outer dimension batch size. 160 | Do not collect list recursively. 161 | """ 162 | if torch.is_tensor(batch[0]): 163 | out = None 164 | if torch.utils.data.dataloader._use_shared_memory: 165 | # If we're in a background process, concatenate directly into a 166 | # shared memory tensor to avoid an extra copy 167 | numel = sum([x.numel() for x in batch]) 168 | storage = batch[0].storage()._new_shared(numel) 169 | out = batch[0].new(storage) 170 | return torch.stack(batch, 0, out=out) 171 | elif type(batch[0]).__module__ == 'numpy': 172 | elem = batch[0] 173 | if type(elem).__name__ == 'ndarray': 174 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 175 | if elem.shape == (): # scalars 176 | py_type = float if elem.dtype.name.startswith('float') else int 177 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 178 | elif isinstance(batch[0], int): 179 | return torch.LongTensor(batch) 180 | elif isinstance(batch[0], float): 181 | return torch.DoubleTensor(batch) 182 | elif isinstance(batch[0], (str, bytes)): 183 | return batch 184 | elif isinstance(batch[0], collections.Mapping): 185 | return {key: skip_list_collate([d[key] for d in batch]) for key in batch[0]} 186 | elif isinstance(batch[0], collections.Sequence): 187 | # do not collate list recursively 188 | return batch 189 | 190 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" 191 | .format(type(batch[0])))) 192 | -------------------------------------------------------------------------------- /src/stack_machine.py: -------------------------------------------------------------------------------- 1 | import sympy 2 | import torch 3 | import logging 4 | from define import OPERATIONS 5 | 6 | 7 | class StackMachine: 8 | """ 9 | 10 | Args: 11 | constants (list): Value of numbers. 12 | embeddings (tensor): Tensor of shape [len(constants), dim_embedding]. 13 | Embedding of the constants. 14 | bottom_embedding (teonsor): Tensor of shape (dim_embedding,). The 15 | embeding to return when stack is empty. 16 | """ 17 | def __init__(self, constants, embeddings, bottom_embedding, dry_run=False): 18 | self._operands = list(constants) 19 | self._embeddings = [embedding for embedding in embeddings] 20 | 21 | # number of unknown variables 22 | self._n_nuknown = 0 23 | 24 | # stack which stores (val, embed) tuples 25 | self._stack = [] 26 | 27 | # equations got from applying `=` on the stack 28 | self._equations = [] 29 | self.stack_log = [] 30 | 31 | # functions operate on value 32 | self._val_funcs = { 33 | OPERATIONS.ADD: sympy.Add, 34 | OPERATIONS.SUB: lambda a, b: sympy.Add(-a, b), 35 | OPERATIONS.MUL: sympy.Mul, 36 | OPERATIONS.DIV: lambda a, b: sympy.Mul(1/a, b) 37 | } 38 | self._op_chars = { 39 | OPERATIONS.ADD: '+', 40 | OPERATIONS.SUB: '-', 41 | OPERATIONS.MUL: '*', 42 | OPERATIONS.DIV: '/', 43 | OPERATIONS.EQL: '=' 44 | } 45 | 46 | self._bottom_embed = bottom_embedding 47 | 48 | if dry_run: 49 | self.apply = self.apply_embed_only 50 | 51 | def add_variable(self, embedding): 52 | """ Tell the stack machine to increase the number of nuknown variables 53 | by 1. 54 | 55 | Args: 56 | embedding (tensor): Tensor of shape (dim_embedding). Embedding 57 | of the unknown varialbe. 58 | """ 59 | var = sympy.Symbol('x{}'.format(self._n_nuknown)) 60 | self._operands.append(var) 61 | self._embeddings.append(embedding) 62 | self._n_nuknown += 1 63 | 64 | def push(self, operand_index): 65 | """ Push var to stack. 66 | 67 | Args: 68 | operand_index (int): Index of the operand. If 69 | index >= number of constants, then it implies a variable is 70 | pushed. 71 | Return: 72 | tensor: Simply return the pushed embedding. 73 | """ 74 | self._stack.append((self._operands[operand_index], 75 | self._embeddings[operand_index])) 76 | self.stack_log.append(str(self._operands[operand_index])) 77 | return self._embeddings[operand_index] 78 | 79 | def apply(self, operation, embed_res): 80 | """ Apply operator on stack. 81 | 82 | Args: 83 | operator (OPERATION): One of 84 | - OPERATIONS.ADD 85 | - OPERATIONS.SUB 86 | - OPERATIONS.MUL 87 | - OPERATIONS.DIV 88 | - OPERATIONS.EQL 89 | embed_res (FloatTensor): Resulted embedding after transformation, 90 | with size (dim_embedding,). 91 | Return: 92 | tensor: embeding on the top of the stack. 93 | """ 94 | val1, embed1 = self._stack.pop() 95 | val2, embed2 = self._stack.pop() 96 | if operation != OPERATIONS.EQL: 97 | try: 98 | # calcuate values in the equation 99 | val_res = self._val_funcs[operation](val1, val2) 100 | # transform embedding 101 | self._stack.append((val_res, embed_res)) 102 | except ZeroDivisionError: 103 | logging.warn('WARNING: zero division error, skip operation') 104 | else: 105 | self._equations.append(val1 - val2) 106 | # pass 107 | self.stack_log.append(self._op_chars[operation]) 108 | 109 | if len(self._stack) > 0: 110 | return self._stack[-1][1] 111 | else: 112 | return self._bottom_embed 113 | 114 | def apply_embed_only(self, operation, embed_res): 115 | """ Apply operator on stack with embedding operation only. 116 | 117 | Args: 118 | operator (OPERATION): One of 119 | - OPERATIONS.ADD 120 | - OPERATIONS.SUB 121 | - OPERATIONS.MUL 122 | - OPERATIONS.DIV 123 | - OPERATIONS.EQL 124 | embed_res (FloatTensor): Resulted embedding after transformation, 125 | with size (dim_embedding,). 126 | Return: 127 | tensor: embeding on the top of the stack. 128 | """ 129 | val1, embed1 = self._stack.pop() 130 | val2, embed2 = self._stack.pop() 131 | if operation != OPERATIONS.EQL: 132 | # calcuate values in the equation 133 | val_res = None 134 | # transform embedding 135 | self._stack.append((val_res, embed_res)) 136 | 137 | if len(self._stack) > 0: 138 | return self._stack[-1][1] 139 | else: 140 | return self._bottom_embed 141 | 142 | def get_solution(self): 143 | """ Get solution. If the problem has not been solved, return None. 144 | 145 | Return: 146 | list: If the problem has been solved, return result from 147 | sympy.solve. If not, return None. 148 | """ 149 | if self._n_nuknown == 0: 150 | return None 151 | 152 | root = sympy.solve(self._equations) 153 | for i in range(self._n_nuknown): 154 | if self._operands[-i - 1] not in root: 155 | return None 156 | 157 | return root 158 | 159 | def get_top2(self): 160 | """ Get the top 2 embeddings of the stack. 161 | 162 | Return: 163 | tensor: Return tensor of shape (2, embed_dim) 164 | """ 165 | if len(self._stack) >= 2: 166 | return torch.stack([self._stack[-1][1], 167 | self._stack[-2][1]], dim=0) 168 | elif len(self._stack) == 1: 169 | return torch.stack([self._stack[-1][1], 170 | self._bottom_embed], dim=0) 171 | else: 172 | return torch.stack([self._bottom_embed, 173 | self._bottom_embed], dim=0) 174 | 175 | def get_height(self): 176 | """ Get the height of the stack. 177 | 178 | Return: 179 | int: height 180 | """ 181 | return len(self._stack) 182 | 183 | def get_stack(self): 184 | return [self._bottom_embed] + [s[1] for s in self._stack] 185 | -------------------------------------------------------------------------------- /src/text_num_utils.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import re 3 | 4 | 5 | _to_9 = '(zero|one|two|three|four|five|six|seven|eight|nine)' 6 | xty = '(twenty|thirty|forty|fifty|sixty|seventy|eighty|ninety)' 7 | _to_19 = '(ten|eleven|twelve|thirteen|fourteen|fifteen' \ 8 | '|sixteen|seventeen|eighteen|nineteen|{})'.format(_to_9) 9 | _to_99 = '(({xty}[ -]{to_9})|{xty}|{to_19})'.format( 10 | to_19=_to_19, 11 | to_9=_to_9, 12 | xty=xty) 13 | _to_999 = '({to_9} hundred( (and )?{to_99})?|{to_99})'.format( 14 | to_9=_to_9, to_99=_to_99) 15 | _to_999999 = '({to_999} thousand( (and)? {to_999})?|{to_999})'.format( 16 | to_999=_to_999) 17 | _to_9x9 = '({to_999999} million( (and)? {to_999999})?|{to_999999})'.format( 18 | to_999999=_to_999999) 19 | _to_9x12 = '({to_9x9} billion( (and)? {to_9x9})?|{to_9x9})'.format( 20 | to_9x9=_to_9x9) 21 | 22 | _fraction = '({to_19}-(second|third|fourth|fifth|sixth|seventh|eighth|ninth|' \ 23 | 'tenth|eleventh|twelfth|thirteenth|fourteenth|fifteenth|' \ 24 | 'sixteenth|seventeenth|eighteenth|nineteenth|twentyth)|' \ 25 | 'half|quarter)'.format( 26 | to_19=_to_19) 27 | 28 | _numbers = '(({to_9x12} and )?{fraction}|{to_9x12})'.format( 29 | to_9x12=_to_9x12, fraction=_fraction) 30 | 31 | fraction_pattern = re.compile(_fraction) 32 | number_pattern = re.compile(_numbers) 33 | 34 | 35 | def text2num(text): 36 | """ Convert text to number. 37 | """ 38 | base = { 39 | 'one': 1, 40 | 'two': 2, 41 | 'three': 3, 42 | 'four': 4, 43 | 'five': 5, 44 | 'six': 6, 45 | 'seven': 7, 46 | 'eight': 8, 47 | 'nine': 9, 48 | 'ten': 10, 49 | 'eleven': 11, 50 | 'twelve': 12, 51 | 'thirteen': 13, 52 | 'fourteen': 14, 53 | 'fifteen': 15, 54 | 'sixteen': 16, 55 | 'seventeen': 17, 56 | 'eighteen': 18, 57 | 'nineteen': 19, 58 | 'twenty': 20, 59 | 'thirty': 30, 60 | 'forty': 40, 61 | 'fifty': 50, 62 | 'sixty': 60, 63 | 'seventy': 70, 64 | 'eighty': 80, 65 | 'ninety': 90, 66 | 'twice': 2, 67 | 'half': 0.5, 68 | 'quarter': 0.25} 69 | 70 | scale = { 71 | 'thousand': 1000, 72 | 'million': 1000000, 73 | 'billion': 1000000000} 74 | 75 | order = { 76 | 'second': 2, 77 | 'thirds': 3, 78 | 'fourths': 4, 79 | 'fifths': 5, 80 | 'sixths': 6, 81 | 'sevenths': 7, 82 | 'eighths': 8, 83 | 'nineths': 9, 84 | 'tenths': 10, 85 | 'elevenths': 11, 86 | 'twelfths': 12, 87 | 'thirteenths': 13, 88 | 'fourteenths': 14, 89 | 'fifteenths': 15, 90 | 'sixteenths': 16, 91 | 'seventeenths': 17, 92 | 'eighteenths': 18, 93 | 'nineteenths': 19, 94 | 'twentyths': 20, 95 | 'third': 3, 96 | 'fourth': 4, 97 | 'fifth': 5, 98 | 'sixth': 6, 99 | 'seventh': 7, 100 | 'eighth': 8, 101 | 'nineth': 9, 102 | 'tenth': 10, 103 | 'eleventh': 11, 104 | 'twelfth': 12, 105 | 'thirteenth': 13, 106 | 'fourteenth': 14, 107 | 'fifteenth': 15, 108 | 'sixteenth': 16, 109 | 'seventeenth': 17, 110 | 'eighteenth': 18, 111 | 'nineteenth': 19, 112 | 'twentyth': 20} 113 | 114 | tokens = [] 115 | for token in text.split(' '): 116 | if token == 'and': 117 | continue 118 | elif '-' in token: 119 | if token.split('-')[-1] in order: 120 | tokens.append(token) 121 | else: 122 | tokens += token.split('-') 123 | else: 124 | tokens.append(token) 125 | 126 | result = 0 127 | leading = 0 128 | for token in tokens: 129 | if token in base: 130 | leading += base[token] 131 | elif token == 'hundred': 132 | leading *= 100 133 | elif token in scale: 134 | result += leading * scale[token] 135 | leading = 0 136 | elif token in order: 137 | result += leading / order[token] 138 | leading = 0 139 | elif '-' in token: 140 | numerator, denominator = token.split('-') 141 | result += base[numerator] / order[denominator] 142 | result += leading 143 | 144 | return result 145 | 146 | 147 | def isfloat(s): 148 | try: 149 | float(s) 150 | return True 151 | except ValueError: 152 | return False 153 | -------------------------------------------------------------------------------- /src/torch_solver.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pdb 3 | import torch 4 | from define import OPERATIONS 5 | from pytorch_base import TorchBase 6 | from modules.neural_solver_machine_v1 import NeuralSolverMachineV1 7 | from stack_machine import StackMachine 8 | 9 | 10 | class TorchSolver(TorchBase): 11 | """ 12 | 13 | Args: 14 | dim_embed (int): Number of dimensions of word embeddings. 15 | dim_hidden (int): Number of dimensions of intermediate 16 | information embeddings. 17 | """ 18 | def __init__(self, dim_embed, dim_hidden, 19 | decoder_use_state=True, **kwargs): 20 | super(TorchSolver, self).__init__(**kwargs) 21 | self._dim_embed = dim_embed 22 | self._dim_hidden = dim_hidden 23 | self.use_state = decoder_use_state 24 | self._model = NeuralSolverMachineV1(dim_embed, dim_hidden, 0.1) 25 | 26 | # make class weights to ignore loss for padding operations 27 | class_weights = torch.ones(8) 28 | class_weights[OPERATIONS.NOOP] = 0 29 | 30 | # use cuda 31 | class_weights = class_weights.to(self._device) 32 | self._model = self._model.to(self._device) 33 | 34 | # make loss 35 | self._op_loss = torch.nn.CrossEntropyLoss(class_weights, 36 | size_average=False, 37 | reduce=False) 38 | self._arg_loss = torch.nn.CrossEntropyLoss() 39 | 40 | # make optimizer 41 | self._optimizer = torch.optim.Adam(self._model.parameters(), 42 | lr=self._learning_rate) 43 | 44 | def _run_iter(self, batch, training): 45 | order = torch.sort(batch['text_len'] * -1)[1] 46 | for k in batch: 47 | if type(batch[k]) is list: 48 | batch[k] = [batch[k][i] for i in order] 49 | else: 50 | batch[k] = batch[k][order] 51 | batch_size = len(order) 52 | 53 | # zero embedding for the stack bottom 54 | bottom = torch.zeros(self._dim_hidden * 2) 55 | bottom.requires_grad = False 56 | 57 | # deal with device 58 | text, ops, bottom = \ 59 | batch['text'].to(self._device), \ 60 | batch['operations'].to(self._device), \ 61 | bottom.to(self._device) 62 | 63 | # encode 64 | context, state, operands = \ 65 | self._model.encoder.forward(text, batch['text_len'], 66 | batch['constant_indices']) 67 | 68 | # extract constant embeddings 69 | # operands = [[self._model.embedding_one, self._model.embedding_pi] + 70 | # [context[b][i] 71 | # for i in batch['constant_indices'][b]] 72 | # for b in range(batch_size)] 73 | 74 | # initialize stacks 75 | stacks = [StackMachine(batch['constants'][b], operands[b], bottom, 76 | dry_run=True) 77 | for b in range(batch_size)] 78 | 79 | loss = torch.zeros(batch_size).to(self._device) 80 | prev_op = torch.zeros(batch_size).to(self._device) 81 | prev_output = None 82 | 83 | if self.use_state: 84 | prev_state = state 85 | else: 86 | prev_state = None 87 | for t in range(max(batch['op_len'])): 88 | # step one 89 | op_logits, arg_logits, prev_output, prev_state = \ 90 | self._model.decoder( 91 | context, batch['text_len'], operands, stacks, 92 | prev_op, prev_output, prev_state) 93 | 94 | # accumulate op loss 95 | op_target = torch.tensor(ops[:, t]) 96 | op_target[op_target >= OPERATIONS.N_OPS] = OPERATIONS.N_OPS 97 | op_target.require_grad = False 98 | loss += self._op_loss(op_logits, torch.tensor(op_target)) 99 | 100 | # accumulate arg loss 101 | for b in range(batch_size): 102 | if ops[b, t] < OPERATIONS.N_OPS: 103 | continue 104 | 105 | loss[b] += self._arg_loss( 106 | arg_logits[b].unsqueeze(0), 107 | ops[b, t].unsqueeze(0) - OPERATIONS.N_OPS) 108 | 109 | prev_op = ops[:, t] 110 | 111 | # if training: 112 | # weights = 1 / torch.tensor(batch['op_len']).to(self._device).float() 113 | # else: 114 | weights = 1 115 | 116 | loss = (loss * weights).mean() 117 | predicts = [stack.get_solution() for stack in stacks] 118 | 119 | return predicts, loss 120 | 121 | def _predict_batch(self, batch, max_len=30): 122 | order = torch.sort(batch['text_len'] * -1)[1] 123 | for k in batch: 124 | if type(batch[k]) is list: 125 | batch[k] = [batch[k][i] for i in order] 126 | else: 127 | batch[k] = batch[k][order] 128 | batch_size = len(order) 129 | 130 | # for constants, cindices, operations in zip(batch['constants'], 131 | # batch['constant_indices'], 132 | # batch['operations']): 133 | # used = set() 134 | # for op in operations: 135 | # if op >= OPERATIONS.N_OPS: 136 | # used.add(op - OPERATIONS.N_OPS) 137 | 138 | # for i in range(len(cindices) - 1, -1, -1): 139 | # if i not in used: 140 | # del constants[i + 2] 141 | # del cindices[i] 142 | 143 | # zero embedding for the stack bottom 144 | bottom = torch.zeros(self._dim_hidden * 2) 145 | bottom.requires_grad = False 146 | 147 | # deal with device 148 | text, bottom = \ 149 | batch['text'].to(self._device), \ 150 | bottom.to(self._device) 151 | 152 | # encode 153 | context, state, operands = \ 154 | self._model.encoder.forward(text, batch['text_len'], 155 | batch['constant_indices']) 156 | 157 | # extract constant embeddings 158 | # operands = [[self._model.embedding_one, self._model.embedding_pi] 159 | # + [context[b][i] 160 | # for i in batch['constant_indices'][b]] 161 | # for b in range(batch_size)] 162 | 163 | # initialize stacks 164 | stacks = [StackMachine(batch['constants'][b], operands[b], bottom) 165 | for b in range(batch_size)] 166 | 167 | loss = torch.zeros(batch_size).to(self._device) 168 | prev_op = torch.zeros(batch_size).to(self._device) 169 | prev_output = None 170 | prev_state = state 171 | finished = [False] * batch_size 172 | for t in range(40): 173 | op_logits, arg_logits, prev_output, prev_state = \ 174 | self._model.decoder( 175 | context, batch['text_len'], operands, stacks, 176 | prev_op, prev_output, prev_state) 177 | 178 | n_finished = 0 179 | for b in range(batch_size): 180 | if stacks[b].get_solution() is not None: 181 | finished[b] = True 182 | 183 | if finished[b]: 184 | op_logits[b, OPERATIONS.NOOP] = math.inf 185 | n_finished += 1 186 | 187 | if stacks[b].get_height() < 2: 188 | op_logits[b, OPERATIONS.ADD] = -math.inf 189 | op_logits[b, OPERATIONS.SUB] = -math.inf 190 | op_logits[b, OPERATIONS.MUL] = -math.inf 191 | op_logits[b, OPERATIONS.DIV] = -math.inf 192 | op_logits[b, OPERATIONS.EQL] = -math.inf 193 | 194 | op_loss, prev_op = torch.log( 195 | torch.nn.functional.softmax(op_logits, -1) 196 | ).max(-1) 197 | arg_loss, prev_arg = torch.log( 198 | torch.nn.functional.softmax(arg_logits, -1) 199 | ).max(-1) 200 | 201 | for b in range(batch_size): 202 | if prev_op[b] == OPERATIONS.N_OPS: 203 | prev_op[b] += prev_arg[b] 204 | loss[b] += arg_loss[b] 205 | 206 | if prev_op[b] != OPERATIONS.NOOP: 207 | loss[b] += op_loss[b] 208 | 209 | if n_finished == batch_size: 210 | break 211 | 212 | predicts = [None] * batch_size 213 | for i, o in enumerate(order): 214 | predicts[o] = { 215 | 'ans': stacks[i].get_solution(), 216 | 'equations': stacks[i].stack_log, 217 | 'confidence': loss[i].item() 218 | } 219 | 220 | return predicts 221 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import logging 4 | import pdb 5 | import pickle 6 | import sys 7 | import traceback 8 | from utils import Preprocessor 9 | from callbacks import ModelCheckpoint, MetricsLogger 10 | from permute_stack import PermuteStackOps 11 | 12 | 13 | def main(args): 14 | # preprocessor = Preprocessor(args.embedding_path) 15 | # train, valid = preprocessor.get_train_valid_dataset(args.data_path) 16 | 17 | with open(args.pickle_path, 'rb') as f: 18 | data = pickle.load(f) 19 | preprocessor = data['preprocessor'] 20 | train, valid = data['train'], data['valid'] 21 | 22 | if args.arch == 'NSMv1': 23 | from torch_solver import TorchSolver 24 | solver = TorchSolver( 25 | preprocessor.get_word_dim(), 26 | args.dim_hidden, 27 | valid=valid, 28 | batch_size=args.batch_size, 29 | n_epochs=args.n_epochs, 30 | learning_rate=args.learning_rate, 31 | device=args.device, 32 | decoder_use_state=args.decoder_use_state) 33 | 34 | # load model 35 | if args.load is not None: 36 | solver.load(args.load) 37 | 38 | if not args.five_fold: 39 | model_checkpoint = ModelCheckpoint(args.model_path, 40 | 'loss', 1, 'all') 41 | metrics_logger = MetricsLogger(args.log_path) 42 | solver.fit_dataset(train, [model_checkpoint, metrics_logger]) 43 | else: 44 | from utils import MWPDataset 45 | problems = train._problems 46 | fold_indices = [int(len(problems) * 0.2) * i for i in range(6)] 47 | for fold in range(5): 48 | train = [] 49 | for j in range(5): 50 | if j != fold: 51 | start = fold_indices[j] 52 | end = fold_indices[j + 1] 53 | train += problems[start:end] 54 | 55 | transform = \ 56 | PermuteStackOps(args.revert_prob, args.transpose_prob) \ 57 | if args.permute else None 58 | train = MWPDataset(train, preprocessor.indices_to_embeddings) 59 | logging.info('Start training fold {}'.format(fold)) 60 | model_checkpoint = ModelCheckpoint( 61 | '{}.fold{}'.format(args.model_path, fold), 62 | 'loss', 1, 'all') 63 | metrics_logger = MetricsLogger( 64 | '{}.fold{}'.format(args.log_path, fold)) 65 | solver = TorchSolver( 66 | preprocessor.get_word_dim(), 67 | args.dim_hidden, 68 | valid=valid, 69 | batch_size=args.batch_size, 70 | n_epochs=args.n_epochs, 71 | learning_rate=args.learning_rate, 72 | device=args.device) 73 | solver.fit_dataset(train, [model_checkpoint, metrics_logger]) 74 | 75 | 76 | def _parse_args(): 77 | parser = argparse.ArgumentParser( 78 | description="Script to train the MWP solver.") 79 | # parser.add_argument('data_path', type=str, 80 | # help='Path to the data.') 81 | # parser.add_argument('embedding_path', type=str, 82 | # help='Path to the embedding.') 83 | parser.add_argument('pickle_path', type=str, 84 | help='Path to the train valid pickle.') 85 | parser.add_argument('model_path', type=str, 86 | help='Path to the model checkpoint.') 87 | parser.add_argument('--log_path', type=str, default='./log.json', 88 | help='Path to the log file.') 89 | parser.add_argument('--dim_hidden', type=int, default=256, 90 | help='Hidden state dimension of the encoder.') 91 | parser.add_argument('--batch_size', type=int, default=32, 92 | help='Batch size.') 93 | parser.add_argument('--learning_rate', type=float, default=1e-3, 94 | help='Learning rate to use.') 95 | parser.add_argument('--n_epochs', type=int, default=30, 96 | help='Number of epochs to run.') 97 | parser.add_argument('--device', default=None, 98 | help='Device used to train. Can be cpu or cuda:0,' 99 | ' cuda:1, etc.') 100 | parser.add_argument('--load', default=None, type=str, 101 | help='Model to load.') 102 | parser.add_argument('--arch', default='NSMv1', type=str, 103 | help='Model architecture.') 104 | parser.add_argument('--five_fold', default=False, 105 | help='Wheather or not doing 5 fold cross validation', 106 | action='store_true') 107 | parser.add_argument('--decoder_use_state', default=False, 108 | help='', 109 | action='store_true') 110 | parser.add_argument('--permute', default=False, 111 | help='', 112 | action='store_true') 113 | parser.add_argument('--revert_prob', default=0.5, type=float) 114 | parser.add_argument('--transpose_prob', default=0.5, type=float) 115 | args = parser.parse_args() 116 | return args 117 | 118 | 119 | if __name__ == '__main__': 120 | args = _parse_args() 121 | logging.basicConfig(format='%(asctime)s | %(levelname)s | %(message)s', 122 | level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S') 123 | try: 124 | main(args) 125 | except KeyboardInterrupt: 126 | pass 127 | except BaseException: 128 | type, value, tb = sys.exc_info() 129 | traceback.print_exc() 130 | pdb.post_mortem(tb) 131 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import random 4 | import re 5 | import torch 6 | from torch.utils.data import Dataset 7 | from define import OPERATIONS 8 | from text_num_utils import isfloat, number_pattern, text2num 9 | import pdb 10 | 11 | 12 | class MWPDataset(Dataset): 13 | """ Dataset of math word problems. 14 | 15 | Args: 16 | problems (list): A list containing objects that have 17 | - text (list): List of indices of the words in a math word problem. 18 | - constants (list): Value of the constant in text. 19 | - constant_indices (list): Indices of the constant in text. 20 | - operations (list): List of OPERATIONS to use to solve a math word 21 | problem. 22 | - text_len (int): Actural length of the text. 23 | - ans (list): List of solutions. 24 | Note that the lists `text` and `operations` should be 25 | padded. 26 | encode_fn (function): A function that converts list of word indices to 27 | tensor consisting of word embeddings. 28 | """ 29 | def __init__(self, problems, encode_fn): 30 | self._encode_fn = encode_fn 31 | self._problems = problems 32 | 33 | def __len__(self): 34 | return len(self._problems) 35 | 36 | def __getitem__(self, index): 37 | # copy problem from list 38 | problem = dict(self._problems[index]) 39 | problem['indice'] = problem['text'] 40 | problem['text'] = self._encode_fn(problem['text']) 41 | return problem 42 | 43 | 44 | class Preprocessor: 45 | """ 46 | 47 | Args: 48 | embedding_path (str): Path to the embedding to use. 49 | """ 50 | def __init__(self, embedding_path, max_text_len=150, max_op_len=50): 51 | logging.info('loading embedding...') 52 | self.num_token = '' 53 | self._word_dict, self._embeddings = load_embeddings(embedding_path) 54 | self._gen_num_embedding() 55 | self._max_text_len = max_text_len 56 | self._max_op_len = max_op_len 57 | 58 | def get_train_valid_dataset(self, data_path, valid_ratio=0.2, 59 | index=None, char_based=False): 60 | """ Load data and return MWPDataset objects for training and validating. 61 | 62 | Args: 63 | data_path (str): Path to the data. 64 | valid_ratio (float): Ratio of the data to used as valid data. 65 | """ 66 | logging.info('loading dataset...') 67 | with open(data_path) as fp_data: 68 | raw_problems = json.load(fp_data) 69 | 70 | logging.info('preprocessing data...') 71 | processed = [] 72 | n_fail = 0 73 | for problem in raw_problems: 74 | try: 75 | processed_problem = self.preprocess_problem(problem, 76 | char_based=char_based) 77 | if (processed_problem['op_len'] > 0 and 78 | processed_problem['op_len'] < 25): 79 | processed.append(processed_problem) 80 | except (ValueError, 81 | ZeroDivisionError, 82 | EquationParsingException) as err: 83 | n_fail += 1 84 | text_key = 'segmented_text' \ 85 | if 'segmented_text' in problem else 'text' 86 | equation_key = 'equation' \ 87 | if 'equation' in problem else 'equations' 88 | logging.warn('Fail to parse:\n' 89 | 'error: {}\n' 90 | 'text: {}\n' 91 | 'equation: {}\n'.format( 92 | err, 93 | problem[text_key], 94 | problem[equation_key])) 95 | 96 | if n_fail > 0: 97 | logging.warn('Fail to parse {} problems!'.format(n_fail)) 98 | 99 | logging.info('Parsed {} problems.'.format(len(processed) - n_fail)) 100 | 101 | if index is None: 102 | random.shuffle(processed) 103 | else: 104 | processed = [processed[i] for i in index] 105 | n_valid = int(len(processed) * valid_ratio) 106 | 107 | return (MWPDataset(processed[n_valid:], self.indices_to_embeddings), 108 | MWPDataset(processed[:n_valid], self.indices_to_embeddings)) 109 | 110 | def preprocess_problem(self, problem, pad=True, char_based=False): 111 | """ Preprocess problem to convert a problem to the form required by solver. 112 | 113 | Args: 114 | problem (dict): A dictionary containing (loaded from Dolphin18k) 115 | - text (str): Text part in a math word problem. 116 | - ans (str): Optional, required only when training. 117 | - equations (str): Optional, required only when training. 118 | - id (str): Problem ID. 119 | 120 | Return: 121 | dict: A dictionary containing 122 | - text (list): List of indices of the words in a math word 123 | problem. It is padded to `max_text_len`. 124 | - text_len (int): Actural length of the text. 125 | - constants (list): Value of the constant in text. 126 | - constant_indices (list): Indices of the constant in text. 127 | - operations (tensor): OPERATIONS to use to solve a math word 128 | problem. It is padded to `max_op_len`. 129 | - op_len (int): Actural number of operations. 130 | - ans (list): List of solutions. 131 | - id (str): Problem ID. 132 | """ 133 | processed = {} 134 | processed['ans'] = problem['ans'] 135 | 136 | # replace numbers in words with digits 137 | text = replace_number_with_digits(problem['text'].lower()) 138 | tokens = sentence_to_tokens(text, char_based) 139 | 140 | processed['id'] = problem['id'] 141 | 142 | # extract number tokens 143 | processed['constants'] = [1, 3.14] 144 | processed['constant_indices'] = [] 145 | for i, token in enumerate(tokens): 146 | if isfloat(token): 147 | processed['constants'].append(float(token)) 148 | processed['constant_indices'].append( 149 | min(i, self._max_text_len - 1)) 150 | tokens[i] = self.num_token 151 | 152 | # get actural length before padding 153 | processed['text_len'] = min(self._max_text_len, len(tokens)) 154 | 155 | # pad with '' 156 | processed['text'] = self.tokens_to_indices(tokens) 157 | processed['text'] = pad_to_len(processed['text'], 158 | self._word_dict[''], 159 | self._max_text_len) 160 | 161 | # construct ground truth stack operations if 'euqations' is provided 162 | if 'equations' in problem: 163 | processed['operations'] = \ 164 | self.construct_stack_ops(problem['equations'], 165 | processed['constants'], 166 | processed['constant_indices']) 167 | processed['op_len'] = min(self._max_op_len, 168 | len(processed['operations'])) 169 | processed['operations'] = pad_to_len( 170 | processed['operations'], OPERATIONS.NOOP, self._max_op_len) 171 | processed['operations'] = \ 172 | torch.Tensor(processed['operations']).long() 173 | 174 | return processed 175 | 176 | def construct_stack_ops(self, unk_equations, constants, constant_indices): 177 | """ Construct stack operations that build the given equations. 178 | 179 | Args: 180 | unk_equations (str): `equations` attribute in Dolphin18k dataset. 181 | constants (list): Values of the constants in the text. 182 | constant_indices (list): Location (indices) of the constant in the 183 | text. 184 | 185 | Return: 186 | - operations (list): List of OPERATIONS to use to solve a math word 187 | problem. 188 | """ 189 | # split equations string 190 | _, *equations = unk_equations.split('\r\nequ: ') 191 | 192 | # find all unknown variables that appear in equations 193 | # (`unkn` part of `euqations` attribute in the dataset may not 194 | # contain all unknow vars in the `equ` part) 195 | unknowns = [] 196 | for match in re.finditer('[a-z]', ' '.join(equations)): 197 | if match.group() not in unknowns: 198 | unknowns.append(match.group()) 199 | 200 | # accumulator 201 | operations = [] 202 | 203 | # generate variable based on number of unknowns 204 | for _ in range(len(unknowns)): 205 | operations.append(OPERATIONS.GEN_VAR) 206 | 207 | # prepare list of operands 208 | operands = constants + unknowns 209 | 210 | # mapping from operator token to its encoding 211 | op_map = { 212 | '+': OPERATIONS.ADD, 213 | '-': OPERATIONS.SUB, 214 | '*': OPERATIONS.MUL, 215 | '/': OPERATIONS.DIV, 216 | '=': OPERATIONS.EQL 217 | } 218 | 219 | # start parsing equations 220 | for equation in equations: 221 | 222 | # substitute fraction in equation with float 223 | for match in re.finditer(r'\(([0-9]+)/([0-9]+)\)', equation): 224 | frac = int(match.group(1)) / int(match.group(2)) 225 | if frac in operands: 226 | equation = equation.replace(match.group(), str(frac)) 227 | 228 | postfix = infix2postfix(equation) 229 | for token in postfix: 230 | # deal with operators 231 | if token in op_map: 232 | operations.append(op_map[token]) 233 | 234 | # deal with operands 235 | else: 236 | if isfloat(token): 237 | token = float(token) 238 | operations.append(operands.index(token) + OPERATIONS.N_OPS) 239 | 240 | return operations 241 | 242 | def tokens_to_indices(self, tokens): 243 | word_indices = [] 244 | for w in tokens + ['']: 245 | if w in self._word_dict: 246 | word_indices.append(self._word_dict[w]) 247 | else: 248 | word_indices.append(self._word_dict['']) 249 | 250 | return word_indices 251 | 252 | def build_rev_dict(self): 253 | self._rev_dict = [None] * len(self._word_dict) 254 | for k, v in self._word_dict.items(): 255 | self._rev_dict[v] = k 256 | 257 | def indices_to_tokens(self, indices): 258 | return [self._rev_dict[i] for i in indices] 259 | 260 | def indices_to_embeddings(self, indices): 261 | return torch.stack([self._embeddings[i] for i in indices], dim=0) 262 | 263 | def get_word_dim(self): 264 | return self._embeddings.shape[1] 265 | 266 | def get_vocabulary_size(self): 267 | return self._embeddings.shape[0] 268 | 269 | def get_embedding(self): 270 | return self._embeddings 271 | 272 | def _gen_num_embedding(self): 273 | """ Generate embedding for number token. 274 | """ 275 | self._word_dict[self.num_token] = self._embeddings.size(0) 276 | num_indices = [self._word_dict[num] 277 | for num in '1234567890'] 278 | num_embedding = torch.mean(self._embeddings[num_indices], 279 | dim=0, keepdim=True) 280 | self._embeddings = torch.cat([self._embeddings, num_embedding], dim=0) 281 | 282 | 283 | class Math23kPreprocessor(Preprocessor): 284 | def __init__(self, *args, **kwargs): 285 | super(Math23kPreprocessor, self).__init__(*args, **kwargs) 286 | 287 | def preprocess_problem(self, problem, char_based=False): 288 | """ Preprocess problem to convert a problem to the form required by solver. 289 | 290 | Args: 291 | problem (dict): A dictionary containing (loaded from Dolphin18k) 292 | - id (str): Problem ID. 293 | - segmented_text (str): Segmented text part in a math word 294 | problem. 295 | - ans (str): Optional, required only when training. 296 | - equation (str): Optional, required only when training. 297 | 298 | Return: 299 | dict: A dictionary containing 300 | - text (list): List of indices of the words in a math word 301 | problem. It is padded to `max_text_len`. 302 | - text_len (int): Actural length of the text. 303 | - constants (list): Value of the constant in text. 304 | - constant_indices (list): Indices of the constant in text. 305 | - operations (tensor): OPERATIONS to use to solve a math word 306 | problem. It is padded to `max_op_len`. 307 | - op_len (int): Actural number of operations. 308 | - ans (list): List of solutions. 309 | """ 310 | intermediate = { 311 | 'id': problem['id'], 312 | 'text': problem['segmented_text'], 313 | 'ans': problem['ans'], 314 | 'equations': '\r\nequ: ' + problem['equation'] 315 | } 316 | 317 | for match in re.finditer(r'(\d*\.?\d+)%', intermediate['equations']): 318 | intermediate['equations'] = intermediate['equations'].replace( 319 | match.group(), 320 | str(float(match.group(1)) / 100)) 321 | intermediate['equations'] = intermediate['equations'].replace('[', '(') 322 | intermediate['equations'] = intermediate['equations'].replace(']', ')') 323 | return super(Math23kPreprocessor, self) \ 324 | .preprocess_problem(intermediate, char_based=char_based) 325 | 326 | 327 | def load_embeddings(embedding_path): 328 | word_dict = {} 329 | with open(embedding_path) as fp: 330 | embedding = [] 331 | 332 | row1 = fp.readline() 333 | # if the first row is not header 334 | if not re.match('^[0-9]+ [0-9]+$', row1): 335 | # seek to 0 336 | fp.seek(0) 337 | # otherwise ignore the header 338 | 339 | for i, line in enumerate(fp): 340 | cols = line.rstrip().split(' ') 341 | word = cols[0] 342 | word_dict[word] = i 343 | embedding.append([float(v) for v in cols[1:]]) 344 | 345 | if '' not in word_dict: 346 | word_dict[''] = len(embedding) 347 | embedding.append([0] * len(embedding[0])) 348 | 349 | if '' not in word_dict: 350 | word_dict[''] = len(embedding) 351 | embedding.append([0] * len(embedding[0])) 352 | 353 | return word_dict, torch.Tensor(embedding) 354 | 355 | 356 | def sentence_to_tokens(sentence, char_based=False): 357 | """ Normalize text and tokenize to tokens. 358 | """ 359 | if not char_based: 360 | sentence = sentence.replace('. ', ' . ') 361 | sentence = re.sub('.$', ' .', sentence) 362 | sentence = sentence.replace(', ', ' , ') 363 | sentence = sentence.replace('$', '$ ') 364 | sentence = sentence.replace('?', ' ?') 365 | sentence = sentence.replace('!', ' !') 366 | sentence = sentence.replace('"', ' "') 367 | sentence = sentence.replace('\'', ' \'') 368 | sentence = sentence.replace(';', ' ;') 369 | sentence = sentence.strip().lower().replace('\n', ' ') 370 | else: 371 | sentence = sentence.replace(' ', '') 372 | sentence = ' '.join(sentence) 373 | sentence = re.sub(r'(?<=[\d\.]) (?=[\d\.])', '', sentence) 374 | 375 | return sentence.split(' ') 376 | 377 | 378 | def pad_to_len(arr, pad, max_len): 379 | """ Pad and truncate to specific length. 380 | 381 | Args: 382 | arr (list): List to pad. 383 | pad: Element used to pad. 384 | max_len: Max langth of arr. 385 | """ 386 | padded = [pad] * max_len 387 | n_copy = min(len(arr), max_len) 388 | padded[:n_copy] = arr[:n_copy] 389 | return padded 390 | 391 | 392 | def infix2postfix(infix): 393 | """ Convert infix equation to postfix representation. 394 | 395 | Args: 396 | infix (str): Math expression. 397 | """ 398 | infix = re.sub(r'([\d\.]+) *([a-z]+)', r'\1 * \2', infix).strip() 399 | 400 | # add spaces between numbers and operators 401 | infix = re.sub(r'([+\*/\-\(\)=])', r' \1 ', infix).strip() 402 | 403 | # remove consequitive spaces in the expression 404 | infix = re.sub(r' +', r' ', infix) 405 | 406 | # so now numbers and operators are seperated by exactly one space 407 | tokens = infix.split(' ') 408 | 409 | # deal with negative symbol, so it will not be seen as minus latter 410 | redundant_minus_indices = [] 411 | for i, token in enumerate(tokens): 412 | if token == '-' and (i == 0 or 413 | (i > 0 and tokens[i-1] in '=+-*/(')): 414 | tokens[i + 1] = str(-float(tokens[i + 1])) 415 | redundant_minus_indices.append(i) 416 | for i in redundant_minus_indices[::-1]: 417 | del tokens[i] 418 | 419 | # convert tokens to postfix 420 | postfix = [] 421 | operator_stack = [] 422 | operators = '=()+-*/' 423 | try: 424 | for token in tokens: 425 | if token not in operators: 426 | postfix.append(token) 427 | elif token == '(': 428 | operator_stack.append(token) 429 | elif token == ')': 430 | while operator_stack[-1] != '(': 431 | postfix.append(operator_stack.pop()) 432 | 433 | operator_stack.pop() 434 | else: 435 | while (len(operator_stack) > 0 and 436 | operators.index(operator_stack[-1]) 437 | >= operators.index(token)): 438 | op = operator_stack.pop() 439 | # if op not in '()': 440 | postfix.append(op) 441 | operator_stack.append(token) 442 | except BaseException as exception: 443 | raise EquationParsingException(''.join(infix), 444 | exception) 445 | 446 | while len(operator_stack) > 0: 447 | op = operator_stack.pop() 448 | if op not in '()': 449 | postfix.append(op) 450 | 451 | return postfix 452 | 453 | 454 | def replace_number_with_digits(text): 455 | text = text.replace('$', '$ ') \ 456 | .replace('/', ' / ') \ 457 | .replace('a half', '1.5') 458 | text = text.replace('twice', '2 times') 459 | text = text.replace('double', '2 times') 460 | 461 | for match in re.finditer('([0-9]+) ([0-9]+) */ *([0-9]+)', text): 462 | frac = int(match.group(1)) + int(match.group(2)) / int(match.group(3)) 463 | text = text.replace(match.group(), str(frac)) 464 | 465 | for match in re.finditer(r'\(([0-9]+) */ *([0-9]+)\)', text): 466 | frac = int(match.group(1)) / int(match.group(2)) 467 | text = text.replace(match.group(), str(frac)) 468 | 469 | for match in re.finditer(r'([0-9]+\.)?[0-9]+%', text): 470 | percent_text = match.group() 471 | float_text = str(float(percent_text[:-1]) / 100) 472 | text = text.replace(percent_text, 473 | float_text) 474 | 475 | for match in re.finditer('\\d{1,3}(,\\d{3})+', text): 476 | match_text = match.group() 477 | text = text.replace(match_text, 478 | match_text.replace(',', '')) 479 | 480 | for num_text in number_pattern.finditer(text): 481 | text = text.replace(num_text.group(), 482 | str(text2num(num_text.group()))) 483 | 484 | text = re.sub(r'(-?\d+.\d+|\d+)', r' \1 ', text) 485 | text = re.sub(r' +', ' ', text) 486 | return text 487 | 488 | 489 | class EquationParsingException(BaseException): 490 | def __init__(self, equation, exception): 491 | self.equation = equation 492 | self.exception = exception 493 | 494 | def __repr__(self): 495 | return 'Fail to parse equation "{}".\nError: {}'.format( 496 | self.equation, 497 | self.exception) 498 | --------------------------------------------------------------------------------