├── .gitignore ├── Bart_Program ├── README.md ├── data.py ├── predict.py ├── preprocess.py └── train.py ├── Bart_SPARQL ├── README.md ├── data.py ├── predict.py ├── preprocess.py ├── sparql_engine.py └── train.py ├── BlindGRU ├── README.md ├── data.py ├── model.py ├── predict.py ├── preprocess.py └── train.py ├── KVMemNN ├── README.md ├── data.py ├── model.py ├── predict.py ├── preprocess.py └── train.py ├── LICENSE ├── Program ├── data.py ├── executor_rule.py ├── parser.py ├── predict.py ├── preprocess.py ├── readme.md └── train.py ├── README.md ├── RGCN ├── README.md ├── data.py ├── model.py ├── predict.py ├── preprocess.py └── train.py ├── SPARQL ├── README.md ├── data.py ├── model.py ├── predict.py ├── preprocess.py ├── sparql_engine.py └── train.py ├── SRN ├── data.py ├── input │ └── pgrk.txt ├── knowledge_graph.py ├── model.py ├── predict.py ├── preprocess.py ├── readme.md └── train.py ├── evaluate.py └── utils ├── BiGRU.py ├── load_kb.py ├── lr_scheduler.py ├── misc.py ├── pickle_glove.py └── value_class.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.sublime-workspace 2 | *.sublime-project 3 | test_dataset/ 4 | dataset/ 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /Bart_Program/README.md: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | - python3 3 | - pytorch==1.12.0 4 | - transformers==4.16.2 5 | - kopl==0.0.5 6 | 7 | ## How to run 8 | 1. Install the KoPL engine 9 | ``` 10 | pip install kopl 11 | ``` 12 | 2. Preprocess the training data, and copy the `./dataset/kb.json` into `output_dir` 13 | ``` 14 | python -m Bart_Program.preprocess --input_dir ./dataset --output_dir --model_name_or_path 15 | cp ./dataset/kb.json 16 | ``` 17 | 3. Train 18 | ``` 19 | python -m Bart_Program.train --input_dir --output_dir --save_dir --model_name_or_path 20 | ``` 21 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order. 22 | ``` 23 | python -m Bart_Program.predict --input_dir --save_dir --ckpt 24 | ``` 25 | 26 | ## Checkpoints 27 | 1. The pretrained Bart-base checkpoint without finetuning can be downloaded here [bart-base](https://cloud.tsinghua.edu.cn/f/3b59ec6c43034cfc8841/?dl=1) 28 | 2. The checkpoint for finetuned Bart_Program can be downloaded here [finetuned](https://cloud.tsinghua.edu.cn/f/5b82ae04f9f64d1c8d1d/?dl=1) 29 | 30 | ## Change Log 31 | 32 | - [2022/8/8] Upload the evaluation.py; update the KoPL engine based on [KoPL](https://github.com/THU-KEG/KoPL); update kb.json in dataset; 33 | 34 | - A different serializer and add special token in the tokenizer. Note that the argument is for --model_name_or_path for Bart_Program.train 35 | -------------------------------------------------------------------------------- /Bart_Program/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | from utils.misc import invert_dict 5 | 6 | def load_vocab(path): 7 | vocab = json.load(open(path)) 8 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) 9 | return vocab 10 | 11 | def collate(batch): 12 | batch = list(zip(*batch)) 13 | source_ids = torch.stack(batch[0]) 14 | source_mask = torch.stack(batch[1]) 15 | choices = torch.stack(batch[2]) 16 | if batch[-1][0] is None: 17 | target_ids, answer = None, None 18 | else: 19 | target_ids = torch.stack(batch[3]) 20 | answer = torch.cat(batch[4]) 21 | return source_ids, source_mask, choices, target_ids, answer 22 | 23 | 24 | class Dataset(torch.utils.data.Dataset): 25 | def __init__(self, inputs): 26 | self.source_ids, self.source_mask, self.target_ids, self.choices, self.answers = inputs 27 | self.is_test = len(self.answers)==0 28 | 29 | 30 | def __getitem__(self, index): 31 | source_ids = torch.LongTensor(self.source_ids[index]) 32 | source_mask = torch.LongTensor(self.source_mask[index]) 33 | choices = torch.LongTensor(self.choices[index]) 34 | if self.is_test: 35 | target_ids = None 36 | answer = None 37 | else: 38 | target_ids = torch.LongTensor(self.target_ids[index]) 39 | answer = torch.LongTensor([self.answers[index]]) 40 | return source_ids, source_mask, choices, target_ids, answer 41 | 42 | 43 | def __len__(self): 44 | return len(self.source_ids) 45 | 46 | 47 | class DataLoader(torch.utils.data.DataLoader): 48 | def __init__(self, vocab_json, question_pt, batch_size, training=False): 49 | vocab = load_vocab(vocab_json) 50 | if training: 51 | print('#vocab of answer: %d' % (len(vocab['answer_token_to_idx']))) 52 | 53 | inputs = [] 54 | with open(question_pt, 'rb') as f: 55 | for _ in range(5): 56 | inputs.append(pickle.load(f)) 57 | dataset = Dataset(inputs) 58 | # np.shuffle(dataset) 59 | # dataset = dataset[:(int)(len(dataset) / 10)] 60 | super().__init__( 61 | dataset, 62 | batch_size=batch_size, 63 | shuffle=training, 64 | collate_fn=collate, 65 | ) 66 | self.vocab = vocab -------------------------------------------------------------------------------- /Bart_Program/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import json 7 | from tqdm import tqdm 8 | from datetime import date 9 | from utils.misc import MetricLogger, seed_everything, ProgressBar 10 | from .data import DataLoader 11 | from transformers import BartConfig, BartForConditionalGeneration, BartTokenizer 12 | import torch.optim as optim 13 | import logging 14 | import time 15 | from utils.lr_scheduler import get_linear_schedule_with_warmup 16 | import re 17 | from kopl.kopl import KoPLEngine 18 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 19 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 20 | rootLogger = logging.getLogger() 21 | import warnings 22 | warnings.simplefilter("ignore") # hide warnings that caused by invalid sparql query 23 | from termcolor import colored 24 | 25 | def post_process(text): 26 | pattern = re.compile(r'".*?"') 27 | nes = [] 28 | for item in pattern.finditer(text): 29 | nes.append((item.group(), item.span())) 30 | pos = [0] 31 | for name, span in nes: 32 | pos += [span[0], span[1]] 33 | pos.append(len(text)) 34 | assert len(pos) % 2 == 0 35 | assert len(pos) / 2 == len(nes) + 1 36 | chunks = [text[pos[i]: pos[i+1]] for i in range(0, len(pos), 2)] 37 | for i in range(len(chunks)): 38 | chunks[i] = chunks[i].replace('?', ' ?').replace('.', ' .') 39 | bingo = '' 40 | for i in range(len(chunks) - 1): 41 | bingo += chunks[i] + nes[i][0] 42 | bingo += chunks[-1] 43 | return bingo 44 | 45 | def vis(args, kb, model, data, device, tokenizer): 46 | while True: 47 | text = input('Input your question:') 48 | with torch.no_grad(): 49 | input_ids = tokenizer.batch_encode_plus([text], max_length = 512, pad_to_max_length = True, return_tensors="pt", truncation = True) 50 | source_ids = input_ids['input_ids'].to(device) 51 | outputs = model.generate( 52 | input_ids=source_ids, 53 | max_length = 500, 54 | ) 55 | outputs = [tokenizer.decode(output_id, skip_special_tokens = True, clean_up_tokenization_spaces = True) for output_id in outputs] 56 | outputs = [post_process(output) for output in outputs] 57 | print(outputs[0]) 58 | 59 | def predict(args, model, data, device, tokenizer, executor): 60 | model.eval() 61 | count, correct = 0, 0 62 | with torch.no_grad(): 63 | all_outputs = [] 64 | for batch in tqdm(data, total=len(data)): 65 | source_ids = batch[0].to(device) 66 | outputs = model.generate( 67 | input_ids=source_ids, 68 | max_length = 500, 69 | ) 70 | 71 | all_outputs.extend(outputs.cpu().numpy()) 72 | 73 | outputs = [tokenizer.decode(output_id, skip_special_tokens = True, clean_up_tokenization_spaces = True) for output_id in all_outputs] 74 | with open(os.path.join(args.save_dir, 'predict.txt'), 'w') as f: 75 | 76 | for output in tqdm(outputs): 77 | chunks = output.split('') 78 | func_list = [] 79 | inputs_list = [] 80 | for chunk in chunks: 81 | chunk = chunk.strip() 82 | res = chunk.split('') 83 | res = [_.strip() for _ in res] 84 | if len(res) > 0: 85 | func = res[0] 86 | inputs = [] 87 | if len(res) > 1: 88 | for x in res[1:]: 89 | inputs.append(x) 90 | else: 91 | inputs = [] 92 | func_list.append(func) 93 | inputs_list.append(inputs) 94 | ans = executor.forward(func_list, inputs_list, ignore_error = True) 95 | if ans is None: 96 | ans = 'no' 97 | if isinstance(ans, list) and len(ans) > 0: 98 | ans = ans[0] 99 | if isinstance(ans, list) and len(ans) == 0: 100 | ans = 'None' 101 | f.write(ans + '\n') 102 | 103 | def validate(model, data, device, tokenizer, executor): 104 | model.eval() 105 | count, correct = 0, 0 106 | with torch.no_grad(): 107 | all_outputs = [] 108 | all_answers = [] 109 | for batch in tqdm(data, total=len(data)): 110 | source_ids, source_mask, choices, target_ids, answer = [x.to(device) for x in batch] 111 | outputs = model.generate( 112 | input_ids=source_ids, 113 | max_length = 500, 114 | ) 115 | 116 | all_outputs.extend(outputs.cpu().numpy()) 117 | all_answers.extend(answer.cpu().numpy()) 118 | 119 | outputs = [tokenizer.decode(output_id, skip_special_tokens = True, clean_up_tokenization_spaces = True) for output_id in all_outputs] 120 | given_answer = [data.vocab['answer_idx_to_token'][a] for a in all_answers] 121 | for a, output in tqdm(zip(given_answer, outputs)): 122 | chunks = output.split('') 123 | func_list = [] 124 | inputs_list = [] 125 | for chunk in chunks: 126 | chunk = chunk.strip() 127 | res = chunk.split('') 128 | res = [_.strip() for _ in res] 129 | if len(res) > 0: 130 | func = res[0] 131 | inputs = [] 132 | if len(res) > 1: 133 | for x in res[1:]: 134 | inputs.append(x) 135 | else: 136 | inputs = [] 137 | func_list.append(func) 138 | inputs_list.append(inputs) 139 | ans = executor.forward(func_list, inputs_list, ignore_error = True) 140 | if ans is None: 141 | ans = 'no' 142 | if isinstance(ans, list) and len(ans) > 0: 143 | ans = ans[0] 144 | if ans == a: 145 | correct += 1 146 | count += 1 147 | acc = correct / count 148 | logging.info('acc: {}'.format(acc)) 149 | 150 | return acc 151 | 152 | 153 | 154 | def train(args): 155 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 156 | 157 | logging.info("Create train_loader and val_loader.........") 158 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 159 | val_pt = os.path.join(args.input_dir, 'test.pt') 160 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size) 161 | logging.info("Create model.........") 162 | config_class, model_class, tokenizer_class = (BartConfig, BartForConditionalGeneration, BartTokenizer) 163 | tokenizer = tokenizer_class.from_pretrained(os.path.join(args.ckpt)) 164 | model = model_class.from_pretrained(os.path.join(args.ckpt)) 165 | model = model.to(device) 166 | logging.info(model) 167 | engine = KoPLEngine(json.load(open(os.path.join(args.input_dir, 'kb.json')))) 168 | # validate(model, val_loader, device, tokenizer, engine) 169 | 170 | predict(args, model, val_loader, device, tokenizer, engine) 171 | def main(): 172 | parser = argparse.ArgumentParser() 173 | # input and output 174 | parser.add_argument('--input_dir', required=True) 175 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 176 | parser.add_argument('--ckpt', required=True) 177 | 178 | # training parameters 179 | parser.add_argument('--batch_size', default=256, type=int) 180 | parser.add_argument('--seed', type=int, default=666, help='random seed') 181 | 182 | # validating parameters 183 | # parser.add_argument('--num_return_sequences', default=1, type=int) 184 | # parser.add_argument('--top_p', default=) 185 | # model hyperparameters 186 | parser.add_argument('--dim_hidden', default=1024, type=int) 187 | parser.add_argument('--alpha', default = 1e-4, type = float) 188 | args = parser.parse_args() 189 | 190 | if not os.path.exists(args.save_dir): 191 | os.makedirs(args.save_dir) 192 | time_ = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 193 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, '{}.predict.log'.format(time_))) 194 | fileHandler.setFormatter(logFormatter) 195 | rootLogger.addHandler(fileHandler) 196 | # args display 197 | for k, v in vars(args).items(): 198 | logging.info(k+':'+str(v)) 199 | 200 | seed_everything(666) 201 | 202 | train(args) 203 | 204 | 205 | if __name__ == '__main__': 206 | main() 207 | 208 | -------------------------------------------------------------------------------- /Bart_Program/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | We need the last function to help extract the final answer of SPARQL, used in check_sparql 3 | """ 4 | 5 | import os 6 | import json 7 | import pickle 8 | import argparse 9 | import numpy as np 10 | from nltk import word_tokenize 11 | from collections import Counter 12 | from itertools import chain 13 | from tqdm import tqdm 14 | import re 15 | 16 | from utils.misc import init_vocab 17 | from transformers import * 18 | 19 | new_tokens = ['', ''] 20 | 21 | def get_program_seq(program): 22 | seq = [] 23 | for item in program: 24 | func = item['function'] 25 | inputs = item['inputs'] 26 | args = '' 27 | for input in inputs: 28 | args += ' ' + input 29 | seq.append(func + args) 30 | seq = ' '.join(seq) 31 | return seq 32 | 33 | def encode_dataset(dataset, vocab, tokenizer, test = False): 34 | questions = [] 35 | programs = [] 36 | for item in tqdm(dataset): 37 | question = item['question'] 38 | questions.append(question) 39 | if not test: 40 | program = item['program'] 41 | program = get_program_seq(program) 42 | programs.append(program) 43 | sequences = questions + programs 44 | print('tokenizing') 45 | encoded_inputs = tokenizer(sequences, padding = True) 46 | print('tokenize ended.') 47 | print(encoded_inputs.keys()) 48 | print(encoded_inputs['input_ids'][0]) 49 | print(tokenizer.decode(encoded_inputs['input_ids'][0])) 50 | print(tokenizer.decode(encoded_inputs['input_ids'][-1])) 51 | max_seq_length = len(encoded_inputs['input_ids'][0]) 52 | assert max_seq_length == len(encoded_inputs['input_ids'][-1]) 53 | print(max_seq_length) 54 | questions = [] 55 | programs = [] 56 | choices = [] 57 | answers = [] 58 | for item in tqdm(dataset): 59 | question = item['question'] 60 | questions.append(question) 61 | _ = [vocab['answer_token_to_idx'][w] for w in item['choices']] 62 | choices.append(_) 63 | if not test: 64 | program = item['program'] 65 | program = get_program_seq(program) 66 | programs.append(program) 67 | answers.append(vocab['answer_token_to_idx'].get(item['answer'])) 68 | 69 | input_ids = tokenizer.batch_encode_plus(questions, max_length = max_seq_length, pad_to_max_length = True, truncation = True) 70 | source_ids = np.array(input_ids['input_ids'], dtype = np.int32) 71 | source_mask = np.array(input_ids['attention_mask'], dtype = np.int32) 72 | if not test: 73 | target_ids = tokenizer.batch_encode_plus(programs, max_length = max_seq_length, pad_to_max_length = True, truncation = True) 74 | target_ids = np.array(target_ids['input_ids'], dtype = np.int32) 75 | else: 76 | target_ids = np.array([], dtype = np.int32) 77 | choices = np.array(choices, dtype = np.int32) 78 | answers = np.array(answers, dtype = np.int32) 79 | return source_ids, source_mask, target_ids, choices, answers 80 | 81 | 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--input_dir', required=True) 86 | parser.add_argument('--output_dir', required=True) 87 | parser.add_argument('--model_name_or_path', required=True) 88 | args = parser.parse_args() 89 | 90 | print('Build kb vocabulary') 91 | vocab = { 92 | 'answer_token_to_idx': {} 93 | } 94 | print('Load questions') 95 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json'))) 96 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json'))) 97 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json'))) 98 | for question in chain(train_set, val_set, test_set): 99 | for a in question['choices']: 100 | if not a in vocab['answer_token_to_idx']: 101 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 102 | 103 | if not os.path.isdir(args.output_dir): 104 | os.mkdir(args.output_dir) 105 | fn = os.path.join(args.output_dir, 'vocab.json') 106 | print('Dump vocab to {}'.format(fn)) 107 | with open(fn, 'w') as f: 108 | json.dump(vocab, f, indent=2) 109 | for k in vocab: 110 | print('{}:{}'.format(k, len(vocab[k]))) 111 | tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path) 112 | for token in new_tokens: 113 | # NOTE: in some newer versions of transformers, the special_tokens needs to be set as False 114 | tokenizer.add_tokens(token, special_tokens = True) 115 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)): 116 | print('Encode {} set'.format(name)) 117 | outputs = encode_dataset(dataset, vocab, tokenizer, name=='test') 118 | assert len(outputs) == 5 119 | print('shape of input_ids of questions, attention_mask of questions, input_ids of sparqls, choices and answers:') 120 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f: 121 | for o in outputs: 122 | print(o.shape) 123 | pickle.dump(o, f) 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /Bart_Program/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import json 7 | from tqdm import tqdm 8 | from datetime import date 9 | from utils.misc import MetricLogger, seed_everything, ProgressBar 10 | from .data import DataLoader 11 | from transformers import BartConfig, BartForConditionalGeneration, BartTokenizer 12 | import torch.optim as optim 13 | import logging 14 | import time 15 | from utils.lr_scheduler import get_linear_schedule_with_warmup 16 | from Bart_Program.predict import validate 17 | from kopl.kopl import KoPLEngine 18 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 19 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 20 | rootLogger = logging.getLogger() 21 | import warnings 22 | warnings.simplefilter("ignore") # hide warnings that caused by invalid sparql query 23 | 24 | new_tokens = ['', ''] 25 | 26 | def train(args): 27 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 28 | 29 | logging.info("Create train_loader and val_loader.........") 30 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 31 | train_pt = os.path.join(args.input_dir, 'train.pt') 32 | val_pt = os.path.join(args.input_dir, 'val.pt') 33 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True) 34 | val_loader = DataLoader(vocab_json, val_pt, 64) 35 | 36 | engine = KoPLEngine(json.load(open(os.path.join(args.input_dir, 'kb.json')))) 37 | logging.info("Create model.........") 38 | config_class, model_class, tokenizer_class = (BartConfig, BartForConditionalGeneration, BartTokenizer) 39 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) 40 | model = model_class.from_pretrained(args.model_name_or_path) 41 | added_tokens_num = tokenizer.add_tokens(new_tokens, special_tokens = True) 42 | print('added_tokens_num:', added_tokens_num) 43 | if added_tokens_num > 0: 44 | model.resize_token_embeddings(len(tokenizer)) 45 | 46 | model = model.to(device) 47 | logging.info(model) 48 | t_total = len(train_loader) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) 49 | no_decay = ["bias", "LayerNorm.weight"] 50 | bart_param_optimizer = list(model.named_parameters()) 51 | optimizer_grouped_parameters = [ 52 | {'params': [p for n, p in bart_param_optimizer if not any(nd in n for nd in no_decay)], 53 | 'weight_decay': args.weight_decay, 'lr': args.learning_rate}, 54 | {'params': [p for n, p in bart_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 55 | 'lr': args.learning_rate} 56 | ] 57 | args.warmup_steps = int(t_total * args.warmup_proportion) 58 | optimizer = optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 59 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, 60 | num_training_steps=t_total) 61 | # Check if saved optimizer or scheduler states exist 62 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( 63 | os.path.join(args.model_name_or_path, "scheduler.pt")): 64 | # Load in optimizer and scheduler states 65 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 66 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 67 | 68 | # Train! 69 | logging.info("***** Running training *****") 70 | logging.info(" Num examples = %d", len(train_loader.dataset)) 71 | logging.info(" Num Epochs = %d", args.num_train_epochs) 72 | logging.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 73 | logging.info(" Total optimization steps = %d", t_total) 74 | 75 | global_step = 0 76 | steps_trained_in_current_epoch = 0 77 | # Check if continuing training from a checkpoint 78 | if os.path.exists(args.model_name_or_path) and "checkpoint" in args.model_name_or_path: 79 | # set global_step to gobal_step of last saved checkpoint from model path 80 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) 81 | epochs_trained = global_step // (len(train_loader) // args.gradient_accumulation_steps) 82 | steps_trained_in_current_epoch = global_step % (len(train_loader) // args.gradient_accumulation_steps) 83 | logging.info(" Continuing training from checkpoint, will skip to saved global_step") 84 | logging.info(" Continuing training from epoch %d", epochs_trained) 85 | logging.info(" Continuing training from global step %d", global_step) 86 | logging.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 87 | logging.info('Checking...') 88 | logging.info("===================Dev==================") 89 | validate(model, val_loader, device, tokenizer, engine) 90 | tr_loss, logging_loss = 0.0, 0.0 91 | model.zero_grad() 92 | for _ in range(int(args.num_train_epochs)): 93 | pbar = ProgressBar(n_total=len(train_loader), desc='Training') 94 | for step, batch in enumerate(train_loader): 95 | # Skip past any already trained steps if resuming training 96 | if steps_trained_in_current_epoch > 0: 97 | steps_trained_in_current_epoch -= 1 98 | continue 99 | model.train() 100 | batch = tuple(t.to(device) for t in batch) 101 | pad_token_id = tokenizer.pad_token_id 102 | source_ids, source_mask, y = batch[0], batch[1], batch[-2] 103 | y_ids = y[:, :-1].contiguous() 104 | lm_labels = y[:, 1:].clone() 105 | lm_labels[y[:, 1:] == pad_token_id] = -100 106 | 107 | inputs = { 108 | "input_ids": source_ids.to(device), 109 | "attention_mask": source_mask.to(device), 110 | "decoder_input_ids": y_ids.to(device), 111 | "labels": lm_labels.to(device), 112 | } 113 | outputs = model(**inputs) 114 | loss = outputs[0] 115 | loss.backward() 116 | pbar(step, {'loss': loss.item()}) 117 | tr_loss += loss.item() 118 | if (step + 1) % args.gradient_accumulation_steps == 0: 119 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 120 | optimizer.step() 121 | scheduler.step() # Update learning rate schedule 122 | model.zero_grad() 123 | global_step += 1 124 | validate(model, val_loader, device, tokenizer, engine) 125 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 126 | if not os.path.exists(output_dir): 127 | os.makedirs(output_dir) 128 | model_to_save = ( 129 | model.module if hasattr(model, "module") else model 130 | ) # Take care of distributed/parallel training 131 | model_to_save.save_pretrained(output_dir) 132 | tokenizer.save_pretrained(output_dir) 133 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 134 | logging.info("Saving model checkpoint to %s", output_dir) 135 | # tokenizer.save_vocabulary(output_dir) 136 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 137 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 138 | logging.info("Saving optimizer and scheduler states to %s", output_dir) 139 | logging.info("\n") 140 | if 'cuda' in str(device): 141 | torch.cuda.empty_cache() 142 | return global_step, tr_loss / global_step 143 | 144 | 145 | def main(): 146 | parser = argparse.ArgumentParser() 147 | # input and output 148 | parser.add_argument('--input_dir', required=True) 149 | parser.add_argument('--output_dir', required=True) 150 | 151 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 152 | parser.add_argument('--model_name_or_path', required = True) 153 | parser.add_argument('--ckpt') 154 | 155 | # training parameters 156 | parser.add_argument('--weight_decay', default=1e-5, type=float) 157 | parser.add_argument('--batch_size', default=16, type=int) 158 | parser.add_argument('--seed', type=int, default=666, help='random seed') 159 | parser.add_argument('--learning_rate', default=3e-5, type = float) 160 | parser.add_argument('--num_train_epochs', default=25, type = int) 161 | parser.add_argument('--save_steps', default=448, type = int) 162 | parser.add_argument('--logging_steps', default=448, type = int) 163 | parser.add_argument('--warmup_proportion', default=0.1, type = float, 164 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training.") 165 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 166 | help="Epsilon for Adam optimizer.") 167 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, 168 | help="Number of updates steps to accumulate before performing a backward/update pass.", ) 169 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 170 | help="Max gradient norm.") 171 | 172 | # validating parameters 173 | # parser.add_argument('--num_return_sequences', default=1, type=int) 174 | # parser.add_argument('--top_p', default=) 175 | # model hyperparameters 176 | parser.add_argument('--dim_hidden', default=1024, type=int) 177 | parser.add_argument('--alpha', default = 1e-4, type = float) 178 | args = parser.parse_args() 179 | 180 | if not os.path.exists(args.save_dir): 181 | os.makedirs(args.save_dir) 182 | time_ = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 183 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, '{}.log'.format(time_))) 184 | fileHandler.setFormatter(logFormatter) 185 | rootLogger.addHandler(fileHandler) 186 | # args display 187 | for k, v in vars(args).items(): 188 | logging.info(k+':'+str(v)) 189 | 190 | seed_everything(666) 191 | 192 | train(args) 193 | 194 | 195 | if __name__ == '__main__': 196 | main() 197 | 198 | -------------------------------------------------------------------------------- /Bart_SPARQL/README.md: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | - python3.7 3 | - rdflib=4.2.2 or 6.1.1 4 | - transformers 5 | --- 6 | **Note for rdflib 4.2.2:** 7 | After installing rdflib via `pip` or `anaconda` or some other tools, we need to fix some bugs of it. 8 | 9 | First, find your rdflib location. One possible way is to run following codes in ipython 10 | ``` 11 | import rdflib 12 | rdflib.__file__ 13 | ``` 14 | It returns `~/anaconda3/lib/python3.7/site-packages/rdflib/__init__.py` in my computer, so I enter the folder `~/anaconda3/lib/python3.7/site-packages/rdflib`. 15 | 16 | Then open `plugins/sparql/parser.py`, find *Line 68*, replace its code with 17 | ``` 18 | if i + 1 < l and (not isinstance(terms[i + 1], str) or terms[i + 1] not in ".,;"): 19 | ``` 20 | Remember to keep the original indentation. 21 | Note that *Line 67* is a comment of `# is this bnode the subject of more triplets?`. If your line number is different from mine, you could locate the target line by this comment. 22 | 23 | Finally, open `plugins/serializers/turtle.py`, find *Line 328*, change `use_plain=True` to `use_plain=False` 24 | 25 | 26 | **Note for rdflib 6.1.1:** 27 | If you have an erro " can't set attribute" with rdflib=4.2.2,you should try rdflib=6.1.1 . 28 | 29 | --- 30 | 31 | - SPARQLWrapper=1.8.4 32 | 33 | --- 34 | **Note:** 35 | When installing `SPARQLWrapper` with `pip`, it may automatically install another package `keepalive`. You can check whether it is in your environment by 36 | ``` 37 | pip show keepalive 38 | ``` 39 | 40 | If it is installed, it will cause some problems when we execute a large number of SPARQL queries. Specifically, the available ports will be used out. So we need to manually disable the `keepalive` package. It is okay to directly remove it. 41 | ``` 42 | pip uninstall keepalive 43 | ``` 44 | 45 | --- 46 | 47 | - Virtuoso backend, refer to the next section 48 | 49 | ## How to install virtuoso backend 50 | The virtuoso backend will start up a web service, we can import our kb into it and then execute SPARQL queries by network requests. We install virtuoso in an Ubuntu 16.04 system. Following are specific steps. 51 | 52 | 1. Download and install virtuoso into our system. 53 | ``` 54 | git clone https://github.com/openlink/virtuoso-opensource.git Virtuoso-Opensource 55 | cd Virtuoso-Opensource 56 | git checkout stable/7 57 | sudo apt-get install libtool gawk gperf autoconf automake libtool flex bison m4 make openssl libssl-dev 58 | sudo ./autogen.sh 59 | sudo ./configure 60 | sudo make 61 | sudo make install 62 | ``` 63 | 64 | 2. Create a new user for virtuoso service 65 | ``` 66 | sudo useradd virtuoso --home /usr/local/virtuoso-opensource 67 | sudo chown -R virtuoso /usr/local/virtuoso-opensource 68 | ``` 69 | 70 | 3. Modify some necessary configs: 71 | ``` 72 | cd /usr/local/virtuoso-opensource/var/lib/virtuoso/db 73 | sudo vim virtuoso.ini 74 | ``` 75 | Find the item `CheckpointInterval`, and change its value from default 60 to 0, to avoid automatical checkpoint process which will cause 404 error. 76 | 77 | 4. Start up the virtuoso service: 78 | ``` 79 | sudo -H -u virtuoso ../../../../bin/virtuoso-t -f & 80 | ``` 81 | Now you can access the service via the default port 8890. 82 | Enter `[ip]:8890` in a browser, you will see the virtuoso service page. 83 | 84 | [note] The virtuoso may report an erro "There is no configuration file virtuoso.ini" when start up. 85 | ``` 86 | sudo vim /etc/rc.conf 87 | ``` 88 | Add a line:`virtuoso_config="/usr/local/virtuoso-opensource/var/lib/virtuoso/db/virtuoso.ini"` 89 | 90 | 91 | 5. Now we can import our kb into virtuoso. Before that, we need to convert our kb to `ttl` format and move it to proper position: 92 | ``` 93 | python -m Bart_SPARQL.sparql_engine --kb_path dataset/kb.json --ttl_path dataset/kb.ttl 94 | sudo chmod 777 dataset/kb.ttl 95 | sudo mv dataset/kb.ttl /usr/local/virtuoso-opensource/share/virtuoso/vad 96 | ``` 97 | 98 | 6. Enter the interactive terminal of virtuoso: 99 | ``` 100 | cd /usr/local/virtuoso-opensource/bin 101 | sudo ./isql 102 | ``` 103 | 104 | 7. Import our kb by executing these commands in terminal: 105 | ``` 106 | SPARQL CREATE GRAPH <[graph_name]>; 107 | SPARQL CLEAR GRAPH <[graph_name]>; 108 | delete from db.dba.load_list; 109 | ld_dir('/usr/local/virtuoso-opensource/share/virtuoso/vad', 'kb.ttl', '[graph_name]'); 110 | rdf_loader_run(); 111 | select * from DB.DBA.load_list; 112 | exit; 113 | ``` 114 | `[graph_name]` could be any legal string, such as *KQAPro*. 115 | You are success if `rdf_loader_run()` lasts for about 10 seconds. 116 | 117 | 118 | ## How to run 119 | 1. Follow the last section, start up the virtuoso service and import `kb.ttl`. Then you need to open `sparql_engine.py` and find the lines of 120 | ``` 121 | virtuoso_address = "http://127.0.0.1:8890/sparql" 122 | virtuoso_graph_uri = 'sjx' 123 | ``` 124 | Change `virtuoso_address` to your service url (you can visit it in your browser to check whether it is valid) and change `virtuoso_graph_uri` to your ``. 125 | 2. Preprocess the training data 126 | ``` 127 | python -m Bart_SPARQL.preprocess --input_dir ./dataset --output_dir --model_name_or_path 128 | cp ./dataset/kb.json 129 | ``` 130 | 3. Train 131 | ``` 132 | python -m Bart_SPARQL.train --input_dir --output_dir --model_name_or_path --save_dir 133 | ``` 134 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order. 135 | ``` 136 | python -m Bart_SPARQL.predict --input_dir --ckpt --save_dir 137 | 138 | ``` 139 | 140 | ## Checkpoints 141 | 1. The pretrained Bart-base checkpoint without finetuning can be downloaded here [bart-base](https://cloud.tsinghua.edu.cn/f/3b59ec6c43034cfc8841/?dl=1) 142 | 2. The checkpoint for finetuned Bart_SPARQL can be downloaded here [finetuned](https://cloud.tsinghua.edu.cn/f/1b9746dcd96b4fca870d/?dl=1) 143 | -------------------------------------------------------------------------------- /Bart_SPARQL/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | from utils.misc import invert_dict 5 | 6 | def load_vocab(path): 7 | vocab = json.load(open(path)) 8 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) 9 | return vocab 10 | 11 | def collate(batch): 12 | batch = list(zip(*batch)) 13 | source_ids = torch.stack(batch[0]) 14 | source_mask = torch.stack(batch[1]) 15 | choices = torch.stack(batch[2]) 16 | if batch[-1][0] is None: 17 | target_ids, answer = None, None 18 | else: 19 | target_ids = torch.stack(batch[3]) 20 | answer = torch.cat(batch[4]) 21 | return source_ids, source_mask, choices, target_ids, answer 22 | 23 | 24 | class Dataset(torch.utils.data.Dataset): 25 | def __init__(self, inputs): 26 | self.source_ids, self.source_mask, self.target_ids, self.choices, self.answers = inputs 27 | self.is_test = len(self.answers)==0 28 | 29 | 30 | def __getitem__(self, index): 31 | source_ids = torch.LongTensor(self.source_ids[index]) 32 | source_mask = torch.LongTensor(self.source_mask[index]) 33 | choices = torch.LongTensor(self.choices[index]) 34 | if self.is_test: 35 | target_ids = None 36 | answer = None 37 | else: 38 | target_ids = torch.LongTensor(self.target_ids[index]) 39 | answer = torch.LongTensor([self.answers[index]]) 40 | return source_ids, source_mask, choices, target_ids, answer 41 | 42 | 43 | def __len__(self): 44 | return len(self.source_ids) 45 | 46 | 47 | class DataLoader(torch.utils.data.DataLoader): 48 | def __init__(self, vocab_json, question_pt, batch_size, training=False): 49 | vocab = load_vocab(vocab_json) 50 | if training: 51 | print('#vocab of answer: %d' % (len(vocab['answer_token_to_idx']))) 52 | 53 | inputs = [] 54 | with open(question_pt, 'rb') as f: 55 | for _ in range(5): 56 | inputs.append(pickle.load(f)) 57 | dataset = Dataset(inputs) 58 | 59 | super().__init__( 60 | dataset, 61 | batch_size=batch_size, 62 | shuffle=training, 63 | collate_fn=collate, 64 | ) 65 | self.vocab = vocab -------------------------------------------------------------------------------- /Bart_SPARQL/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | from nltk import word_tokenize 7 | from collections import Counter 8 | from itertools import chain 9 | from tqdm import tqdm 10 | import re 11 | 12 | from utils.misc import init_vocab 13 | from transformers import * 14 | 15 | 16 | 17 | def encode_dataset(dataset, vocab, tokenizer, test = False): 18 | questions = [] 19 | sparqls = [] 20 | for item in tqdm(dataset): 21 | question = item['question'] 22 | questions.append(question) 23 | if not test: 24 | sparql = item['sparql'] 25 | sparqls.append(sparql) 26 | sequences = questions + sparqls 27 | encoded_inputs = tokenizer(sequences, padding = True) 28 | print(encoded_inputs.keys()) 29 | print(encoded_inputs['input_ids'][0]) 30 | print(tokenizer.decode(encoded_inputs['input_ids'][0])) 31 | print(tokenizer.decode(encoded_inputs['input_ids'][-1])) 32 | max_seq_length = len(encoded_inputs['input_ids'][0]) 33 | assert max_seq_length == len(encoded_inputs['input_ids'][-1]) 34 | print(max_seq_length) 35 | questions = [] 36 | sparqls = [] 37 | choices = [] 38 | answers = [] 39 | for item in tqdm(dataset): 40 | question = item['question'] 41 | questions.append(question) 42 | _ = [vocab['answer_token_to_idx'][w] for w in item['choices']] 43 | choices.append(_) 44 | if not test: 45 | sparql = item['sparql'] 46 | sparqls.append(sparql) 47 | answers.append(vocab['answer_token_to_idx'].get(item['answer'])) 48 | 49 | input_ids = tokenizer.batch_encode_plus(questions, max_length = max_seq_length, pad_to_max_length = True, truncation = True) 50 | source_ids = np.array(input_ids['input_ids'], dtype = np.int32) 51 | source_mask = np.array(input_ids['attention_mask'], dtype = np.int32) 52 | if not test: 53 | target_ids = tokenizer.batch_encode_plus(sparqls, max_length = max_seq_length, pad_to_max_length = True, truncation = True) 54 | target_ids = np.array(target_ids['input_ids'], dtype = np.int32) 55 | else: 56 | target_ids = np.array([], dtype = np.int32) 57 | choices = np.array(choices, dtype = np.int32) 58 | answers = np.array(answers, dtype = np.int32) 59 | return source_ids, source_mask, target_ids, choices, answers 60 | 61 | 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument('--input_dir', required=True) 66 | parser.add_argument('--output_dir', required=True) 67 | parser.add_argument('--model_name_or_path', required=True) 68 | args = parser.parse_args() 69 | 70 | print('Build kb vocabulary') 71 | vocab = { 72 | 'answer_token_to_idx': {} 73 | } 74 | print('Load questions') 75 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json'))) 76 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json'))) 77 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json'))) 78 | for question in chain(train_set, val_set, test_set): 79 | for a in question['choices']: 80 | if not a in vocab['answer_token_to_idx']: 81 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 82 | 83 | if not os.path.isdir(args.output_dir): 84 | os.mkdir(args.output_dir) 85 | fn = os.path.join(args.output_dir, 'vocab.json') 86 | print('Dump vocab to {}'.format(fn)) 87 | with open(fn, 'w') as f: 88 | json.dump(vocab, f, indent=2) 89 | for k in vocab: 90 | print('{}:{}'.format(k, len(vocab[k]))) 91 | tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path) 92 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)): 93 | print('Encode {} set'.format(name)) 94 | outputs = encode_dataset(dataset, vocab, tokenizer, name=='test') 95 | assert len(outputs) == 5 96 | print('shape of input_ids of questions, attention_mask of questions, input_ids of sparqls, choices and answers:') 97 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f: 98 | for o in outputs: 99 | print(o.shape) 100 | pickle.dump(o, f) 101 | if __name__ == '__main__': 102 | main() -------------------------------------------------------------------------------- /Bart_SPARQL/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | import json 8 | from tqdm import tqdm 9 | from datetime import date 10 | from utils.misc import MetricLogger, seed_everything, ProgressBar 11 | from utils.load_kb import DataForSPARQL 12 | from .data import DataLoader 13 | from transformers import BartConfig, BartForConditionalGeneration, BartTokenizer 14 | from .sparql_engine import get_sparql_answer 15 | import torch.optim as optim 16 | import logging 17 | import time 18 | from utils.lr_scheduler import get_linear_schedule_with_warmup 19 | 20 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 21 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 22 | rootLogger = logging.getLogger() 23 | import warnings 24 | warnings.simplefilter("ignore") # hide warnings that caused by invalid sparql query 25 | 26 | 27 | 28 | 29 | def train(args): 30 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 31 | 32 | logging.info("Create train_loader and val_loader.........") 33 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 34 | train_pt = os.path.join(args.input_dir, 'train.pt') 35 | val_pt = os.path.join(args.input_dir, 'val.pt') 36 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True) 37 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size) 38 | 39 | vocab = train_loader.vocab 40 | 41 | logging.info("Create model.........") 42 | config_class, model_class, tokenizer_class = (BartConfig, BartForConditionalGeneration, BartTokenizer) 43 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) 44 | model = model_class.from_pretrained(args.model_name_or_path) 45 | model = model.to(device) 46 | logging.info(model) 47 | t_total = len(train_loader) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay) 48 | no_decay = ["bias", "LayerNorm.weight"] 49 | bart_param_optimizer = list(model.named_parameters()) 50 | optimizer_grouped_parameters = [ 51 | {'params': [p for n, p in bart_param_optimizer if not any(nd in n for nd in no_decay)], 52 | 'weight_decay': args.weight_decay, 'lr': args.learning_rate}, 53 | {'params': [p for n, p in bart_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 54 | 'lr': args.learning_rate} 55 | ] 56 | args.warmup_steps = int(t_total * args.warmup_proportion) 57 | optimizer = optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 58 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, 59 | num_training_steps=t_total) 60 | # Check if saved optimizer or scheduler states exist 61 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile( 62 | os.path.join(args.model_name_or_path, "scheduler.pt")): 63 | # Load in optimizer and scheduler states 64 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) 65 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) 66 | 67 | # Train! 68 | logging.info("***** Running training *****") 69 | logging.info(" Num examples = %d", len(train_loader.dataset)) 70 | logging.info(" Num Epochs = %d", args.num_train_epochs) 71 | logging.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 72 | logging.info(" Total optimization steps = %d", t_total) 73 | 74 | global_step = 0 75 | steps_trained_in_current_epoch = 0 76 | # Check if continuing training from a checkpoint 77 | if os.path.exists(args.model_name_or_path) and "checkpoint" in args.model_name_or_path: 78 | # set global_step to gobal_step of last saved checkpoint from model path 79 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0]) 80 | epochs_trained = global_step // (len(train_loader) // args.gradient_accumulation_steps) 81 | steps_trained_in_current_epoch = global_step % (len(train_loader) // args.gradient_accumulation_steps) 82 | logging.info(" Continuing training from checkpoint, will skip to saved global_step") 83 | logging.info(" Continuing training from epoch %d", epochs_trained) 84 | logging.info(" Continuing training from global step %d", global_step) 85 | logging.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) 86 | logging.info('Checking...') 87 | logging.info("===================Dev==================") 88 | # evaluate(args, model, val_loader, device) 89 | tr_loss, logging_loss = 0.0, 0.0 90 | model.zero_grad() 91 | for _ in range(int(args.num_train_epochs)): 92 | pbar = ProgressBar(n_total=len(train_loader), desc='Training') 93 | for step, batch in enumerate(train_loader): 94 | # Skip past any already trained steps if resuming training 95 | if steps_trained_in_current_epoch > 0: 96 | steps_trained_in_current_epoch -= 1 97 | continue 98 | model.train() 99 | batch = tuple(t.to(device) for t in batch) 100 | pad_token_id = tokenizer.pad_token_id 101 | source_ids, source_mask, y = batch[0], batch[1], batch[-2] 102 | y_ids = y[:, :-1].contiguous() 103 | lm_labels = y[:, 1:].clone() 104 | lm_labels[y[:, 1:] == pad_token_id] = -100 105 | 106 | inputs = { 107 | "input_ids": source_ids.to(device), 108 | "attention_mask": source_mask.to(device), 109 | "decoder_input_ids": y_ids.to(device), 110 | "labels": lm_labels.to(device) 111 | } 112 | outputs = model(**inputs) 113 | loss = outputs[0] 114 | loss.backward() 115 | pbar(step, {'loss': loss.item()}) 116 | tr_loss += loss.item() 117 | if (step + 1) % args.gradient_accumulation_steps == 0: 118 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 119 | optimizer.step() 120 | scheduler.step() # Update learning rate schedule 121 | model.zero_grad() 122 | global_step += 1 123 | # if args.logging_steps > 0 and global_step % args.logging_steps == 0: 124 | # logging.info("===================Dev==================") 125 | # evaluate(args, model, val_loader, device) 126 | # logging.info("===================Test==================") 127 | # evaluate(args, model, test_loader, device) 128 | if args.save_steps > 0 and global_step % args.save_steps == 0: 129 | # Save model checkpoint 130 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 131 | if not os.path.exists(output_dir): 132 | os.makedirs(output_dir) 133 | model_to_save = ( 134 | model.module if hasattr(model, "module") else model 135 | ) # Take care of distributed/parallel training 136 | model_to_save.save_pretrained(output_dir) 137 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 138 | logging.info("Saving model checkpoint to %s", output_dir) 139 | tokenizer.save_vocabulary(output_dir) 140 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) 141 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) 142 | logging.info("Saving optimizer and scheduler states to %s", output_dir) 143 | logging.info("\n") 144 | if 'cuda' in str(device): 145 | torch.cuda.empty_cache() 146 | return global_step, tr_loss / global_step 147 | 148 | 149 | def main(): 150 | parser = argparse.ArgumentParser() 151 | # input and output 152 | parser.add_argument('--input_dir', required=True) 153 | parser.add_argument('--output_dir', required=True) 154 | 155 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 156 | parser.add_argument('--model_name_or_path', required = True, help = 'pretrained language models') 157 | parser.add_argument('--ckpt') 158 | 159 | # training parameters 160 | parser.add_argument('--weight_decay', default=1e-5, type=float) 161 | parser.add_argument('--batch_size', default=8, type=int) 162 | parser.add_argument('--seed', type=int, default=666, help='random seed') 163 | parser.add_argument('--learning_rate', default=3e-5, type = float) 164 | parser.add_argument('--num_train_epochs', default=25, type = int) 165 | parser.add_argument('--save_steps', default=448, type = int) 166 | parser.add_argument('--logging_steps', default=448, type = int) 167 | parser.add_argument('--warmup_proportion', default=0.1, type = float, 168 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training.") 169 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 170 | help="Epsilon for Adam optimizer.") 171 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, 172 | help="Number of updates steps to accumulate before performing a backward/update pass.", ) 173 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 174 | help="Max gradient norm.") 175 | 176 | # validating parameters 177 | # parser.add_argument('--num_return_sequences', default=1, type=int) 178 | # parser.add_argument('--top_p', default=) 179 | # model hyperparameters 180 | parser.add_argument('--dim_hidden', default=1024, type=int) 181 | parser.add_argument('--alpha', default = 1e-4, type = float) 182 | args = parser.parse_args() 183 | 184 | if not os.path.exists(args.save_dir): 185 | os.makedirs(args.save_dir) 186 | time_ = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 187 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, '{}.log'.format(time_))) 188 | fileHandler.setFormatter(logFormatter) 189 | rootLogger.addHandler(fileHandler) 190 | # args display 191 | for k, v in vars(args).items(): 192 | logging.info(k+':'+str(v)) 193 | 194 | seed_everything(666) 195 | 196 | train(args) 197 | 198 | 199 | if __name__ == '__main__': 200 | main() 201 | 202 | -------------------------------------------------------------------------------- /BlindGRU/README.md: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | - python3 3 | - pytorch>=1.2.0 4 | - nltk 5 | 6 | ## How to run 7 | 1. Download [GloVe 300d vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip), unzip it to get the file `glove.840B.300d.txt`, and then convert it to a pickle file for faster loading: 8 | ``` 9 | python -m utils.pickle_glove --input --output 10 | ``` 11 | This step can be skipped if you have obtained the glove pickle file in other models. 12 | 2. Preprocess the training data 13 | ``` 14 | python -m BlindGRU.preprocess --input_dir ./dataset --output_dir 15 | ``` 16 | 3. Train 17 | ``` 18 | python -m BlindGRU.train --input_dir --save_dir --glove_pt 19 | ``` 20 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order. 21 | ``` 22 | python -m BlindGRU.predict --input_dir --save_dir 23 | ``` 24 | -------------------------------------------------------------------------------- /BlindGRU/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | from utils.misc import invert_dict 5 | 6 | 7 | def load_vocab(path): 8 | vocab = json.load(open(path)) 9 | vocab['word_idx_to_token'] = invert_dict(vocab['word_token_to_idx']) 10 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) 11 | return vocab 12 | 13 | def collate(batch): 14 | batch = list(zip(*batch)) 15 | question = torch.stack(batch[0]) 16 | choices = torch.stack(batch[1]) 17 | if batch[-1][0] is None: 18 | answer = None 19 | else: 20 | answer = torch.cat(batch[2]) 21 | return question, choices, answer 22 | 23 | 24 | class Dataset(torch.utils.data.Dataset): 25 | def __init__(self, inputs): 26 | self.questions, self.choices, self.answers = inputs 27 | self.is_test = len(self.answers)==0 28 | 29 | 30 | def __getitem__(self, index): 31 | question = torch.LongTensor(self.questions[index]) 32 | choices = torch.LongTensor(self.choices[index]) 33 | if self.is_test: 34 | answer = None 35 | else: 36 | answer = torch.LongTensor([self.answers[index]]) 37 | return question, choices, answer 38 | 39 | 40 | def __len__(self): 41 | return len(self.questions) 42 | 43 | 44 | class DataLoader(torch.utils.data.DataLoader): 45 | def __init__(self, vocab_json, question_pt, batch_size, training=False): 46 | vocab = load_vocab(vocab_json) 47 | if training: 48 | print('#vocab of word/answer: %d/%d' % 49 | (len(vocab['word_token_to_idx']), len(vocab['answer_token_to_idx']))) 50 | 51 | inputs = [] 52 | with open(question_pt, 'rb') as f: 53 | for _ in range(3): 54 | inputs.append(pickle.load(f)) 55 | dataset = Dataset(inputs) 56 | 57 | super().__init__( 58 | dataset, 59 | batch_size=batch_size, 60 | shuffle=training, 61 | collate_fn=collate, 62 | ) 63 | self.vocab = vocab 64 | 65 | -------------------------------------------------------------------------------- /BlindGRU/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils.BiGRU import BiGRU 5 | 6 | class GRUClassifier(nn.Module): 7 | def __init__(self, vocab, dim_word, dim_hidden): 8 | super().__init__() 9 | 10 | num_class = len(vocab['answer_token_to_idx']) 11 | num_words = len(vocab['word_token_to_idx']) 12 | 13 | self.word_embeddings = nn.Embedding(num_words, dim_word) 14 | self.word_dropout = nn.Dropout(0.3) 15 | self.question_encoder = BiGRU(dim_word, dim_hidden, num_layers=2, dropout=0.2) 16 | 17 | self.classifier = nn.Sequential( 18 | nn.Linear(dim_hidden, 1024), 19 | nn.ReLU(), 20 | nn.Linear(1024, num_class) 21 | ) 22 | 23 | for m in self.modules(): 24 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 25 | nn.init.kaiming_normal_(m.weight) 26 | if m.bias is not None: 27 | m.bias.data.zero_() 28 | 29 | def forward(self, questions): 30 | """ 31 | Args: 32 | - questions (LongTensor) [bsz, max_len] 33 | """ 34 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means 35 | # print(question_lens) 36 | question_input = self.word_dropout(self.word_embeddings(questions)) 37 | _, question_embeddings, _ = self.question_encoder(question_input, question_lens) 38 | logits = self.classifier(question_embeddings) 39 | return logits 40 | -------------------------------------------------------------------------------- /BlindGRU/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | from tqdm import tqdm 8 | 9 | from .data import DataLoader 10 | from .model import GRUClassifier 11 | 12 | 13 | def predict(args): 14 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 15 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 16 | test_pt = os.path.join(args.input_dir, 'test.pt') 17 | test_loader = DataLoader(vocab_json, test_pt, 128) 18 | vocab = test_loader.vocab 19 | 20 | model = GRUClassifier(vocab, args.dim_word, args.dim_hidden) 21 | model.load_state_dict(torch.load(os.path.join(args.save_dir, 'model.pt'))) 22 | model = model.to(device) 23 | model.eval() 24 | 25 | def write(f, predict): 26 | predict = predict.squeeze().tolist() 27 | for i in predict: 28 | f.write(vocab['answer_idx_to_token'][i] + '\n') 29 | 30 | f1 = open(os.path.join(args.save_dir, 'predict.txt'), 'w') 31 | f2 = open(os.path.join(args.save_dir, 'choice_predict.txt'), 'w') 32 | with torch.no_grad(): 33 | for batch in tqdm(test_loader, total=len(test_loader)): 34 | question, choices = [x.to(device) for x in batch[:2]] 35 | logit = model(question) 36 | predict = logit.max(1)[1] 37 | write(f1, predict) 38 | choiced_logit = torch.gather(logit, 1, choices) # [bsz, num_choices] 39 | choiced_predict = torch.gather(choices, 1, choiced_logit.max(1)[1].unsqueeze(-1)) # [bsz, 1] 40 | write(f2, choiced_predict) 41 | f1.close() 42 | f2.close() 43 | 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser() 48 | # input and output 49 | parser.add_argument('--input_dir', required=True) 50 | parser.add_argument('--save_dir', required=True, help='folder of checkpoint') 51 | 52 | # model hyperparameters 53 | parser.add_argument('--dim_word', default=300, type=int) 54 | parser.add_argument('--dim_hidden', default=1024, type=int) 55 | args = parser.parse_args() 56 | 57 | predict(args) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /BlindGRU/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | from nltk import word_tokenize 7 | from collections import Counter 8 | from itertools import chain 9 | from tqdm import tqdm 10 | 11 | from utils.misc import init_vocab 12 | 13 | 14 | def encode_dataset(dataset, vocab, test=False): 15 | questions = [] 16 | choices = [] 17 | answers = [] 18 | for question in tqdm(dataset): 19 | q = [vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx']['']) 20 | for w in word_tokenize(question['question'].lower())] 21 | questions.append(q) 22 | 23 | _ = [vocab['answer_token_to_idx'][w] for w in question['choices']] 24 | choices.append(_) 25 | 26 | if test: 27 | continue 28 | 29 | if 'answer' in question: 30 | answers.append(vocab['answer_token_to_idx'].get(question['answer'])) 31 | 32 | # question padding 33 | max_len = max(len(q) for q in questions) 34 | for q in questions: 35 | while len(q) < max_len: 36 | q.append(vocab['word_token_to_idx']['']) 37 | 38 | questions = np.asarray(questions, dtype=np.int32) 39 | choices = np.asarray(choices, dtype=np.int32) 40 | answers = np.asarray(answers, dtype=np.int32) 41 | return questions, choices, answers 42 | 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--input_dir', required=True) 48 | parser.add_argument('--output_dir', required=True) 49 | parser.add_argument('--min_cnt', type=int, default=1) 50 | args = parser.parse_args() 51 | 52 | 53 | 54 | vocab = { 55 | 'word_token_to_idx': init_vocab(), # include question text and function inputs 56 | 'answer_token_to_idx': {} 57 | } 58 | print('Load questions') 59 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json'))) 60 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json'))) 61 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json'))) 62 | print('Build question vocabulary') 63 | word_counter = Counter() 64 | for question in train_set: 65 | tokens = word_tokenize(question['question'].lower()) 66 | word_counter.update(tokens) 67 | # add candidate answers 68 | for a in question['choices']: 69 | if a not in vocab['answer_token_to_idx']: 70 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 71 | # filter low-frequency words 72 | for w, c in word_counter.items(): 73 | if w and c >= args.min_cnt and w not in vocab['word_token_to_idx']: 74 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx']) 75 | # add candidate answers of val and test set 76 | for question in chain(val_set, test_set): 77 | for a in question['choices']: 78 | if a not in vocab['answer_token_to_idx']: 79 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 80 | 81 | 82 | if not os.path.isdir(args.output_dir): 83 | os.mkdir(args.output_dir) 84 | fn = os.path.join(args.output_dir, 'vocab.json') 85 | print('Dump vocab to {}'.format(fn)) 86 | with open(fn, 'w') as f: 87 | json.dump(vocab, f, indent=2) 88 | for k in vocab: 89 | print('{}:{}'.format(k, len(vocab[k]))) 90 | 91 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)): 92 | print('Encode {} set'.format(name)) 93 | outputs = encode_dataset(dataset, vocab, name=='test') 94 | print('shape of questions, choices, answers:') 95 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f: 96 | for o in outputs: 97 | print(o.shape) 98 | pickle.dump(o, f) 99 | 100 | 101 | 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /BlindGRU/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | from tqdm import tqdm 8 | 9 | from .data import DataLoader 10 | from .model import GRUClassifier 11 | from utils.misc import MetricLogger, load_glove 12 | 13 | import logging 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 15 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 16 | rootLogger = logging.getLogger() 17 | 18 | 19 | def validate(model, data, device): 20 | model.eval() 21 | count, correct = 0, 0 22 | with torch.no_grad(): 23 | for batch in tqdm(data, total=len(data)): 24 | question, choices, answer = [x.to(device) for x in batch] 25 | logit = model(question) 26 | predict = logit.max(1)[1] 27 | correct += torch.eq(predict, answer).long().sum().item() 28 | count += len(answer) 29 | 30 | acc = correct / count 31 | logging.info('\nValid Accuracy: %.4f\n' % acc) 32 | return acc 33 | 34 | 35 | def train(args): 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | 38 | logging.info("Create train_loader and val_loader.........") 39 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 40 | train_pt = os.path.join(args.input_dir, 'train.pt') 41 | val_pt = os.path.join(args.input_dir, 'val.pt') 42 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True) 43 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size) 44 | vocab = train_loader.vocab 45 | 46 | logging.info("Create model.........") 47 | model = GRUClassifier(vocab, args.dim_word, args.dim_hidden) 48 | logging.info("Load pretrained word vectors.........") 49 | pretrained = load_glove(args.glove_pt, vocab['word_idx_to_token']) 50 | with torch.no_grad(): 51 | model.word_embeddings.weight.set_(torch.Tensor(pretrained)) 52 | model = model.to(device) 53 | logging.info(model) 54 | 55 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 56 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[5, 50], gamma=0.1) 57 | criterion = nn.CrossEntropyLoss().to(device) 58 | 59 | validate(model, val_loader, device) 60 | meters = MetricLogger(delimiter=" ") 61 | best_acc = 0 62 | logging.info("Start training........") 63 | for epoch in range(args.num_epoch): 64 | model.train() 65 | for iteration, batch in enumerate(train_loader): 66 | iteration = iteration + 1 67 | 68 | question, choices, answer = [x.to(device) for x in batch] 69 | logits = model(question) 70 | loss = criterion(logits, answer) 71 | optimizer.zero_grad() 72 | loss.backward() 73 | optimizer.step() 74 | meters.update(loss=loss.item()) 75 | 76 | if iteration % (len(train_loader) // 100) == 0: 77 | logging.info( 78 | meters.delimiter.join( 79 | [ 80 | "progress: {progress:.3f}", 81 | "{meters}", 82 | "lr: {lr:.6f}", 83 | ] 84 | ).format( 85 | progress=epoch + iteration / len(train_loader), 86 | meters=str(meters), 87 | lr=optimizer.param_groups[0]["lr"], 88 | ) 89 | ) 90 | 91 | acc = validate(model, val_loader, device) 92 | scheduler.step() 93 | if acc and acc > best_acc: 94 | best_acc = acc 95 | logging.info("\nupdate best ckpt with acc: {:.4f}".format(best_acc)) 96 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model.pt')) 97 | 98 | 99 | def main(): 100 | parser = argparse.ArgumentParser() 101 | # input and output 102 | parser.add_argument('--input_dir', required=True) 103 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 104 | parser.add_argument('--glove_pt', required=True) 105 | 106 | # training parameters 107 | parser.add_argument('--lr', default=0.001, type=float) 108 | parser.add_argument('--weight_decay', default=1e-5, type=float) 109 | parser.add_argument('--num_epoch', default=100, type=int) 110 | parser.add_argument('--batch_size', default=128, type=int) 111 | parser.add_argument('--seed', type=int, default=666, help='random seed') 112 | 113 | # model hyperparameters 114 | parser.add_argument('--dim_word', default=300, type=int) 115 | parser.add_argument('--dim_hidden', default=1024, type=int) 116 | args = parser.parse_args() 117 | 118 | # make logging.info display into both shell and file 119 | if os.path.isdir(args.save_dir): 120 | shutil.rmtree(args.save_dir) 121 | os.mkdir(args.save_dir) 122 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt')) 123 | fileHandler.setFormatter(logFormatter) 124 | rootLogger.addHandler(fileHandler) 125 | # args display 126 | for k, v in vars(args).items(): 127 | logging.info(k+':'+str(v)) 128 | 129 | # set random seed 130 | torch.manual_seed(args.seed) 131 | 132 | train(args) 133 | 134 | 135 | if __name__ == '__main__': 136 | main() 137 | -------------------------------------------------------------------------------- /KVMemNN/README.md: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | - python3 3 | - pytorch>=1.2.0 4 | - nltk 5 | 6 | ## How to run 7 | 1. Download [GloVe 300d vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip), unzip it to get the file `glove.840B.300d.txt`, and then convert it to a pickle file for faster loading: 8 | ``` 9 | python -m utils.pickle_glove --input --output 10 | ``` 11 | This step can be skipped if you have obtained the glove pickle file in other models. 12 | 2. Preprocess the training data 13 | ``` 14 | python -m KVMemNN.preprocess --input_dir ./dataset --output_dir 15 | ``` 16 | 3. Train 17 | ``` 18 | python -m KVMemNN.train --input_dir --save_dir --glove_pt 19 | ``` 20 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order. 21 | ``` 22 | python -m KVMemNN.predict --input_dir --save_dir 23 | ``` 24 | -------------------------------------------------------------------------------- /KVMemNN/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | from utils.misc import invert_dict 5 | 6 | 7 | def load_vocab(path): 8 | vocab = json.load(open(path)) 9 | vocab['word_idx_to_token'] = invert_dict(vocab['word_token_to_idx']) 10 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) 11 | return vocab 12 | 13 | def collate(batch): 14 | batch = list(zip(*batch)) 15 | question, choices, keys, values = list(map(torch.stack, batch[:4])) 16 | if batch[-1][0] is None: 17 | answer = None 18 | else: 19 | answer = torch.cat(batch[-1]) 20 | return question, choices, keys, values, answer 21 | 22 | 23 | class Dataset(torch.utils.data.Dataset): 24 | def __init__(self, all_keys, all_values, inputs): 25 | self.all_keys = all_keys 26 | self.all_values = all_values 27 | self.questions, self.key_indexes, self.choices, self.answers = inputs 28 | self.is_test = len(self.answers)==0 29 | 30 | 31 | def __getitem__(self, index): 32 | question = torch.LongTensor(self.questions[index]) 33 | key_index = self.key_indexes[index] 34 | keys = torch.LongTensor(self.all_keys[key_index]) 35 | values = torch.LongTensor(self.all_values[key_index]) 36 | choices = torch.LongTensor(self.choices[index]) 37 | if self.is_test: 38 | answer = None 39 | else: 40 | answer = torch.LongTensor([self.answers[index]]) 41 | return question, choices, keys, values, answer 42 | 43 | 44 | def __len__(self): 45 | return len(self.questions) 46 | 47 | 48 | class DataLoader(torch.utils.data.DataLoader): 49 | def __init__(self, vocab_json, kb_pt, question_pt, batch_size, training=False): 50 | vocab = load_vocab(vocab_json) 51 | 52 | inputs = [] 53 | with open(question_pt, 'rb') as f: 54 | for _ in range(4): 55 | inputs.append(pickle.load(f)) 56 | with open(kb_pt, 'rb') as f: 57 | all_keys = pickle.load(f) 58 | all_values = pickle.load(f) 59 | dataset = Dataset(all_keys, all_values, inputs) 60 | 61 | super().__init__( 62 | dataset, 63 | batch_size=batch_size, 64 | shuffle=training, 65 | collate_fn=collate, 66 | ) 67 | self.vocab = vocab 68 | 69 | -------------------------------------------------------------------------------- /KVMemNN/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils.BiGRU import BiGRU, GRU 5 | 6 | class KVMemNN(nn.Module): 7 | def __init__(self, num_hop, dim_emb, vocab): 8 | super().__init__() 9 | self.num_hop = num_hop 10 | num_vocab = len(vocab['word_token_to_idx']) 11 | num_class = len(vocab['answer_token_to_idx']) 12 | 13 | self.embeddings = nn.Embedding(num_vocab, dim_emb) 14 | self.question_encoder = BiGRU(dim_emb, dim_emb, num_layers=2, dropout=0.2) 15 | self.word_dropout = nn.Dropout(0.3) 16 | self.linears = [] 17 | for i in range(num_hop): 18 | lin = nn.Linear(dim_emb, dim_emb) 19 | self.linears.append(lin) 20 | self.add_module('linear_{}'.format(i), lin) 21 | 22 | self.classifier = nn.Sequential( 23 | nn.Linear(dim_emb, 1024), 24 | nn.ReLU(), 25 | nn.Linear(1024, num_class) 26 | ) 27 | for m in self.modules(): 28 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 29 | nn.init.kaiming_normal_(m.weight) 30 | if m.bias is not None: 31 | m.bias.data.zero_() 32 | 33 | def forward(self, questions, keys, values): 34 | """ 35 | Args: 36 | questions [bsz, max_q_len] 37 | keys [bsz, num_slot, max_k_len] 38 | values [bsz, num_slot, max_v_len] 39 | """ 40 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means 41 | q_word_emb = self.word_dropout(self.embeddings(questions)) 42 | q, q_embeddings, q_hn = self.question_encoder(q_word_emb, question_lens) 43 | q = self.embeddings(questions).sum(dim=1) # [bsz, dim_emb] 44 | k = self.embeddings(keys).sum(dim=2) # [bsz, num_slot, dim_emb] 45 | v = self.embeddings(values).sum(dim=2) # [bsz, num_slot, dim_emb] 46 | 47 | for i in range(self.num_hop): 48 | weights = torch.bmm(k, q.unsqueeze(2)).squeeze(2) # [bsz, num_slot] 49 | weights = torch.softmax(weights, dim=1) 50 | o = torch.bmm(weights.unsqueeze(1), v).squeeze(1) # [bsz, dim_emb] 51 | q = self.linears[i](q + o) # [bsz, dim_emb] 52 | logits = self.classifier(q) 53 | return logits 54 | -------------------------------------------------------------------------------- /KVMemNN/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | from tqdm import tqdm 8 | 9 | from .data import DataLoader 10 | from .model import KVMemNN 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | # input and output 16 | parser.add_argument('--input_dir', required=True) 17 | parser.add_argument('--save_dir', required=True, help='path of checkpoint') 18 | 19 | # model hyperparameters 20 | parser.add_argument('--dim_emb', default=300, type=int) 21 | parser.add_argument('--num_hop', default=3, type=int) 22 | args = parser.parse_args() 23 | 24 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 25 | 26 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 27 | test_pt = os.path.join(args.input_dir, 'test.pt') 28 | kb_pt = os.path.join(args.input_dir, 'kb.pt') 29 | test_loader = DataLoader(vocab_json, kb_pt, test_pt, 32) 30 | vocab = test_loader.vocab 31 | 32 | 33 | model = KVMemNN( 34 | args.num_hop, 35 | args.dim_emb, 36 | vocab 37 | ) 38 | model.load_state_dict(torch.load(os.path.join(args.save_dir, 'model.pt'))) 39 | model = model.to(device) 40 | model.eval() 41 | 42 | def write(f, predict): 43 | predict = predict.squeeze().tolist() 44 | for i in predict: 45 | f.write(vocab['answer_idx_to_token'][i] + '\n') 46 | 47 | f1 = open(os.path.join(args.save_dir, 'predict.txt'), 'w') 48 | f2 = open(os.path.join(args.save_dir, 'choice_predict.txt'), 'w') 49 | with torch.no_grad(): 50 | for batch in tqdm(test_loader, total=len(test_loader)): 51 | question, choices, keys, values = [x.to(device) for x in batch[:4]] 52 | logit = model(question, keys, values) # [bsz, num_answers] 53 | predict = logit.max(1)[1] 54 | write(f1, predict) 55 | choiced_logit = torch.gather(logit, 1,choices) # [bsz, num_choices] 56 | choiced_predict = torch.gather(choices, 1, choiced_logit.max(1)[1].unsqueeze(-1)) # [bsz, 1] 57 | write(f2, choiced_predict) 58 | f1.close() 59 | f2.close() 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /KVMemNN/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | from nltk import word_tokenize 7 | from collections import Counter, defaultdict 8 | from itertools import chain 9 | from tqdm import tqdm 10 | 11 | from utils.load_kb import load_as_key_value 12 | from utils.misc import init_vocab 13 | 14 | 15 | def create_inverted(keys): 16 | inverted_index = defaultdict(set) 17 | counter = Counter() 18 | for i in range(len(keys)): 19 | for w in keys[i]: 20 | inverted_index[w].add(i) 21 | counter[w] += 1 22 | return inverted_index 23 | 24 | 25 | def find_candidate_keys(inverted_index, stopwords, question, num_cand_keys): 26 | """ 27 | find keys that are relevant to question, and then return the top num_cand_keys 28 | if not enough, pad 0 29 | """ 30 | words = word_tokenize(question['question'].lower()) 31 | counter = Counter() 32 | for w in words: 33 | if w in stopwords: # skip stopwords 34 | continue 35 | counter.update(inverted_index.get(w, [])) 36 | indexes = [x[0] for x in counter.most_common(num_cand_keys)] 37 | if len(indexes) < num_cand_keys: 38 | indexes += [0] * (num_cand_keys - len(indexes)) 39 | return indexes 40 | 41 | 42 | 43 | def encode_kb(keys, values, vocab): 44 | encoded_keys = [] 45 | encoded_values = [] 46 | for i in tqdm(range(len(keys))): 47 | encoded_keys.append([vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx']['']) for w in keys[i]]) 48 | encoded_values.append([vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx']['']) for w in values[i]]) 49 | keys = encoded_keys 50 | values = encoded_values 51 | max_len = max(len(k) for k in keys) 52 | for k in keys: 53 | while len(k) < max_len: 54 | k.append(vocab['word_token_to_idx']['']) 55 | max_len = max(len(k) for k in values) 56 | for k in values: 57 | while len(k) < max_len: 58 | k.append(vocab['word_token_to_idx']['']) 59 | keys = np.asarray(keys, dtype=np.int32) 60 | values = np.asarray(values, dtype=np.int32) 61 | return keys, values 62 | 63 | 64 | def encode_dataset(dataset, vocab, inverted_index, stopwords, num_cand_keys): 65 | questions = [] 66 | key_indexes = [] 67 | choices = [] 68 | answers = [] 69 | for question in tqdm(dataset): 70 | q = [vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx']['']) 71 | for w in word_tokenize(question['question'].lower())] 72 | questions.append(q) 73 | 74 | key_indexes.append(find_candidate_keys(inverted_index, stopwords, question, num_cand_keys)) 75 | 76 | 77 | _ = [vocab['answer_token_to_idx'][w] for w in question['choices']] 78 | choices.append(_) 79 | if 'answer' in question: 80 | answers.append(vocab['answer_token_to_idx'].get(question['answer'])) 81 | 82 | # question padding 83 | max_len = max(len(q) for q in questions) 84 | for q in questions: 85 | while len(q) < max_len: 86 | q.append(vocab['word_token_to_idx']['']) 87 | 88 | questions = np.asarray(questions, dtype=np.int32) 89 | key_indexes = np.asarray(key_indexes, dtype=np.int32) 90 | choices = np.asarray(choices, dtype=np.int32) 91 | answers = np.asarray(answers, dtype=np.int32) 92 | return questions, key_indexes, choices, answers 93 | 94 | 95 | 96 | def main(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('--input_dir', required=True) 99 | parser.add_argument('--output_dir', required=True) 100 | parser.add_argument('--min_cnt', type=int, default=1) 101 | parser.add_argument('--stop_thresh', type=int, default=1000) 102 | parser.add_argument('--num_cand_keys', type=int, default=1000) 103 | args = parser.parse_args() 104 | 105 | 106 | print('Build kb vocabulary') 107 | kb_vocab, kb_keys, kb_values = load_as_key_value(os.path.join(args.input_dir, 'kb.json'), args.min_cnt) 108 | vocab = { 109 | 'word_token_to_idx': init_vocab(), 110 | 'answer_token_to_idx': {} 111 | } 112 | print('Load questions') 113 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json'))) 114 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json'))) 115 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json'))) 116 | print('Build question vocabulary') 117 | word_counter = Counter() 118 | for question in train_set: 119 | tokens = word_tokenize(question['question'].lower()) 120 | word_counter.update(tokens) 121 | # add candidate answers 122 | for a in question['choices']: 123 | if a not in vocab['answer_token_to_idx']: 124 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 125 | # filter low-frequency words 126 | stopwords = set() 127 | for w, c in word_counter.items(): 128 | if w and c >= args.min_cnt and w not in vocab['word_token_to_idx']: 129 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx']) 130 | if w and c >= args.stop_thresh: 131 | stopwords.add(w) 132 | print('number of stop words (>={}): {}'.format(args.stop_thresh, len(stopwords))) 133 | # merge kb vocab 134 | for w in kb_vocab: 135 | if w not in vocab['word_token_to_idx']: 136 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx']) 137 | # add candidate answers of val and test set 138 | for question in chain(val_set, test_set): 139 | for a in question['choices']: 140 | if a not in vocab['answer_token_to_idx']: 141 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 142 | 143 | 144 | if not os.path.isdir(args.output_dir): 145 | os.mkdir(args.output_dir) 146 | fn = os.path.join(args.output_dir, 'vocab.json') 147 | print('Dump vocab to {}'.format(fn)) 148 | with open(fn, 'w') as f: 149 | json.dump(vocab, f, indent=2) 150 | for k in vocab: 151 | print('{}:{}'.format(k, len(vocab[k]))) 152 | 153 | print('Create inverted index for keys') 154 | inverted_index = create_inverted(kb_keys) 155 | 156 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)): 157 | print('Encode {} set'.format(name)) 158 | outputs = encode_dataset(dataset, vocab, inverted_index, stopwords, args.num_cand_keys) 159 | print('shape of questions, key indexes, choices, answers:') 160 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f: 161 | for o in outputs: 162 | print(o.shape) 163 | pickle.dump(o, f) 164 | 165 | print('Encode kb') 166 | outputs = encode_kb(kb_keys, kb_values, vocab) 167 | print('shape of keys, values:') 168 | with open(os.path.join(args.output_dir, 'kb.pt'), 'wb') as f: 169 | for o in outputs: 170 | print(o.shape) 171 | pickle.dump(o, f) 172 | 173 | 174 | 175 | 176 | if __name__ == '__main__': 177 | main() 178 | -------------------------------------------------------------------------------- /KVMemNN/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | from tqdm import tqdm 8 | 9 | from utils.misc import MetricLogger, load_glove 10 | from .data import DataLoader 11 | from .model import KVMemNN 12 | 13 | import logging 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 15 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 16 | rootLogger = logging.getLogger() 17 | 18 | torch.set_num_threads(1) # avoid using multiple cpus 19 | 20 | def validate(model, data, device): 21 | model.eval() 22 | count, correct = 0, 0 23 | with torch.no_grad(): 24 | for batch in tqdm(data, total=len(data)): 25 | question, choices, keys, values, answer = [x.to(device) for x in batch] 26 | logit = model(question, keys, values) 27 | predict = logit.max(1)[1] 28 | correct += torch.eq(predict, answer).long().sum().item() 29 | count += len(answer) 30 | 31 | acc = correct / count 32 | logging.info('\nValid Accuracy: %.4f\n' % acc) 33 | return acc 34 | 35 | 36 | def train(args): 37 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 38 | 39 | logging.info("Create train_loader and val_loader.........") 40 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 41 | train_pt = os.path.join(args.input_dir, 'train.pt') 42 | val_pt = os.path.join(args.input_dir, 'val.pt') 43 | kb_pt = os.path.join(args.input_dir, 'kb.pt') 44 | train_loader = DataLoader(vocab_json, kb_pt, train_pt, args.batch_size, training=True) 45 | val_loader = DataLoader(vocab_json, kb_pt, val_pt, args.batch_size) 46 | vocab = train_loader.vocab 47 | 48 | logging.info("Create model.........") 49 | model = KVMemNN( 50 | args.num_hop, 51 | args.dim_emb, 52 | vocab 53 | ) 54 | logging.info("Load pretrained word vectors.........") 55 | pretrained = load_glove(args.glove_pt, vocab['word_idx_to_token']) 56 | with torch.no_grad(): 57 | model.embeddings.weight.set_(torch.Tensor(pretrained)) 58 | model = model.to(device) 59 | logging.info(model) 60 | 61 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 62 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[5, 50], gamma=0.1) 63 | criterion = nn.CrossEntropyLoss().to(device) 64 | 65 | validate(model, val_loader, device) 66 | meters = MetricLogger(delimiter=" ") 67 | best_acc = 0 68 | logging.info("Start training........") 69 | for epoch in range(args.num_epoch): 70 | model.train() 71 | for iteration, batch in enumerate(train_loader): 72 | iteration = iteration + 1 73 | 74 | question, choices, keys, values, answer = [x.to(device) for x in batch] 75 | logits = model(question, keys, values) 76 | loss = criterion(logits, answer) 77 | optimizer.zero_grad() 78 | loss.backward() 79 | optimizer.step() 80 | meters.update(loss=loss.item()) 81 | 82 | if iteration % (len(train_loader) // 100) == 0: 83 | logging.info( 84 | meters.delimiter.join( 85 | [ 86 | "progress: {progress:.3f}", 87 | "{meters}", 88 | "lr: {lr:.6f}", 89 | ] 90 | ).format( 91 | progress=epoch + iteration / len(train_loader), 92 | meters=str(meters), 93 | lr=optimizer.param_groups[0]["lr"], 94 | ) 95 | ) 96 | 97 | acc = validate(model, val_loader, device) 98 | scheduler.step() 99 | if acc and acc > best_acc: 100 | best_acc = acc 101 | logging.info("\nupdate best ckpt with acc: {:.4f}".format(best_acc)) 102 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model.pt')) 103 | 104 | 105 | def main(): 106 | parser = argparse.ArgumentParser() 107 | # input and output 108 | parser.add_argument('--input_dir', required=True) 109 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 110 | parser.add_argument('--glove_pt', required=True) 111 | 112 | # training parameters 113 | parser.add_argument('--lr', default=0.001, type=float) 114 | parser.add_argument('--weight_decay', default=1e-5, type=float) 115 | parser.add_argument('--num_epoch', default=100, type=int) 116 | parser.add_argument('--batch_size', default=32, type=int) 117 | parser.add_argument('--seed', type=int, default=666, help='random seed') 118 | # model hyperparameters 119 | parser.add_argument('--dim_emb', default=300, type=int) 120 | parser.add_argument('--num_hop', default=3, type=int) 121 | args = parser.parse_args() 122 | 123 | # make logging.info display into both shell and file 124 | if not os.path.exists(args.save_dir): 125 | os.makedirs(args.save_dir) 126 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt')) 127 | fileHandler.setFormatter(logFormatter) 128 | rootLogger.addHandler(fileHandler) 129 | # args display 130 | for k, v in vars(args).items(): 131 | logging.info(k+':'+str(v)) 132 | 133 | # set random seed 134 | torch.manual_seed(args.seed) 135 | 136 | train(args) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 THU-KEG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Program/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | from utils.misc import invert_dict 5 | 6 | 7 | def load_vocab(path): 8 | vocab = json.load(open(path)) 9 | vocab['word_idx_to_token'] = invert_dict(vocab['word_token_to_idx']) 10 | vocab['function_idx_to_token'] = invert_dict(vocab['function_token_to_idx']) 11 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) 12 | return vocab 13 | 14 | def collate(batch): 15 | batch = list(zip(*batch)) 16 | question = torch.stack(batch[0]) 17 | choices = torch.stack(batch[1]) 18 | if batch[-1][0] is None: 19 | program, prog_depends, prog_inputs, answer = None, None, None, None 20 | else: 21 | program, prog_depends, prog_inputs = list(map(torch.stack, batch[2:5])) 22 | answer = torch.cat(batch[5]) 23 | return question, choices, program, prog_depends, prog_inputs, answer 24 | 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | def __init__(self, inputs): 28 | self.questions, self.functions, self.func_depends, self.func_inputs, \ 29 | self.choices, self.answers = inputs 30 | self.is_test = len(self.answers)==0 31 | 32 | 33 | def __getitem__(self, index): 34 | question = torch.LongTensor(self.questions[index]) 35 | choices = torch.LongTensor(self.choices[index]) 36 | if self.is_test: 37 | program = None 38 | prog_depends = None 39 | prog_inputs = None 40 | answer = None 41 | else: 42 | program = torch.LongTensor(self.functions[index]) 43 | prog_depends = torch.LongTensor(self.func_depends[index]) 44 | prog_inputs = torch.LongTensor(self.func_inputs[index]) 45 | answer = torch.LongTensor([self.answers[index]]) 46 | # dependency is not necessary because it can be inferred based on functions 47 | return question, choices, program, prog_depends, prog_inputs, answer 48 | 49 | 50 | def __len__(self): 51 | return len(self.questions) 52 | 53 | 54 | class DataLoader(torch.utils.data.DataLoader): 55 | def __init__(self, vocab_json, question_pt, batch_size, training=False): 56 | vocab = load_vocab(vocab_json) 57 | if training: 58 | print('#vocab of word: %d' % len(vocab['word_token_to_idx'])) 59 | print('#vocab of answer: %d' % len(vocab['answer_token_to_idx'])) 60 | 61 | inputs = [] 62 | with open(question_pt, 'rb') as f: 63 | for _ in range(6): 64 | inputs.append(pickle.load(f)) 65 | dataset = Dataset(inputs) 66 | 67 | super().__init__( 68 | dataset, 69 | batch_size=batch_size, 70 | shuffle=training, 71 | collate_fn=collate, 72 | ) 73 | self.vocab = vocab 74 | -------------------------------------------------------------------------------- /Program/parser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils.BiGRU import GRU, BiGRU 5 | 6 | class Parser(nn.Module): 7 | def __init__(self, vocab, dim_word, dim_hidden, max_dec_len=20, max_inp=3): 8 | super().__init__() 9 | num_func = len(vocab['function_token_to_idx']) 10 | num_words = len(vocab['word_token_to_idx']) 11 | self.vocab = vocab 12 | self.dim_word = dim_word 13 | self.dim_hidden = dim_hidden 14 | self.max_dec_len = max_dec_len 15 | self.max_inp = max_inp 16 | 17 | self.word_embeddings = nn.Embedding(num_words, dim_word) 18 | self.word_dropout = nn.Dropout(0.2) 19 | self.question_encoder = GRU(dim_word, dim_hidden, num_layers=2, dropout=0.2) 20 | 21 | self.func_embeddings = nn.Embedding(num_func, dim_word) 22 | self.decoder = GRU(dim_word, dim_hidden, num_layers=2, dropout=0.2) 23 | 24 | self.func_classifier = nn.Sequential( 25 | nn.Linear(dim_hidden, 1024), 26 | nn.ReLU(), 27 | nn.Linear(1024, num_func), 28 | ) 29 | 30 | self.inp_embeddings = nn.Embedding(num_words, dim_word) 31 | self.inp_decoder = GRU(dim_word + dim_hidden, dim_hidden, num_layers=2, dropout=0.2) 32 | self.inp_classifier = nn.Sequential( 33 | nn.Linear(dim_hidden, 1024), 34 | nn.ReLU(), 35 | nn.Linear(1024, num_words), 36 | ) 37 | 38 | for m in self.modules(): 39 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 40 | nn.init.kaiming_normal_(m.weight) 41 | if m.bias is not None: 42 | m.bias.data.zero_() 43 | 44 | def forward(self, questions, programs=None, inputs=None): 45 | """ 46 | Args: 47 | questions [bsz, max_q] 48 | programs [bsz, max_prog] 49 | inputs [bsz, max_prog, max_inp=3] 50 | Return: 51 | if programs are given, then return losses 52 | else, return predicted programs 53 | """ 54 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means 55 | q_word_emb = self.word_dropout(self.word_embeddings(questions)) 56 | q_word_h, q_embeddings, q_hn = self.question_encoder(q_word_emb, question_lens) 57 | # [bsz, max_q, dim_h], [bsz, dim_h], [num_layers, bsz, dim_h] 58 | 59 | if programs is None: # during inference 60 | return self.inference(q_word_h, q_embeddings, q_hn) 61 | else: 62 | return self.train_phase(q_word_h, q_embeddings, q_hn, programs, inputs) 63 | 64 | 65 | def train_phase(self, q_word_h, q_embeddings, q_hn, programs, inputs): 66 | bsz, max_prog = programs.size(0), programs.size(1) 67 | device = programs.device 68 | program_lens = programs.size(1) - programs.eq(0).long().sum(dim=1) # 0 means 69 | program_mask = programs.ne(0).long() 70 | 71 | p_word_emb = self.word_dropout(self.func_embeddings(programs)) 72 | p_word_h, _, _ = self.decoder(p_word_emb, program_lens, h_0=q_hn) # [bsz, max_prog, dim_h] 73 | # attention over question words 74 | attn = torch.softmax(torch.bmm(p_word_h, q_word_h.permute(0, 2, 1)), dim=2) # [bsz, max_prog, max_q] 75 | attn_word_h = torch.bmm(attn, q_word_h) # [bsz, max_prog, dim_h] 76 | # sum up 77 | p_word_h = p_word_h + attn_word_h # [bsz, max_prog, dim_h] 78 | 79 | 80 | criterion_CE = nn.CrossEntropyLoss().to(device) 81 | # predict function 82 | logit_func = self.func_classifier(p_word_h) # [bsz, max_prog, num_func] 83 | loss_func = criterion_CE(logit_func.permute(0, 2, 1)[:,:,:-1], programs[:,1:]) # remember to shift the gt 84 | 85 | # remove inputs of function 86 | inputs = inputs[:,1:,:].view(bsz, -1) # [bsz, (max_prog-1)*3] 87 | # add an extra at the beginning, for convenience of inference 88 | start_token = torch.zeros((bsz, 1)).to(device).fill_(self.vocab['word_token_to_idx']['']).long() 89 | inputs = torch.cat((start_token, inputs), dim=1) # [bsz, 1+(max_prog-1)*3] 90 | inp_emb = self.word_dropout(self.inp_embeddings(inputs)) # [bsz, 1+(max_prog-1)*3, dim_w] 91 | 92 | rep_p_word_h = p_word_h.view(bsz, max_prog, 1, -1).expand(-1, -1, 3, -1).\ 93 | reshape(bsz, max_prog*3, -1).contiguous() # [bsz, max_prog*3, dim_h] 94 | # align, so that func is used to predict the 3 inputs of the first function 95 | rep_p_word_h = rep_p_word_h[:, :1+(max_prog-1)*3] 96 | inp_h, _, _ = self.inp_decoder(torch.cat((inp_emb, rep_p_word_h), dim=2), 97 | 1+(program_lens-1)*3, h_0=q_hn) # [bsz, 1+(max_prog-1)*3, dim_h] 98 | # attention over question words 99 | attn = torch.softmax(torch.bmm(inp_h, q_word_h.permute(0, 2, 1)), dim=2) 100 | attn_word_h = torch.bmm(attn, q_word_h) 101 | # sum up 102 | inp_h = inp_h + attn_word_h # [bsz, 1+(max_prog-1)*3, dim_h] 103 | # logit 104 | logit_inp = self.inp_classifier(inp_h) # [bsz, 1+(max_prog-1)*3, dim_h] 105 | loss_inp = criterion_CE(logit_inp.permute(0, 2, 1)[:,:,:-1], inputs[:,1:]) # shift the input 106 | 107 | loss = loss_func + loss_inp 108 | 109 | return loss 110 | 111 | 112 | def inference(self, q_word_h, q_embeddings, q_hn): 113 | """ 114 | Predict programs, and inputs 115 | """ 116 | bsz = q_word_h.size(0) 117 | device = q_word_h.device 118 | start_id = self.vocab['function_token_to_idx'][''] 119 | end_id = self.vocab['function_token_to_idx'][''] 120 | 121 | latest_func = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ] 122 | last_h = q_hn 123 | finished = torch.zeros((bsz,)).byte().to(device) # record whether is produced 124 | 125 | latest_inp = torch.LongTensor([self.vocab['word_token_to_idx']['']]*bsz).to(device) # [bsz, ] 126 | last_inp_h = q_hn 127 | 128 | # store predictions at each step 129 | programs = [latest_func] 130 | inputs = [torch.zeros((bsz, self.max_inp)).long().to(device)] 131 | 132 | for i in range(self.max_dec_len): 133 | p_word_emb = self.word_dropout(self.func_embeddings(latest_func)).unsqueeze(1) # [bsz, 1, dim_w] 134 | p_word_h, last_h = self.decoder.forward_one_step(p_word_emb, last_h) # [bsz, 1, dim_h] 135 | # attention over question words 136 | attn = torch.softmax(torch.bmm(p_word_h, q_word_h.permute(0, 2, 1)), dim=2) # [bsz, 1, max_q] 137 | attn_word_h = torch.bmm(attn, q_word_h) # [bsz, 1, dim_h] 138 | # sum up 139 | p_word_h = p_word_h + attn_word_h # [bsz, 1, dim_h] 140 | 141 | # predict function 142 | logit_func = self.func_classifier(p_word_h).squeeze(1) # [bsz, num_func] 143 | latest_func = torch.argmax(logit_func, dim=1) # [bsz, ] 144 | programs.append(latest_func) 145 | 146 | # predict input 147 | pred_inp = [] 148 | for _ in range(self.max_inp): 149 | inp_emb = self.word_dropout(self.inp_embeddings(latest_inp)).unsqueeze(1) # [bsz, 1, dim_w] 150 | inp_h, last_inp_h = self.inp_decoder.forward_one_step( 151 | torch.cat((inp_emb, p_word_h), dim=2), 152 | last_inp_h) # [bsz, 1, dim_h] 153 | attn = torch.softmax(torch.bmm(inp_h, q_word_h.permute(0, 2, 1)), dim=2) 154 | attn_word_h = torch.bmm(attn, q_word_h) 155 | inp_h = inp_h + attn_word_h # [bsz, 1, dim_h] 156 | 157 | logit_inp = self.inp_classifier(inp_h).squeeze(1) # [bsz, num_word] 158 | latest_inp = torch.argmax(logit_inp, dim=1) # [bsz, ] 159 | pred_inp.append(latest_inp) 160 | pred_inp = torch.stack(pred_inp, dim=1) # [bsz, 3] 161 | inputs.append(pred_inp) 162 | 163 | finished = finished | latest_func.eq(end_id).byte() 164 | if finished.sum().item() == bsz: 165 | # print('finished at step {}'.format(i)) 166 | break 167 | 168 | programs = torch.stack(programs, dim=1) # [bsz, max_prog] 169 | inputs = torch.stack(inputs, dim=1) # [bsz, max_prog, 3] 170 | return programs, inputs 171 | 172 | -------------------------------------------------------------------------------- /Program/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import shutil 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | from .data import DataLoader 9 | from .parser import Parser 10 | from .executor_rule import RuleExecutor 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser() 14 | # input and output 15 | parser.add_argument('--input_dir', required=True) 16 | parser.add_argument('--save_dir', required=True, help='path of checkpoint') 17 | # model hyperparameters 18 | parser.add_argument('--dim_word', default=300, type=int) 19 | parser.add_argument('--dim_hidden', default=1024, type=int) 20 | args = parser.parse_args() 21 | 22 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 23 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 24 | test_pt = os.path.join(args.input_dir, 'test.pt') 25 | test_loader = DataLoader(vocab_json, test_pt, 128) 26 | vocab = test_loader.vocab 27 | 28 | rule_executor = RuleExecutor(vocab, os.path.join(args.input_dir, 'kb.json')) 29 | model = Parser(vocab, args.dim_word, args.dim_hidden) 30 | 31 | print("load ckpt from {}".format(args.save_dir)) 32 | model.load_state_dict( 33 | torch.load(os.path.join(args.save_dir, 'model.pt'), map_location={'cuda': 'cpu'})) 34 | model = model.to(device) 35 | model.eval() 36 | 37 | with open(os.path.join(args.save_dir, 'predict.txt'), 'w') as f: 38 | with torch.no_grad(): 39 | for batch in tqdm(test_loader, total=len(test_loader)): 40 | question, choices = [x.to(device) for x in batch[:2]] 41 | pred_program, pred_inputs = model(question) 42 | 43 | pred_program, pred_inputs = [x.cpu().numpy() for x in (pred_program, pred_inputs)] 44 | for i in range(len(pred_program)): 45 | pred = rule_executor.forward(pred_program[i], pred_inputs[i], ignore_error=True) 46 | f.write(str(pred) + '\n') 47 | print("save predictions into {}".format(os.path.join(args.save_dir, 'predict.txt'))) 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /Program/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | from nltk import word_tokenize 7 | from collections import Counter, defaultdict 8 | from itertools import chain 9 | from tqdm import tqdm 10 | 11 | from utils.misc import init_vocab 12 | 13 | max_dep = 2 14 | max_inp = 3 15 | 16 | 17 | def encode_dataset(dataset, vocab, test=False): 18 | questions = [] 19 | functions = [] 20 | func_depends = [] 21 | func_inputs = [] 22 | choices = [] 23 | answers = [] 24 | for question in tqdm(dataset): 25 | q = [vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx']['']) 26 | for w in word_tokenize(question['question'].lower())] 27 | questions.append(q) 28 | 29 | _ = [vocab['answer_token_to_idx'][w] for w in question['choices']] 30 | choices.append(_) 31 | 32 | if test: 33 | continue 34 | 35 | func, dep, inp = [], [], [] 36 | # wrap program with and flags 37 | program = [{'function':'','dependencies':[-1,-1],'inputs':['']}] + \ 38 | question['program'] + \ 39 | [{'function':'','dependencies':[-1,-1],'inputs':['']}] 40 | for f in program: 41 | func.append(vocab['function_token_to_idx'][f['function']]) 42 | dep.append(f['dependencies']) 43 | inp.append([vocab['word_token_to_idx'].get(i, vocab['word_token_to_idx']['']) 44 | for i in f['inputs']]) 45 | 46 | functions.append(func) 47 | func_depends.append(dep) 48 | func_inputs.append(inp) 49 | 50 | if 'answer' in question: 51 | answers.append(vocab['answer_token_to_idx'].get(question['answer'])) 52 | 53 | # question padding 54 | max_len = max(len(q) for q in questions) 55 | for i in range(len(questions)): 56 | while len(questions[i]) < max_len: 57 | questions[i].append(vocab['word_token_to_idx']['']) 58 | 59 | if not test: 60 | # function padding 61 | max_len = max(len(f) for f in functions) 62 | for i in range(len(functions)): 63 | while len(functions[i]) < max_len: 64 | functions[i].append(vocab['function_token_to_idx']['']) 65 | func_depends[i].append([-1, -1]) 66 | func_inputs[i].append([]) 67 | for j in range(max_len): 68 | while len(func_depends[i][j]) < max_dep: 69 | func_depends[i][j].append(-1) # use -1 to pad dependency 70 | while len(func_inputs[i][j]) < max_inp: 71 | func_inputs[i][j].append(vocab['word_token_to_idx']['']) 72 | 73 | questions = np.asarray(questions, dtype=np.int32) 74 | functions = np.asarray(functions, dtype=np.int32) 75 | func_depends = np.asarray(func_depends, dtype=np.int32) 76 | # Because we wrap a before the program, dependencies should shift to the right 77 | # After that, all dependencies >= 0 and 0 means padding 78 | func_depends = func_depends + 1 79 | 80 | func_inputs = np.asarray(func_inputs, dtype=np.int32) 81 | choices = np.asarray(choices, dtype=np.int32) 82 | answers = np.asarray(answers, dtype=np.int32) 83 | return questions, functions, func_depends, func_inputs, choices, answers 84 | 85 | 86 | 87 | def main(): 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument('--input_dir', required=True) 90 | parser.add_argument('--output_dir', required=True) 91 | parser.add_argument('--min_cnt', type=int, default=1) 92 | args = parser.parse_args() 93 | 94 | 95 | vocab = { 96 | 'word_token_to_idx': init_vocab(), 97 | 'function_token_to_idx': init_vocab(), 98 | 'answer_token_to_idx': {} 99 | } 100 | print('Load questions') 101 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json'))) 102 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json'))) 103 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json'))) 104 | print('Build question vocabulary') 105 | word_counter = Counter() 106 | for question in train_set: 107 | tokens = word_tokenize(question['question'].lower()) 108 | word_counter.update(tokens) 109 | # add candidate answers 110 | for a in question['choices']: 111 | if a not in vocab['answer_token_to_idx']: 112 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 113 | # add functions 114 | for f in question['program']: 115 | a = f['function'] 116 | if a not in vocab['function_token_to_idx']: 117 | vocab['function_token_to_idx'][a] = len(vocab['function_token_to_idx']) 118 | word_counter.update(f['inputs']) 119 | # filter low-frequency words 120 | for w, c in word_counter.items(): 121 | if w and c >= args.min_cnt and w not in vocab['word_token_to_idx']: 122 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx']) 123 | # add candidate answers of val and test set 124 | for question in chain(val_set, test_set): 125 | for a in question['choices']: 126 | if a not in vocab['answer_token_to_idx']: 127 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 128 | 129 | 130 | if not os.path.isdir(args.output_dir): 131 | os.mkdir(args.output_dir) 132 | fn = os.path.join(args.output_dir, 'vocab.json') 133 | print('Dump vocab to {}'.format(fn)) 134 | with open(fn, 'w') as f: 135 | json.dump(vocab, f, indent=2) 136 | for k in vocab: 137 | print('{}:{}'.format(k, len(vocab[k]))) 138 | 139 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)): 140 | print('Encode {} set'.format(name)) 141 | outputs = encode_dataset(dataset, vocab, test=name=='test') 142 | assert len(outputs) == 6 143 | print('shape of questions, functions, func_depends, func_inputs, choices, answers:') 144 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f: 145 | for o in outputs: 146 | print(o.shape) 147 | pickle.dump(o, f) 148 | 149 | 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /Program/readme.md: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | - python3 3 | - pytorch>=1.2.0 4 | - nltk 5 | 6 | ## How to run 7 | 1. Download [GloVe 300d vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip), unzip it to get the file `glove.840B.300d.txt`, and then convert it to a pickle file for faster loading: 8 | ``` 9 | python -m utils.pickle_glove --input --output 10 | ``` 11 | This step can be skipped if you have obtained the glove pickle file in other models. 12 | 13 | 2. Preprocess the training data, and copy the `./dataset/kb.json` into `output_dir` 14 | ``` 15 | python -m Program.preprocess --input_dir ./dataset --output_dir 16 | cp ./dataset/kb.json 17 | ``` 18 | 3. Train 19 | ``` 20 | python -m Program.train --input_dir --save_dir --glove_pt 21 | ``` 22 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order. 23 | ``` 24 | python -m Program.predict --input_dir --save_dir 25 | ``` 26 | -------------------------------------------------------------------------------- /Program/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | from tqdm import tqdm 8 | import numpy as np 9 | 10 | from utils.misc import MetricLogger, load_glove 11 | from .data import DataLoader 12 | from .parser import Parser 13 | from .executor_rule import RuleExecutor 14 | 15 | import logging 16 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 17 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 18 | rootLogger = logging.getLogger() 19 | 20 | 21 | def validate_executor(executor, data): 22 | # validate whether the executor is correct 23 | correct = 0 24 | count = 0 25 | for batch in tqdm(data, total=len(data)): 26 | question, choices, gt_program, gt_dep, gt_inputs, answer = batch 27 | gt_program, gt_dep, gt_inputs = [x.cpu().numpy() for x in (gt_program, gt_dep, gt_inputs)] 28 | answer = [data.vocab['answer_idx_to_token'][a.item()] for a in answer] 29 | preds = [] 30 | for i in range(len(gt_program)): 31 | pred = executor.forward(gt_program[i], gt_inputs[i], ignore_error=True) 32 | if pred == answer[i]: 33 | correct += 1 34 | else: 35 | print(pred, answer[i]) 36 | pred = executor.forward(gt_program[i], gt_dep[i], gt_inputs[i], ignore_error=True, show_details=True) 37 | embed() 38 | count += 1 39 | if count >= 10000: 40 | break 41 | print('{}/{}/{:.4f}'.format(correct, count, correct/count)) 42 | 43 | 44 | def validate(model, data, device, executor=None): 45 | model.eval() 46 | end_id = data.vocab['function_token_to_idx'][''] 47 | match_prog_num = 0 48 | match_dep_num = 0 49 | match_inp_num = 0 50 | match_all_num = 0 51 | correct = 0 52 | count = 0 53 | with torch.no_grad(): 54 | for batch in tqdm(data, total=len(data)): 55 | question, choices, gt_program, gt_dep, gt_inputs, answer = [x.to(device) for x in batch] 56 | pred_program, pred_inputs = model(question) 57 | 58 | gt_program, gt_inputs = [x.cpu().numpy() for x in (gt_program, gt_inputs)] 59 | pred_program, pred_inputs = [x.cpu().numpy() for x in (pred_program, pred_inputs)] 60 | 61 | for i in range(len(gt_program)): 62 | 63 | # print(gt_program[i]) 64 | # print(gt_inputs[i]) 65 | # print('---') 66 | # print(pred_program[i]) 67 | # print(pred_inputs[i]) 68 | # print('==========') 69 | 70 | match = True 71 | for j in range(min(len(gt_program[i]), len(pred_program[i]))): 72 | if gt_program[i, j] != pred_program[i, j]: 73 | match = False 74 | break 75 | if gt_program[i, j] == end_id and pred_program[i, j] == end_id: 76 | l = j 77 | break 78 | if match: 79 | match_prog_num += 1 80 | if np.all(gt_inputs[i,1:l,:]==pred_inputs[i,1:l,:]): 81 | match_inp_num += 1 82 | 83 | count += len(gt_program) 84 | 85 | if executor: 86 | answer = [data.vocab['answer_idx_to_token'][a.item()] for a in answer] 87 | for i in range(len(gt_program)): 88 | pred = executor.forward(pred_program[i], pred_inputs[i], ignore_error=True) 89 | if pred == answer[i]: 90 | correct += 1 91 | 92 | logging.info('\nValid match program: {:.4f}, inputs: {:.4f}\n'.format( 93 | match_prog_num / count, 94 | match_inp_num / count, 95 | )) 96 | if executor: 97 | logging.info('Accuracy: {:.4f}\n'.format(correct / count)) 98 | return correct / count 99 | else: 100 | return None 101 | 102 | 103 | def train(args): 104 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 105 | 106 | logging.info("Create train_loader and val_loader.........") 107 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 108 | train_pt = os.path.join(args.input_dir, 'train.pt') 109 | val_pt = os.path.join(args.input_dir, 'val.pt') 110 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True) 111 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size) 112 | vocab = train_loader.vocab 113 | 114 | rule_executor = RuleExecutor(vocab, os.path.join(args.input_dir, 'kb.json')) 115 | 116 | logging.info("Create model.........") 117 | model = Parser(vocab, args.dim_word, args.dim_hidden) 118 | logging.info("Load pretrained word vectors.........") 119 | pretrained = load_glove(args.glove_pt, vocab['word_idx_to_token']) 120 | with torch.no_grad(): 121 | model.word_embeddings.weight.set_(torch.Tensor(pretrained)) 122 | model = model.to(device) 123 | logging.info(model) 124 | if args.ckpt and os.path.exists(args.ckpt): 125 | logging.info("load ckpt from {}".format(args.ckpt)) 126 | model.load_state_dict(torch.load(args.ckpt, map_location={'cuda': 'cpu'})) 127 | 128 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 129 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[5, 50], gamma=0.1) 130 | 131 | # validate_executor(rule_executor, train_loader) # accuracy of val_loader is about 80% due to OOV issue 132 | validate(model, val_loader, device) 133 | 134 | meters = MetricLogger(delimiter=" ") 135 | best_acc = 0 136 | logging.info("Start training........") 137 | for epoch in range(args.num_epoch): 138 | model.train() 139 | for iteration, batch in enumerate(train_loader): 140 | iteration = iteration + 1 141 | 142 | question, choices, program, prog_depends, prog_inputs, answer = [x.to(device) for x in batch] 143 | loss = model(question, program, prog_inputs) 144 | optimizer.zero_grad() 145 | loss.backward() 146 | optimizer.step() 147 | meters.update(loss=loss.item()) 148 | 149 | if iteration % (len(train_loader) // 100) == 0: 150 | logging.info( 151 | meters.delimiter.join( 152 | [ 153 | "progress: {progress:.3f}", 154 | "{meters}", 155 | "lr: {lr:.6f}", 156 | ] 157 | ).format( 158 | progress=epoch + iteration / len(train_loader), 159 | meters=str(meters), 160 | lr=optimizer.param_groups[0]["lr"], 161 | ) 162 | ) 163 | 164 | scheduler.step() 165 | if epoch == args.num_epoch-1 or (epoch+1)%5 == 0: 166 | acc = validate(model, val_loader, device, rule_executor) 167 | else: 168 | acc = validate(model, val_loader, device) 169 | if acc and acc > best_acc: 170 | best_acc = acc 171 | logging.info("\nupdate best ckpt with acc: {:.4f}".format(best_acc)) 172 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model.pt')) 173 | 174 | 175 | def main(): 176 | parser = argparse.ArgumentParser() 177 | # input and output 178 | parser.add_argument('--input_dir', required=True) 179 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 180 | parser.add_argument('--glove_pt', required=True) 181 | parser.add_argument('--ckpt') 182 | 183 | # training parameters 184 | parser.add_argument('--lr', default=0.001, type=float) 185 | parser.add_argument('--weight_decay', default=1e-5, type=float) 186 | parser.add_argument('--num_epoch', default=100, type=int) 187 | parser.add_argument('--batch_size', default=64, type=int) 188 | parser.add_argument('--seed', type=int, default=666, help='random seed') 189 | # model hyperparameters 190 | parser.add_argument('--dim_word', default=300, type=int) 191 | parser.add_argument('--dim_hidden', default=1024, type=int) 192 | args = parser.parse_args() 193 | 194 | # make logging.info display into both shell and file 195 | if os.path.isdir(args.save_dir): 196 | shutil.rmtree(args.save_dir) 197 | os.mkdir(args.save_dir) 198 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt')) 199 | fileHandler.setFormatter(logFormatter) 200 | rootLogger.addHandler(fileHandler) 201 | # args display 202 | for k, v in vars(args).items(): 203 | logging.info(k+':'+str(v)) 204 | 205 | # set random seed 206 | torch.manual_seed(args.seed) 207 | 208 | train(args) 209 | 210 | 211 | if __name__ == '__main__': 212 | main() 213 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KQA Pro Baselines 2 | [KQA Pro](https://arxiv.org/abs/2007.03875) is a large-scale dataset of complex question answering over knowledge base, which provides strong supervision of SPARQL and program for each question. 3 | [Here is its homepage website](http://thukeg.gitee.io/kqa-pro/). This dataset is licensed under a Creative Commons Attribution-ShareAlike 4.0 International. 4 | 5 | This repo implements several baselines for the dataset: 6 | 7 | - Blind GRU. It predicts the answer in terms of only the input question, ignoring the knowledge base. We use it to measure the dataset bias. 8 | - [KVMemNN](https://www.aclweb.org/anthology/D16-1147/) (Key-Value Memory Networks) 9 | - [RGCN](https://arxiv.org/abs/1703.06103) (Relational Graph Convolutional Networks) 10 | - [SRN](https://dl.acm.org/doi/10.1145/3336191.3371812) (Stepwise Relational Networks) 11 | - RNN seq2seq SPARQL parser 12 | - RNN seq2seq program parser 13 | - [BART](https://arxiv.org/abs/1910.13461) seq2seq SPARQL parser 14 | - [BART](https://arxiv.org/abs/1910.13461) seq2seq program parser 15 | 16 | Instructions of how to run these models are described in their README files. 17 | Before trying them, you need to first download the [dataset](https://cloud.tsinghua.edu.cn/f/04ce81541e704a648b03/?dl=1) and unzip it into the folder `./dataset`. 18 | The file tree should be like 19 | ``` 20 | . 21 | +-- dataset 22 | | +-- kb.json 23 | | +-- train.json 24 | | +-- val.json 25 | | +-- test.json 26 | +-- GRU 27 | | +-- preprocess.py 28 | | +-- train.py 29 | | +-- ... 30 | +-- KVMemNN 31 | +-- RGCN 32 | ... 33 | ``` 34 | -------------------------------------------------------------------------------- /RGCN/README.md: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | - python3 3 | - pytorch>=1.2.0 4 | - nltk 5 | - [dgl>=0.4.3](https://github.com/dmlc/dgl/) 6 | 7 | ## How to run 8 | 1. Download [GloVe 300d vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip), unzip it to get the file `glove.840B.300d.txt`, and then convert it to a pickle file for faster loading: 9 | ``` 10 | python -m utils.pickle_glove --input --output 11 | ``` 12 | This step can be skipped if you have obtained the glove pickle file in other models. 13 | 2. Preprocess the training data, and copy the `./dataset/kb.json` into `output_dir` 14 | ``` 15 | python -m RGCN.preprocess --input_dir ./dataset --output_dir 16 | ``` 17 | 3. Train 18 | ``` 19 | python -m RGCN.train --input_dir --save_dir --glove_pt 20 | ``` 21 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order. 22 | ``` 23 | python -m RGCN.predict --input_dir --save_dir 24 | ``` 25 | -------------------------------------------------------------------------------- /RGCN/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | from utils.misc import invert_dict 5 | 6 | 7 | def load_vocab(path): 8 | vocab = json.load(open(path)) 9 | vocab['word_idx_to_token'] = invert_dict(vocab['word_token_to_idx']) 10 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) 11 | vocab['kb_idx_to_token'] = invert_dict(vocab['kb_token_to_idx']) 12 | vocab['predicate_idx_to_token'] = invert_dict(vocab['predicate_token_to_idx']) 13 | return vocab 14 | 15 | def collate(batch): 16 | batch = list(zip(*batch)) 17 | question, choices = list(map(torch.stack, batch[:2])) 18 | if batch[-1][0] is None: 19 | answer = None 20 | else: 21 | answer = torch.cat(batch[-1]) 22 | return question, choices, answer 23 | 24 | 25 | class Dataset(torch.utils.data.Dataset): 26 | def __init__(self, inputs): 27 | self.questions, self.choices, self.answers = inputs 28 | self.is_test = len(self.answers)==0 29 | 30 | 31 | def __getitem__(self, index): 32 | question = torch.LongTensor(self.questions[index]) 33 | choices = torch.LongTensor(self.choices[index]) 34 | if self.is_test: 35 | answer = None 36 | else: 37 | answer = torch.LongTensor([self.answers[index]]) 38 | return question, choices, answer 39 | 40 | 41 | def __len__(self): 42 | return len(self.questions) 43 | 44 | 45 | class DataLoader(torch.utils.data.DataLoader): 46 | def __init__(self, vocab_json, kb_pt, question_pt, batch_size, training=False): 47 | vocab = load_vocab(vocab_json) 48 | 49 | inputs = [] 50 | with open(question_pt, 'rb') as f: 51 | for _ in range(3): 52 | inputs.append(pickle.load(f)) 53 | with open(kb_pt, 'rb') as f: 54 | self.node_descs = torch.LongTensor(pickle.load(f)) 55 | self.triples = torch.LongTensor(pickle.load(f)) 56 | 57 | dataset = Dataset(inputs) 58 | 59 | super().__init__( 60 | dataset, 61 | batch_size=batch_size, 62 | shuffle=training, 63 | collate_fn=collate, 64 | ) 65 | self.vocab = vocab 66 | 67 | -------------------------------------------------------------------------------- /RGCN/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Refer to https://github.com/dmlc/dgl/tree/master/examples/pytorch/rgcn 3 | """ 4 | import math 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from dgl import DGLGraph 9 | from dgl.nn.pytorch import RelGraphConv 10 | 11 | from utils.BiGRU import BiGRU 12 | 13 | class RGCN(nn.Module): 14 | def __init__(self, in_dim, h_dim, out_dim, num_rels, num_bases, 15 | num_hidden_layers=2, dropout=0, 16 | use_self_loop=False, use_cuda=True): 17 | super().__init__() 18 | self.in_dim = in_dim 19 | self.h_dim = h_dim 20 | self.out_dim = out_dim 21 | self.num_rels = num_rels 22 | self.num_bases = None if num_bases < 0 else num_bases 23 | self.num_hidden_layers = num_hidden_layers 24 | self.dropout = dropout 25 | self.use_self_loop = use_self_loop 26 | self.use_cuda = use_cuda 27 | 28 | # create rgcn layers 29 | self.build_model() 30 | 31 | def build_model(self): 32 | self.layers = nn.ModuleList() 33 | # i2h 34 | i2h = self.build_input_layer() 35 | if i2h is not None: 36 | self.layers.append(i2h) 37 | # h2h 38 | for idx in range(self.num_hidden_layers): 39 | h2h = self.build_hidden_layer(idx) 40 | self.layers.append(h2h) 41 | # h2o 42 | h2o = self.build_output_layer() 43 | if h2o is not None: 44 | self.layers.append(h2o) 45 | 46 | def build_input_layer(self): 47 | return None 48 | 49 | def build_hidden_layer(self, idx): 50 | return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "basis", 51 | self.num_bases, activation=F.relu, self_loop=self.use_self_loop, 52 | dropout=self.dropout) 53 | 54 | def build_output_layer(self): 55 | return None 56 | # return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis", 57 | # self.num_bases, activation=None, 58 | # self_loop=self.use_self_loop) 59 | 60 | def forward(self, g, h, r, norm=None): 61 | for layer in self.layers: 62 | h = layer(g, h, r, norm) 63 | return h 64 | 65 | 66 | class QuesAnsByRGCN(nn.Module): 67 | def __init__(self, vocab, node_descs, edge_triples, 68 | dim_word, dim_hidden, dim_g, num_bases=1, num_hidden_layers=1): 69 | """ 70 | Args: 71 | - edge_triples (np.array) [#triple, 3] 72 | """ 73 | super().__init__() 74 | num_rels = len(vocab['predicate_token_to_idx']) 75 | num_desc_word = len(vocab['kb_token_to_idx']) 76 | num_question_word = len(vocab['word_token_to_idx']) 77 | num_class = len(vocab['answer_token_to_idx']) 78 | 79 | self.rgcn = RGCN(dim_g, dim_g, dim_g, num_rels, num_bases, num_hidden_layers) 80 | edge_src = edge_triples[:,0] 81 | edge_type = edge_triples[:,1] 82 | edge_dst = edge_triples[:,2] 83 | self.edge_type = edge_type 84 | self.num_nodes = len(node_descs) 85 | self.node_descs = node_descs # [#node, max_desc] 86 | self.dim_g = dim_g 87 | 88 | self.desc_embeddings = nn.Embedding(num_desc_word, dim_g) 89 | nn.init.normal_(self.desc_embeddings.weight, mean=0, std=1/math.sqrt(dim_g)) 90 | 91 | self.input_embeddings = nn.Embedding(num_question_word, dim_word) 92 | nn.init.normal_(self.input_embeddings.weight, mean=0, std=1/math.sqrt(dim_word)) 93 | 94 | self.word_dropout = nn.Dropout(0.3) 95 | self.question_encoder = BiGRU(dim_word, dim_hidden, num_layers=1, dropout=0.0) 96 | 97 | # create graph 98 | self.g = DGLGraph() 99 | self.g.add_nodes(self.num_nodes) 100 | self.g.add_edges(edge_src, edge_dst) 101 | 102 | self.lin_h_to_g = nn.Linear(dim_hidden, dim_g) 103 | self.classifier = nn.Sequential( 104 | nn.Linear(dim_g + dim_hidden, 1024), 105 | nn.ReLU(), 106 | nn.Linear(1024, num_class) 107 | ) 108 | 109 | for m in self.modules(): 110 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 111 | nn.init.kaiming_normal_(m.weight) 112 | if m.bias is not None: 113 | m.bias.data.zero_() 114 | 115 | 116 | def forward(self, questions, only_q=False): 117 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means 118 | question_input = self.word_dropout(self.input_embeddings(questions)) 119 | _, question_embeddings, _ = self.question_encoder(question_input, question_lens) 120 | # [bsz, dim_h] 121 | 122 | if only_q: 123 | bsz = question_embeddings.size(0) 124 | device = question_embeddings.device 125 | empty = torch.zeros((bsz, self.dim_g)).to(device) 126 | feat = torch.cat((empty, question_embeddings), dim=1) 127 | logits = self.classifier(feat) 128 | return logits 129 | 130 | 131 | agg_feats = [] 132 | bsz = len(questions) 133 | for i in range(bsz): 134 | # construct initial node features 135 | q = question_embeddings[i].view(1, 1, -1) # [1, 1, dim_h] 136 | node_desc_emb = self.word_dropout(self.desc_embeddings(self.node_descs)) 137 | # [#node, max_desc, dim_g] 138 | q_g = self.lin_h_to_g(q) # [1, 1, dim_g] 139 | attn = torch.softmax(torch.sum(node_desc_emb * q_g, dim=2), dim=1) # [#node, max_desc] 140 | node_feat = torch.sum(attn.unsqueeze(2) * node_desc_emb, dim=1) # [#node, dim_g] 141 | 142 | # rgcn 143 | node_feat = self.rgcn(self.g, node_feat, self.edge_type) # [#node, dim_g] 144 | 145 | # answer feature 146 | q_g = q_g.view(1, -1) # [1, dim_g] 147 | attn = torch.softmax(torch.sum(node_feat * q_g, dim=1, keepdim=True), dim=0) # [#node, 1] 148 | node_agg = torch.sum(node_feat * attn, dim=0) # [dim_g] 149 | node_agg = torch.cat((node_agg, q.view(-1)), dim=0) # [dim_g+dim_h] 150 | agg_feats.append(node_agg) 151 | 152 | agg_feats = torch.stack(agg_feats) # [bsz, 2*dim_h] 153 | logits = self.classifier(agg_feats) 154 | return logits 155 | -------------------------------------------------------------------------------- /RGCN/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from tqdm import tqdm 6 | 7 | from .data import DataLoader 8 | from .model import QuesAnsByRGCN 9 | 10 | 11 | def test(args): 12 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 13 | 14 | print('load test data') 15 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 16 | test_pt = os.path.join(args.input_dir, 'test.pt') 17 | kb_pt = os.path.join(args.input_dir, 'kb.pt') 18 | data = DataLoader(vocab_json, kb_pt, test_pt, 4) 19 | vocab = data.vocab 20 | 21 | print('load model') 22 | node_descs = data.node_descs.to(device) 23 | node_descs = node_descs[:, :args.max_desc] 24 | triples = data.triples.to(device) 25 | triples = triples[:args.max_triple] 26 | model = QuesAnsByRGCN(vocab, 27 | node_descs, triples, 28 | args.dim_word, args.dim_hidden, args.dim_g) 29 | model = model.to(device) 30 | model.eval() 31 | model.load_state_dict(torch.load(os.path.join(args.save_dir, 'model.pt'))) 32 | 33 | fn_open = open(os.path.join(args.save_dir, 'predict.txt'), 'w') 34 | fn_choice = open(os.path.join(args.save_dir, 'choice_predict.txt'), 'w') 35 | for batch in tqdm(data, total=len(data)): 36 | question, choices, answer = batch 37 | question = question.to(device) 38 | logit = model(question) 39 | logit = logit.detach().cpu() 40 | 41 | for l, c in zip(logit, choices): 42 | a = l.max(0)[1].item() 43 | a = vocab['answer_idx_to_token'][a] 44 | fn_open.write(a + '\n') 45 | # mask for multi-choice 46 | l = torch.softmax(l, 0) 47 | mask = torch.ones((len(l),)).bool() 48 | mask[c] = 0 49 | l[mask] = 0 50 | a = l.max(0)[1].item() 51 | a = vocab['answer_idx_to_token'][a] 52 | fn_choice.write(a + '\n') 53 | fn_open.close() 54 | fn_choice.close() 55 | 56 | 57 | def main(): 58 | parser = argparse.ArgumentParser() 59 | # input and output 60 | parser.add_argument('--input_dir', required=True) 61 | parser.add_argument('--save_dir', required=True, help='path to store predictions') 62 | 63 | # model hyperparameters 64 | parser.add_argument('--dim_word', default=300, type=int) 65 | parser.add_argument('--dim_hidden', default=512, type=int) 66 | parser.add_argument('--dim_g', default=32, type=int) 67 | parser.add_argument('--max_desc', default=20, type=int) 68 | parser.add_argument('--max_triple', default=200000, type=int) 69 | args = parser.parse_args() 70 | 71 | test(args) 72 | 73 | 74 | if __name__ == '__main__': 75 | main() 76 | -------------------------------------------------------------------------------- /RGCN/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import numpy as np 5 | from nltk import word_tokenize 6 | from collections import Counter 7 | from itertools import chain 8 | from tqdm import tqdm 9 | import argparse 10 | 11 | from utils.load_kb import load_as_graph 12 | from utils.misc import init_vocab 13 | 14 | 15 | def encode_dataset(dataset, vocab): 16 | questions = [] 17 | choices = [] 18 | answers = [] 19 | for question in tqdm(dataset): 20 | q = [vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx']['']) 21 | for w in word_tokenize(question['question'].lower())] 22 | questions.append(q) 23 | 24 | _ = [vocab['answer_token_to_idx'][w] for w in question['choices']] 25 | choices.append(_) 26 | 27 | if 'answer' in question: 28 | answers.append(vocab['answer_token_to_idx'].get(question['answer'])) 29 | 30 | # question padding 31 | max_len = max(len(q) for q in questions) 32 | for q in questions: 33 | while len(q) < max_len: 34 | q.append(vocab['word_token_to_idx']['']) 35 | 36 | questions = np.asarray(questions, dtype=np.int32) 37 | choices = np.asarray(choices, dtype=np.int32) 38 | answers = np.asarray(answers, dtype=np.int32) 39 | return questions, choices, answers 40 | 41 | 42 | 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--input_dir', required=True) 47 | parser.add_argument('--output_dir', required=True) 48 | parser.add_argument('--min_cnt', type=int, default=1) 49 | parser.add_argument('--max_desc', type=int, default=200) 50 | args = parser.parse_args() 51 | 52 | 53 | print('Load and encode kb...') 54 | kb_vocab, node_descs, triples, nodeid2idx, pred2idx = \ 55 | load_as_graph(os.path.join(args.input_dir, 'kb.json'), args.max_desc) 56 | node_descs = np.asarray(node_descs) 57 | triples = np.asarray(triples) 58 | print("shape of node_descs and triples:", node_descs.shape, triples.shape) 59 | print(node_descs[-10:]) 60 | print(triples[:10]) 61 | 62 | with open(os.path.join(args.output_dir, 'kb.pt'), 'wb') as f: 63 | pickle.dump(node_descs, f) 64 | pickle.dump(triples, f) 65 | 66 | 67 | vocab = { 68 | 'kb_token_to_idx': kb_vocab, 69 | 'predicate_token_to_idx': pred2idx, 70 | 'word_token_to_idx': init_vocab(), 71 | 'answer_token_to_idx': {} 72 | } 73 | 74 | print('Load questions') 75 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json'))) 76 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json'))) 77 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json'))) 78 | print('Build question vocabulary') 79 | word_counter = Counter() 80 | for question in train_set: 81 | tokens = word_tokenize(question['question'].lower()) 82 | word_counter.update(tokens) 83 | # add candidate answers 84 | for a in question['choices']: 85 | if a not in vocab['answer_token_to_idx']: 86 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 87 | for w, c in word_counter.items(): 88 | if w and c >= args.min_cnt and w not in vocab['word_token_to_idx']: 89 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx']) 90 | # add candidate answers of val and test set 91 | for question in chain(val_set, test_set): 92 | for a in question['choices']: 93 | if a not in vocab['answer_token_to_idx']: 94 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 95 | 96 | if not os.path.isdir(args.output_dir): 97 | os.mkdir(args.output_dir) 98 | fn = os.path.join(args.output_dir, 'vocab.json') 99 | print('Dump vocab to {}'.format(fn)) 100 | with open(fn, 'w') as f: 101 | json.dump(vocab, f, indent=2) 102 | for k in vocab: 103 | print('{}:{}'.format(k, len(vocab[k]))) 104 | 105 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)): 106 | print('Encode {} set'.format(name)) 107 | outputs = encode_dataset(dataset, vocab) 108 | print('shape of questions, choices, answers:') 109 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f: 110 | for o in outputs: 111 | print(o.shape) 112 | pickle.dump(o, f) 113 | 114 | 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | 120 | -------------------------------------------------------------------------------- /RGCN/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import random 6 | from tqdm import tqdm 7 | import os 8 | import pickle 9 | import argparse 10 | import shutil 11 | 12 | from utils.misc import MetricLogger, load_glove 13 | from .data import DataLoader 14 | from .model import QuesAnsByRGCN 15 | 16 | import logging 17 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 18 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 19 | rootLogger = logging.getLogger() 20 | 21 | 22 | def validate(model, data, device): 23 | model.eval() 24 | count, correct = 0, 0 25 | with torch.no_grad(): 26 | for batch in tqdm(data, total=len(data)): 27 | question, choices, answer = [x.to(device) for x in batch] 28 | logit = model(question) 29 | predict = logit.max(1)[1] 30 | correct += torch.eq(predict, answer).long().sum().item() 31 | count += len(answer) 32 | 33 | acc = correct / count 34 | logging.info('\nValid Accuracy: %.4f\n' % acc) 35 | return acc 36 | 37 | 38 | def train(args): 39 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 40 | 41 | logging.info("Create train_loader and val_loader.........") 42 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 43 | train_pt = os.path.join(args.input_dir, 'train.pt') 44 | val_pt = os.path.join(args.input_dir, 'val.pt') 45 | kb_pt = os.path.join(args.input_dir, 'kb.pt') 46 | train_loader = DataLoader(vocab_json, kb_pt, train_pt, args.batch_size, training=True) 47 | train_loader_large_bsz = DataLoader(vocab_json, kb_pt, train_pt, 128, training=True) 48 | val_loader = DataLoader(vocab_json, kb_pt, val_pt, 128) 49 | vocab = train_loader.vocab 50 | 51 | logging.info("Create model.........") 52 | node_descs = train_loader.node_descs.to(device) 53 | node_descs = node_descs[:, :args.max_desc] 54 | triples = train_loader.triples.to(device) 55 | triples = triples[:args.max_triple] 56 | model = QuesAnsByRGCN(vocab, 57 | node_descs, triples, 58 | args.dim_word, args.dim_hidden, args.dim_g) 59 | logging.info("Load pretrained word vectors.........") 60 | pretrained = load_glove(args.glove_pt, vocab['word_idx_to_token']) 61 | with torch.no_grad(): 62 | model.input_embeddings.weight.set_(torch.Tensor(pretrained)) 63 | model = model.to(device) 64 | logging.info(model) 65 | 66 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 67 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[5, 20], gamma=0.1) 68 | criterion = nn.CrossEntropyLoss().to(device) 69 | 70 | # validate(model, val_loader, device) 71 | meters = MetricLogger(delimiter=" ") 72 | best_acc = 0 73 | logging.info("Start training........") 74 | for epoch in range(args.num_epoch): 75 | model.train() 76 | if epoch < 2: 77 | _train_loader = train_loader_large_bsz 78 | only_q = True 79 | else: 80 | _train_loader = train_loader 81 | only_q = False 82 | for iteration, batch in enumerate(_train_loader): 83 | iteration = iteration + 1 84 | 85 | question, choices, answer = [x.to(device) for x in batch] 86 | logits = model(question, only_q) 87 | loss = criterion(logits, answer) 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | meters.update(loss=loss.item()) 92 | 93 | if iteration % (len(train_loader) // 1000) == 0: 94 | logging.info( 95 | meters.delimiter.join( 96 | [ 97 | "progress: {progress:.3f}", 98 | "{meters}", 99 | "lr: {lr:.6f}", 100 | ] 101 | ).format( 102 | progress=epoch + iteration / len(_train_loader), 103 | meters=str(meters), 104 | lr=optimizer.param_groups[0]["lr"], 105 | ) 106 | ) 107 | 108 | if epoch == args.num_epoch-1 or (epoch+1)%2 == 0: 109 | acc = validate(model, val_loader, device) 110 | else: 111 | acc = None 112 | scheduler.step() 113 | if acc and acc > best_acc: 114 | best_acc = acc 115 | logging.info("\nupdate best ckpt with acc: {:.4f}".format(best_acc)) 116 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model.pt')) 117 | 118 | 119 | 120 | def main(): 121 | parser = argparse.ArgumentParser() 122 | # input and output 123 | parser.add_argument('--input_dir', required=True) 124 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 125 | parser.add_argument('--glove_pt', required=True) 126 | 127 | # training parameters 128 | parser.add_argument('--lr', default=0.001, type=float) 129 | parser.add_argument('--weight_decay', default=1e-5, type=float) 130 | parser.add_argument('--num_epoch', default=40, type=int) 131 | parser.add_argument('--batch_size', default=6, type=int) 132 | parser.add_argument('--seed', type=int, default=666, help='random seed') 133 | # model hyperparameters 134 | parser.add_argument('--dim_word', default=300, type=int) 135 | parser.add_argument('--dim_hidden', default=512, type=int) 136 | parser.add_argument('--dim_g', default=32, type=int) 137 | parser.add_argument('--max_desc', default=20, type=int) 138 | parser.add_argument('--max_triple', default=200000, type=int) 139 | args = parser.parse_args() 140 | 141 | # make logging.info display into both shell and file 142 | if not os.path.exists(args.save_dir): 143 | os.makedirs(args.save_dir) 144 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt')) 145 | fileHandler.setFormatter(logFormatter) 146 | rootLogger.addHandler(fileHandler) 147 | # args display 148 | for k, v in vars(args).items(): 149 | logging.info(k + ':' + str(v)) 150 | 151 | # set random seed 152 | torch.manual_seed(args.seed) 153 | 154 | train(args) 155 | 156 | if __name__ == '__main__': 157 | main() 158 | 159 | -------------------------------------------------------------------------------- /SPARQL/README.md: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | - python3 3 | - rdflib=4.2.2 4 | --- 5 | **Note:** 6 | After installing rdflib via `pip` or `anaconda` or some other tools, we need to fix some bugs of it. 7 | 8 | First, find your rdflib location. One possible way is to run following codes in ipython 9 | ``` 10 | import rdflib 11 | rdflib.__file__ 12 | ``` 13 | It returns `~/anaconda3/lib/python3.7/site-packages/rdflib/__init__.py` in my computer, so I enter the folder `~/anaconda3/lib/python3.7/site-packages/rdflib`. 14 | 15 | Then open `plugins/sparql/parser.py`, find *Line 68*, replace its code with 16 | ``` 17 | if i + 1 < l and (not isinstance(terms[i + 1], str) or terms[i + 1] not in ".,;"): 18 | ``` 19 | Remember to keep the original indentation. 20 | Note that *Line 67* is a comment of `# is this bnode the subject of more triplets?`. If your line number is different from mine, you could locate the target line by this comment. 21 | 22 | Finally, open `plugins/serializers/turtle.py`, find *Line 328*, change `use_plain=True` to `use_plain=False` 23 | 24 | --- 25 | 26 | - SPARQLWrapper=1.8.4 27 | 28 | --- 29 | **Note:** 30 | When installing `SPARQLWrapper` with `pip`, it may automatically install another package `keepalive`. You can check whether it is in your environment by 31 | ``` 32 | pip show keepalive 33 | ``` 34 | 35 | If it is installed, it will cause some problems when we execute a large number of SPARQL queries. Specifically, the available ports will be used out. So we need to manually disable the `keepalive` package. It is okay to directly remove it. 36 | ``` 37 | pip uninstall keepalive 38 | ``` 39 | 40 | --- 41 | 42 | - Virtuoso backend, refer to the next section 43 | 44 | ## How to install virtuoso backend 45 | The virtuoso backend will start up a web service, we can import our kb into it and then execute SPARQL queries by network requests. We install virtuoso in an Ubuntu 16.04 system. Following are specific steps. 46 | 47 | 1. Download and install virtuoso into our system. 48 | ``` 49 | git clone https://github.com/openlink/virtuoso-opensource.git Virtuoso-Opensource 50 | cd Virtuoso-Opensource 51 | git checkout stable/7 52 | sudo apt-get install libtool gawk gperf autoconf automake libtool flex bison m4 make openssl libssl-dev 53 | sudo ./autogen.sh 54 | sudo ./configure 55 | sudo make 56 | sudo make install 57 | ``` 58 | 59 | 2. Create a new user for virtuoso service 60 | ``` 61 | sudo useradd virtuoso --home /usr/local/virtuoso-opensource 62 | sudo chown -R virtuoso /usr/local/virtuoso-opensource 63 | ``` 64 | 65 | 3. Modify some necessary configs: 66 | ``` 67 | cd /usr/local/virtuoso-opensource/var/lib/virtuoso/db 68 | sudo vim virtuoso.ini 69 | ``` 70 | Find the item `CheckpointInterval`, and change its value from default 60 to 0, to avoid automatical checkpoint process which will cause 404 error. 71 | 72 | 4. Start up the virtuoso service: 73 | ``` 74 | sudo -H -u virtuoso ../../../../bin/virtuoso-t -f & 75 | ``` 76 | Now you can access the service via the default port 8890. 77 | Enter `[ip]:8890` in a browser, you will see the virtuoso service page. 78 | 79 | 5. Now we can import our kb into virtuoso. Before that, we need to convert our kb to `ttl` format and move it to proper position: 80 | ``` 81 | python sparql_engine.py --kb_path .dataset/kb.json --ttl_path .dataset/kb.ttl 82 | sudo chmod 777 .dataset/kb.ttl 83 | sudo mv .dataset/kb.ttl /usr/local/virtuoso-opensource/share/virtuoso/vad 84 | ``` 85 | 86 | 6. Enter the interactive terminal of virtuoso: 87 | ``` 88 | cd /usr/local/virtuoso-opensource/bin 89 | sudo ./isql 90 | ``` 91 | 92 | 7. Import our kb by executing these commands in terminal: 93 | ``` 94 | SPARQL CREATE GRAPH <[graph_name]>; 95 | SPARQL CLEAR GRAPH <[graph_name]>; 96 | delete from db.dba.load_list; 97 | ld_dir('/usr/local/virtuoso-opensource/share/virtuoso/vad', 'kb.ttl', '[graph_name]'); 98 | rdf_loader_run(); 99 | select * from DB.DBA.load_list; 100 | exit; 101 | ``` 102 | `[graph_name]` could be any legal string, such as *KQAPro*. 103 | You are success if `rdf_loader_run()` lasts for about 10 seconds. 104 | 105 | 106 | ## How to run 107 | 1. Follow the last section, start up the virtuoso service and import `kb.ttl`. Then you need to open `sparql_engine.py` and find the lines of 108 | ``` 109 | virtuoso_address = "http://127.0.0.1:8890/sparql" 110 | virtuoso_graph_uri = 'sjx' 111 | ``` 112 | Change `virtuoso_address` to your service url (you can visit it in your browser to check whether it is valid) and change `virtuoso_graph_uri` to your ``. 113 | 2. Preprocess the training data 114 | ``` 115 | python -m SPARQL.preprocess --input_dir ./dataset --output_dir 116 | cp ./dataset/kb.json 117 | ``` 118 | 3. Train 119 | ``` 120 | python -m SPARQL.train --input_dir --save_dir 121 | ``` 122 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order. 123 | ``` 124 | python -m SPARQL.predict --input_dir --save_dir 125 | ``` 126 | -------------------------------------------------------------------------------- /SPARQL/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | from utils.misc import invert_dict 5 | 6 | 7 | def load_vocab(path): 8 | vocab = json.load(open(path)) 9 | vocab['word_idx_to_token'] = invert_dict(vocab['word_token_to_idx']) 10 | vocab['sparql_idx_to_token'] = invert_dict(vocab['sparql_token_to_idx']) 11 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) 12 | return vocab 13 | 14 | def collate(batch): 15 | batch = list(zip(*batch)) 16 | question = torch.stack(batch[0]) 17 | choices = torch.stack(batch[1]) 18 | if batch[-1][0] is None: 19 | sparql, answer = None, None 20 | else: 21 | sparql = torch.stack(batch[2]) 22 | answer = torch.cat(batch[3]) 23 | return question, choices, sparql, answer 24 | 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | def __init__(self, inputs): 28 | self.questions, self.sparqls, self.choices, self.answers = inputs 29 | self.is_test = len(self.answers)==0 30 | 31 | 32 | def __getitem__(self, index): 33 | question = torch.LongTensor(self.questions[index]) 34 | choices = torch.LongTensor(self.choices[index]) 35 | if self.is_test: 36 | sparql = None 37 | answer = None 38 | else: 39 | sparql = torch.LongTensor(self.sparqls[index]) 40 | answer = torch.LongTensor([self.answers[index]]) 41 | return question, choices, sparql, answer 42 | 43 | 44 | def __len__(self): 45 | return len(self.questions) 46 | 47 | 48 | class DataLoader(torch.utils.data.DataLoader): 49 | def __init__(self, vocab_json, question_pt, batch_size, training=False): 50 | vocab = load_vocab(vocab_json) 51 | if training: 52 | print('#vocab of word/sparql/answer: %d/%d/%d' % 53 | (len(vocab['word_token_to_idx']), len(vocab['sparql_token_to_idx']), len(vocab['answer_token_to_idx']))) 54 | 55 | inputs = [] 56 | with open(question_pt, 'rb') as f: 57 | for _ in range(4): 58 | inputs.append(pickle.load(f)) 59 | dataset = Dataset(inputs) 60 | 61 | super().__init__( 62 | dataset, 63 | batch_size=batch_size, 64 | shuffle=training, 65 | collate_fn=collate, 66 | ) 67 | self.vocab = vocab 68 | 69 | -------------------------------------------------------------------------------- /SPARQL/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from utils.BiGRU import GRU, BiGRU 5 | 6 | class SPARQLParser(nn.Module): 7 | def __init__(self, vocab, dim_word, dim_hidden, max_dec_len): 8 | super().__init__() 9 | num_words = len(vocab['word_token_to_idx']) 10 | num_sparql = len(vocab['sparql_token_to_idx']) 11 | self.vocab = vocab 12 | self.dim_word = dim_word 13 | self.dim_hidden = dim_hidden 14 | self.max_dec_len = max_dec_len 15 | 16 | self.word_embeddings = nn.Embedding(num_words, dim_word) 17 | self.word_dropout = nn.Dropout(0.3) 18 | self.question_encoder = GRU(dim_word, dim_hidden, num_layers=2, dropout=0.2) 19 | 20 | self.sparql_embeddings = nn.Embedding(num_sparql, dim_word) 21 | self.decoder = GRU(dim_word, dim_hidden, num_layers=2, dropout=0.2) 22 | 23 | self.sparql_classifier = nn.Sequential( 24 | nn.Linear(dim_hidden, 1024), 25 | nn.ReLU(), 26 | nn.Linear(1024, num_sparql), 27 | ) 28 | 29 | self.att_lin = nn.Linear(dim_hidden, dim_hidden) 30 | 31 | for m in self.modules(): 32 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 33 | nn.init.kaiming_normal_(m.weight) 34 | if m.bias is not None: 35 | m.bias.data.zero_() 36 | 37 | def forward(self, questions, sparqls=None): 38 | """ 39 | Args: 40 | questions [bsz, max_q] 41 | sparqls [bsz, max_s] 42 | Return: 43 | if sparqls are given, then return losses 44 | else, return predicted sparqls 45 | """ 46 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means 47 | q_word_emb = self.word_dropout(self.word_embeddings(questions)) 48 | q_word_h, q_embeddings, q_hn = self.question_encoder(q_word_emb, question_lens) 49 | # [bsz, max_q, dim_h], [bsz, dim_h], [num_layers, bsz, dim_h] 50 | 51 | if sparqls is None: # during inference 52 | return self.inference(q_word_h, q_embeddings, q_hn) 53 | else: 54 | return self.train_phase(q_word_h, q_embeddings, q_hn, sparqls) 55 | 56 | 57 | def train_phase(self, q_word_h, q_embeddings, q_hn, sparqls): 58 | bsz, max_s = sparqls.size(0), sparqls.size(1) 59 | device = sparqls.device 60 | sparql_lens = max_s - sparqls.eq(0).long().sum(dim=1) # 0 means 61 | sparql_mask = sparqls.ne(0).long() 62 | 63 | s_word_emb = self.word_dropout(self.sparql_embeddings(sparqls)) 64 | s_word_h, _, _ = self.decoder(s_word_emb, sparql_lens, h_0=q_hn) # [bsz, max_s, dim_h] 65 | # attention over question words 66 | attn = torch.softmax(torch.bmm(s_word_h, q_word_h.permute(0, 2, 1)), dim=2) # [bsz, max_s, max_q] 67 | attn_word_h = torch.bmm(attn, q_word_h) # [bsz, max_s, dim_h] 68 | # sum up 69 | s_word_h = s_word_h + attn_word_h # [bsz, max_s, dim_h] 70 | 71 | criterion = nn.CrossEntropyLoss().to(device) 72 | logit = self.sparql_classifier(s_word_h) # [bsz, max_s, num_sparql] 73 | loss = criterion(logit.permute(0, 2, 1)[:,:,:-1], sparqls[:,1:]) # remember to shift the gt 74 | 75 | return loss 76 | 77 | 78 | def inference(self, q_word_h, q_embeddings, q_hn): 79 | """ 80 | Predict sparqls 81 | """ 82 | bsz = q_word_h.size(0) 83 | device = q_word_h.device 84 | start_id = self.vocab['sparql_token_to_idx'][''] 85 | end_id = self.vocab['sparql_token_to_idx'][''] 86 | 87 | latest_sparql = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ] 88 | last_h = q_hn 89 | finished = torch.zeros((bsz,)).byte().to(device) # record whether is produced 90 | 91 | # store predictions at each step 92 | sparqls = [latest_sparql] 93 | 94 | for i in range(self.max_dec_len): 95 | s_word_emb = self.word_dropout(self.sparql_embeddings(latest_sparql)).unsqueeze(1) # [bsz, 1, dim_w] 96 | s_word_h, last_h = self.decoder.forward_one_step(s_word_emb, last_h) # [bsz, 1, dim_h] 97 | # attention over question words 98 | attn = torch.softmax(torch.bmm(s_word_h, q_word_h.permute(0, 2, 1)), dim=2) # [bsz, 1, max_q] 99 | attn_word_h = torch.bmm(attn, q_word_h) # [bsz, 1, dim_h] 100 | # sum up 101 | s_word_h = s_word_h + attn_word_h # [bsz, 1, dim_h] 102 | 103 | logit = self.sparql_classifier(s_word_h).squeeze(1) # [bsz, num_sparql] 104 | latest_sparql = torch.argmax(logit, dim=1) # [bsz, ] 105 | sparqls.append(latest_sparql) 106 | 107 | finished = finished | latest_sparql.eq(end_id).byte() 108 | if finished.sum().item() == bsz: 109 | # print('finished at step {}'.format(i)) 110 | break 111 | 112 | sparqls = torch.stack(sparqls, dim=1) # [bsz, max_s] 113 | 114 | return sparqls 115 | -------------------------------------------------------------------------------- /SPARQL/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import json 5 | from tqdm import tqdm 6 | 7 | from utils.load_kb import DataForSPARQL 8 | from .data import DataLoader 9 | from .model import SPARQLParser 10 | from .sparql_engine import get_sparql_answer 11 | from .preprocess import postprocess_sparql_tokens 12 | 13 | import warnings 14 | warnings.simplefilter("ignore") # hide warnings that caused by invalid sparql query 15 | 16 | 17 | def test(args): 18 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | 20 | print('load test data') 21 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 22 | test_pt = os.path.join(args.input_dir, 'test.pt') 23 | data = DataLoader(vocab_json, test_pt, 128, training=False) 24 | vocab = data.vocab 25 | kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json')) 26 | 27 | print('load model') 28 | model = SPARQLParser(vocab, args.dim_word, args.dim_hidden, args.max_dec_len) 29 | model = model.to(device) 30 | model.load_state_dict(torch.load(os.path.join(args.save_dir, 'model.pt'))) 31 | 32 | f = open(os.path.join(args.save_dir, 'predict.txt'), 'w') 33 | for batch in tqdm(data, total=len(data)): 34 | question, choices, sparql, answer = batch 35 | question = question.to(device) 36 | pred_sparql = model(question) 37 | 38 | pred_sparql = pred_sparql.cpu().numpy().tolist() 39 | for s in pred_sparql: 40 | s = [vocab['sparql_idx_to_token'][i] for i in s] 41 | end_idx = len(s) 42 | if '' in s: 43 | end_idx = s.index('') 44 | s = ' '.join(s[1:end_idx]) 45 | s = postprocess_sparql_tokens(s) 46 | answer = str(get_sparql_answer(s, kb)) 47 | f.write(answer + '\n') 48 | f.close() 49 | 50 | 51 | 52 | def main(): 53 | parser = argparse.ArgumentParser() 54 | # input and output 55 | parser.add_argument('--input_dir', required=True) 56 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 57 | 58 | # model hyperparameters 59 | parser.add_argument('--dim_word', default=300, type=int) 60 | parser.add_argument('--dim_hidden', default=1024, type=int) 61 | parser.add_argument('--max_dec_len', default=100, type=int) 62 | args = parser.parse_args() 63 | 64 | test(args) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /SPARQL/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | We need the last function to help extract the final answer of SPARQL, used in check_sparql 3 | """ 4 | 5 | import os 6 | import json 7 | import pickle 8 | import argparse 9 | import numpy as np 10 | from nltk import word_tokenize 11 | from collections import Counter 12 | from itertools import chain 13 | from tqdm import tqdm 14 | import re 15 | 16 | from utils.misc import init_vocab 17 | 18 | def tokenize_sparql(s): 19 | # separate punctuations 20 | s = s.replace('"', ' " ').replace('^^', ' ^^ ') 21 | # NOTE: after decoding, these extra space must be removed 22 | # this may cause some mistakes, but the ratio is very small, about one of thousands 23 | return s.split() 24 | 25 | def postprocess_sparql_tokens(s): 26 | # organize the predicted sparql tokens into a valid query 27 | s = s.replace(' ^^ ', '^^') 28 | skip_idxs = set() 29 | for i in range(len(s)): 30 | if s[i] == '"': 31 | if i > 2 and s[i-1]==' ' and s[i-2] not in {'>'}: 32 | skip_idxs.add(i-1) 33 | if i < len(s)-2 and s[i+1]==' ' and s[i+2] not in {'<'}: 34 | skip_idxs.add(i+1) 35 | s = ''.join([s[i] for i in range(len(s)) if i not in skip_idxs]) 36 | return s 37 | 38 | def encode_dataset(dataset, vocab, test=False): 39 | questions = [] 40 | sparqls = [] 41 | choices = [] 42 | answers = [] 43 | for question in tqdm(dataset): 44 | q = [vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx']['']) 45 | for w in word_tokenize(question['question'].lower())] 46 | questions.append(q) 47 | 48 | _ = [vocab['answer_token_to_idx'][w] for w in question['choices']] 49 | choices.append(_) 50 | 51 | if test: 52 | continue 53 | 54 | _ = [vocab['sparql_token_to_idx'].get(w, vocab['sparql_token_to_idx']['']) 55 | for w in tokenize_sparql(question['sparql'])] 56 | # wrap with 57 | _ = [vocab['sparql_token_to_idx']['']] + _ + [vocab['sparql_token_to_idx']['']] 58 | sparqls.append(_) 59 | 60 | if 'answer' in question: 61 | answers.append(vocab['answer_token_to_idx'].get(question['answer'])) 62 | 63 | # question padding 64 | max_len = max(len(q) for q in questions) 65 | for q in questions: 66 | while len(q) < max_len: 67 | q.append(vocab['word_token_to_idx']['']) 68 | if not test: 69 | # sparql padding 70 | max_len = max(len(s) for s in sparqls) 71 | for s in sparqls: 72 | while len(s) < max_len: 73 | s.append(vocab['sparql_token_to_idx']['']) 74 | 75 | questions = np.asarray(questions, dtype=np.int32) 76 | sparqls = np.asarray(sparqls, dtype=np.int32) 77 | choices = np.asarray(choices, dtype=np.int32) 78 | answers = np.asarray(answers, dtype=np.int32) 79 | return questions, sparqls, choices, answers 80 | 81 | 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--input_dir', required=True) 86 | parser.add_argument('--output_dir', required=True) 87 | parser.add_argument('--min_cnt', type=int, default=1) 88 | args = parser.parse_args() 89 | 90 | 91 | print('Build kb vocabulary') 92 | vocab = { 93 | 'word_token_to_idx': init_vocab(), 94 | 'sparql_token_to_idx': init_vocab(), 95 | 'answer_token_to_idx': {} 96 | } 97 | print('Load questions') 98 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json'))) 99 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json'))) 100 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json'))) 101 | print('Build question vocabulary') 102 | word_counter = Counter() 103 | for question in train_set: 104 | tokens = word_tokenize(question['question'].lower()) 105 | word_counter.update(tokens) 106 | # add candidate answers 107 | for a in question['choices']: 108 | if a not in vocab['answer_token_to_idx']: 109 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 110 | # add sparql 111 | for a in tokenize_sparql(question['sparql']): 112 | if a not in vocab['sparql_token_to_idx']: 113 | vocab['sparql_token_to_idx'][a] = len(vocab['sparql_token_to_idx']) 114 | 115 | # filter low-frequency words 116 | for w, c in word_counter.items(): 117 | if w and c >= args.min_cnt and w not in vocab['word_token_to_idx']: 118 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx']) 119 | # add candidate answers of val and test set 120 | for question in chain(val_set, test_set): 121 | for a in question['choices']: 122 | if a not in vocab['answer_token_to_idx']: 123 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx']) 124 | 125 | 126 | if not os.path.isdir(args.output_dir): 127 | os.mkdir(args.output_dir) 128 | fn = os.path.join(args.output_dir, 'vocab.json') 129 | print('Dump vocab to {}'.format(fn)) 130 | with open(fn, 'w') as f: 131 | json.dump(vocab, f, indent=2) 132 | for k in vocab: 133 | print('{}:{}'.format(k, len(vocab[k]))) 134 | 135 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)): 136 | print('Encode {} set'.format(name)) 137 | outputs = encode_dataset(dataset, vocab, name=='test') 138 | assert len(outputs) == 4 139 | print('shape of questions, sparqls, choices, answers:') 140 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f: 141 | for o in outputs: 142 | print(o.shape) 143 | pickle.dump(o, f) 144 | 145 | 146 | 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | -------------------------------------------------------------------------------- /SPARQL/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | import json 8 | from tqdm import tqdm 9 | from datetime import date 10 | 11 | from utils.misc import MetricLogger 12 | from utils.load_kb import DataForSPARQL 13 | from .data import DataLoader 14 | from .model import SPARQLParser 15 | from .sparql_engine import get_sparql_answer 16 | from .preprocess import postprocess_sparql_tokens 17 | 18 | import logging 19 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 20 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 21 | rootLogger = logging.getLogger() 22 | import warnings 23 | warnings.simplefilter("ignore") # hide warnings that caused by invalid sparql query 24 | 25 | def whether_equal(answer, pred): 26 | """ 27 | check whether the two arguments are equal as attribute value 28 | """ 29 | def truncate_float(x): 30 | # convert answer from '100.0 meters' to '100 meters' 31 | try: 32 | v, *u = x.split() 33 | v = float(v) 34 | if v - int(v) < 1e-5: 35 | v = int(v) 36 | if len(u) == 0: 37 | x = str(v) 38 | else: 39 | x = '{} {}'.format(str(v), ' '.join(u)) 40 | except: 41 | pass 42 | return x 43 | 44 | def equal_as_date(x, y): 45 | # check whether x and y are equal as type of date or year 46 | try: 47 | x_split = x.split('-') 48 | y_split = y.split('-') 49 | if len(x_split) == 3: 50 | x = date(int(x_split[0]), int(x_split[1]), int(x_split[2])) 51 | else: 52 | x = int(x) 53 | if len(y_split) == 3: 54 | y = date(int(y_split[0]), int(y_split[1]), int(y_split[2])) 55 | else: 56 | y = int(y) 57 | if isinstance(x, date) and isinstance(y, date): 58 | return x == y 59 | else: 60 | x = x.year if isinstance(x, date) else x 61 | y = y.year if isinstance(y, date) else y 62 | return x == y 63 | except: 64 | return False 65 | 66 | answer = truncate_float(answer) 67 | pred = truncate_float(pred) 68 | if equal_as_date(answer, pred): 69 | return True 70 | else: 71 | return answer == pred 72 | 73 | 74 | def validate(args, kb, model, data, device): 75 | model.eval() 76 | count, correct = 0, 0 77 | with torch.no_grad(): 78 | for batch in tqdm(data, total=len(data)): 79 | question, choices, sparql, answer = [x.to(device) for x in batch] 80 | pred_sparql = model(question) 81 | 82 | answer, pred_sparql = [x.cpu().numpy().tolist() for x in (answer, pred_sparql)] 83 | for a, s in zip(answer, pred_sparql): 84 | given_answer = data.vocab['answer_idx_to_token'][a] 85 | s = [data.vocab['sparql_idx_to_token'][i] for i in s] 86 | end_idx = len(s) 87 | if '' in s: 88 | end_idx = s.index('') 89 | s = ' '.join(s[1:end_idx]) 90 | s = postprocess_sparql_tokens(s) 91 | pred_answer = get_sparql_answer(s, kb) 92 | is_match = whether_equal(given_answer, pred_answer) 93 | if is_match: 94 | correct += 1 95 | count += len(answer) 96 | acc = correct / count 97 | logging.info('\nValid Accuracy: %.4f\n' % acc) 98 | return acc 99 | 100 | def test_sparql(args): 101 | # check whether the SPARQL engine is correct, with the training set 102 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 103 | train_pt = os.path.join(args.input_dir, 'train.pt') 104 | data = DataLoader(vocab_json, train_pt, args.batch_size, training=False) 105 | kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json')) 106 | 107 | count, correct = 0, 0 108 | for batch in tqdm(data, total=len(data)): 109 | question, choices, sparql, answer = batch 110 | pred_sparql = sparql 111 | 112 | answer = answer.cpu().numpy().tolist() 113 | pred_sparql = pred_sparql.cpu().numpy().tolist() 114 | for a, s in zip(answer, pred_sparql): 115 | given_answer = data.vocab['answer_idx_to_token'][a] 116 | s = [data.vocab['sparql_idx_to_token'][i] for i in s] 117 | end_idx = len(s) 118 | if '' in s: 119 | end_idx = s.index('') 120 | s = ' '.join(s[1:end_idx]) 121 | s = postprocess_sparql_tokens(s) 122 | pred_answer = get_sparql_answer(s, kb) 123 | is_match = whether_equal(given_answer, pred_answer) 124 | count += 1 125 | if is_match: 126 | correct += 1 127 | else: 128 | print(given_answer, pred_answer) 129 | 130 | def train(args): 131 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 132 | 133 | logging.info("Create train_loader and val_loader.........") 134 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 135 | train_pt = os.path.join(args.input_dir, 'train.pt') 136 | val_pt = os.path.join(args.input_dir, 'val.pt') 137 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True) 138 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size) 139 | vocab = train_loader.vocab 140 | kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json')) 141 | 142 | logging.info("Create model.........") 143 | model = SPARQLParser(vocab, args.dim_word, args.dim_hidden, args.max_dec_len) 144 | model = model.to(device) 145 | logging.info(model) 146 | 147 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 148 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[5, 50], gamma=0.1) 149 | 150 | # validate(args, kb, model, val_loader, device) 151 | meters = MetricLogger(delimiter=" ") 152 | best_acc = 0 153 | logging.info("Start training........") 154 | for epoch in range(args.num_epoch): 155 | model.train() 156 | for iteration, batch in enumerate(train_loader): 157 | iteration = iteration + 1 158 | 159 | question, choices, sparql, answer = [x.to(device) for x in batch] 160 | loss = model(question, sparql) 161 | optimizer.zero_grad() 162 | loss.backward() 163 | optimizer.step() 164 | meters.update(loss=loss.item()) 165 | 166 | if iteration % (len(train_loader) // 100) == 0: 167 | logging.info( 168 | meters.delimiter.join( 169 | [ 170 | "progress: {progress:.3f}", 171 | "{meters}", 172 | "lr: {lr:.6f}", 173 | ] 174 | ).format( 175 | progress=epoch + iteration / len(train_loader), 176 | meters=str(meters), 177 | lr=optimizer.param_groups[0]["lr"], 178 | ) 179 | ) 180 | 181 | acc = validate(args, kb, model, val_loader, device) 182 | scheduler.step() 183 | if acc and acc > best_acc: 184 | best_acc = acc 185 | logging.info("\nupdate best ckpt with acc: {:.4f}".format(best_acc)) 186 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model.pt')) 187 | 188 | 189 | def main(): 190 | parser = argparse.ArgumentParser() 191 | # input and output 192 | parser.add_argument('--input_dir', required=True) 193 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 194 | 195 | # training parameters 196 | parser.add_argument('--lr', default=0.001, type=float) 197 | parser.add_argument('--weight_decay', default=1e-5, type=float) 198 | parser.add_argument('--num_epoch', default=100, type=int) 199 | parser.add_argument('--batch_size', default=64, type=int) 200 | parser.add_argument('--seed', type=int, default=666, help='random seed') 201 | # model hyperparameters 202 | parser.add_argument('--dim_word', default=300, type=int) 203 | parser.add_argument('--dim_hidden', default=1024, type=int) 204 | parser.add_argument('--max_dec_len', default=100, type=int) 205 | args = parser.parse_args() 206 | 207 | # make logging.info display into both shell and file 208 | if os.path.isdir(args.save_dir): 209 | shutil.rmtree(args.save_dir) 210 | os.mkdir(args.save_dir) 211 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt')) 212 | fileHandler.setFormatter(logFormatter) 213 | rootLogger.addHandler(fileHandler) 214 | # args display 215 | for k, v in vars(args).items(): 216 | logging.info(k+':'+str(v)) 217 | 218 | # set random seed 219 | torch.manual_seed(args.seed) 220 | 221 | train(args) 222 | # test_sparql(args) 223 | 224 | 225 | if __name__ == '__main__': 226 | main() 227 | -------------------------------------------------------------------------------- /SRN/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | from utils.misc import invert_dict 5 | 6 | 7 | def load_vocab(path): 8 | vocab = json.load(open(path)) 9 | vocab['id2word'] = invert_dict(vocab['word2id']) 10 | vocab['id2entity'] = invert_dict(vocab['entity2id']) 11 | vocab['id2relation'] = invert_dict(vocab['relation2id']) 12 | # vocab['entity2name'] = invert_dict(vocab['name2entity']) 13 | return vocab 14 | 15 | def collate(batch): 16 | batch = list(zip(*batch)) 17 | question, topic_entity, answer = list(map(torch.stack, batch)) 18 | return question, topic_entity, answer 19 | 20 | 21 | class Dataset(torch.utils.data.Dataset): 22 | def __init__(self, inputs): 23 | self.questions, self.topic_entities, self.answers = inputs 24 | print(self.questions.shape) 25 | print(self.topic_entities.shape) 26 | print(self.answers.shape) 27 | 28 | def __getitem__(self, index): 29 | question = torch.LongTensor(self.questions[index]) 30 | topic_entity = torch.LongTensor(self.topic_entities[index]) 31 | answer = torch.LongTensor(self.answers[index]) 32 | return question, topic_entity, answer 33 | 34 | 35 | def __len__(self): 36 | return len(self.questions) 37 | 38 | 39 | class DataLoader(torch.utils.data.DataLoader): 40 | def __init__(self, vocab_json, question_pt, batch_size, training=False): 41 | vocab = load_vocab(vocab_json) 42 | 43 | inputs = [] 44 | with open(question_pt, 'rb') as f: 45 | for _ in range(3): 46 | inputs.append(pickle.load(f)) 47 | dataset = Dataset(inputs) 48 | 49 | super().__init__( 50 | dataset, 51 | batch_size=batch_size, 52 | shuffle=training, 53 | collate_fn=collate, 54 | ) 55 | self.vocab = vocab 56 | 57 | -------------------------------------------------------------------------------- /SRN/knowledge_graph.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import pickle 4 | from collections import defaultdict 5 | import torch 6 | import torch.nn as nn 7 | from utils.misc import * 8 | 9 | class KnowledgeGraph(nn.Module): 10 | def __init__(self, args, vocab): 11 | super(KnowledgeGraph, self).__init__() 12 | self.args = args 13 | self.entity2id, self.id2entity = vocab['entity2id'], vocab['id2entity'] 14 | self.relation2id, self.id2relation = vocab['relation2id'], vocab['id2relation'] 15 | self.adj_list = None 16 | self.action_space = None 17 | self.action_mask = None 18 | self.bandwidth = args.bandwidth 19 | with open(os.path.join(args.input_dir, 'adj_list.pt'), 'rb') as f: 20 | self.adj_list = pickle.load(f) 21 | self.vectorize_action_space() 22 | self.relation_embeddings = nn.Embedding(self.num_relations, args.dim_hidden) 23 | nn.init.xavier_normal_(self.relation_embeddings.weight) 24 | 25 | 26 | def vectorize_action_space(self): 27 | def load_pgrk_score(): 28 | pgrk_scores = defaultdict(float) 29 | with open(os.path.join(self.args.input_dir, 'pgrk.txt')) as f: 30 | for line in f: 31 | e, score = line.strip().split(':') 32 | pgrk_scores[(int)(e)] = float(score) 33 | return pgrk_scores 34 | 35 | page_rank_scores = load_pgrk_score() 36 | 37 | def get_action_space(e1): 38 | action_space = [] 39 | if e1 in self.adj_list: 40 | for r in self.adj_list[e1]: 41 | targets = self.adj_list[e1][r] 42 | for e2 in targets: 43 | action_space.append((r, e2)) 44 | if len(action_space) + 1 >= self.bandwidth: 45 | # Base graph pruning 46 | sorted_action_space = \ 47 | sorted(action_space, key=lambda x: page_rank_scores[x[1]], reverse=True) 48 | action_space = sorted_action_space[:self.bandwidth] 49 | action_space.insert(0, (NO_OP_RELATION_ID, e1)) 50 | return action_space 51 | 52 | def vectorize_action_space(action_space_list, action_space_size): 53 | bucket_size = len(action_space_list) 54 | r_space = torch.zeros(bucket_size, action_space_size) + self.dummy_r 55 | e_space = torch.zeros(bucket_size, action_space_size) + self.dummy_e 56 | action_mask = torch.zeros(bucket_size, action_space_size) 57 | for i, action_space in enumerate(action_space_list): 58 | for j, (r, e) in enumerate(action_space): 59 | r_space[i, j] = r 60 | e_space[i, j] = e 61 | action_mask[i, j] = 1 62 | return (r_space.long(), e_space.long()), action_mask 63 | 64 | self.action_space_buckets = {} 65 | action_space_buckets_discrete = defaultdict(list) 66 | self.entity2bucketid = torch.zeros(self.num_entities, 2).long() 67 | num_facts_saved_in_action_table = 0 68 | for e1 in range(self.num_entities): 69 | action_space = get_action_space(e1) 70 | key = int(len(action_space) / self.args.bucket_interval) + 1 71 | self.entity2bucketid[e1, 0] = key 72 | self.entity2bucketid[e1, 1] = len(action_space_buckets_discrete[key]) 73 | action_space_buckets_discrete[key].append(action_space) 74 | num_facts_saved_in_action_table += len(action_space) 75 | print('Sanity check: {} facts saved in action table'.format(num_facts_saved_in_action_table - self.num_entities)) 76 | for key in action_space_buckets_discrete: 77 | self.action_space_buckets[key] = vectorize_action_space(action_space_buckets_discrete[key], key * self.args.bucket_interval) 78 | print('Vectorize action spaces bucket {} with size {} finished'.format(key, len(self.action_space_buckets[key][-1]))) 79 | print('Sanity check: {} action space bucket in total'.format(len(self.action_space_buckets))) 80 | 81 | 82 | @property 83 | def num_entities(self): 84 | return len(self.entity2id) 85 | 86 | @property 87 | def num_relations(self): 88 | return len(self.relation2id) 89 | 90 | @property 91 | def self_edge(self): 92 | return NO_OP_RELATION_ID 93 | 94 | @property 95 | def self_e(self): 96 | return NO_OP_ENTITY_ID 97 | 98 | @property 99 | def dummy_r(self): 100 | return DUMMY_RELATION_ID 101 | 102 | @property 103 | def dummy_e(self): 104 | return DUMMY_ENTITY_ID 105 | 106 | @property 107 | def dummy_start_r(self): 108 | return START_RELATION_ID 109 | -------------------------------------------------------------------------------- /SRN/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | from tqdm import tqdm 8 | 9 | from utils.misc import MetricLogger 10 | from SRN.data import DataLoader 11 | from SRN.model import SRN 12 | 13 | import logging 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 15 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 16 | rootLogger = logging.getLogger() 17 | 18 | torch.set_num_threads(1) # avoid using multiple cpus 19 | 20 | def validate(args, vocab, model, data, device): 21 | def write(f, predict): 22 | predict = predict.squeeze().tolist() 23 | for i in predict: 24 | f.write(vocab['id2entity'][i] + '\n') 25 | model.eval() 26 | count, correct = 0, 0 27 | f1 = open(os.path.join(args.save_dir, 'predict.txt'), 'w') 28 | with torch.no_grad(): 29 | for batch in tqdm(data, total=len(data)): 30 | questions, topic_entities, answers = [x.to(device) for x in batch] 31 | predict = model(questions, topic_entities) 32 | 33 | pred_e2s = predict['pred_e2s'] 34 | pred_e2_scores = predict['pred_e2_scores'] 35 | search_traces = predict['search_traces'] 36 | pred_top_e2 = pred_e2s[:, 0].unsqueeze(-1) # [bsz, beam_size] => [bsz] => [bsz, 1] 37 | write(f1, pred_top_e2) 38 | correct += torch.any(pred_top_e2 == answers, dim=1).float().sum().item() 39 | count += len(answers) 40 | acc = correct / count 41 | f1.close() 42 | logging.info('\nValid Accuracy: %.4f\n' % acc) 43 | return acc 44 | 45 | def train(args): 46 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 47 | 48 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 49 | train_pt = os.path.join(args.input_dir, 'train.pt') 50 | val_pt = os.path.join(args.input_dir, 'val.pt') 51 | test_pt = os.path.join(args.input_dir, 'test.pt') 52 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True) 53 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size) 54 | test_loader = DataLoader(vocab_json, test_pt, args.batch_size) 55 | vocab = train_loader.vocab 56 | 57 | model = SRN(args, args.dim_word, args.dim_hidden, vocab) 58 | model.load_state_dict(torch.load(args.ckpt)) 59 | model = model.to(device) 60 | validate(args, vocab, model, test_loader, device) 61 | 62 | 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser() 67 | # input and output 68 | parser.add_argument('--input_dir', required=True) 69 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 70 | parser.add_argument('--glove_pt', default='/data/csl/resources/word2vec/glove.840B.300d.py36.pt') 71 | 72 | # training parameters 73 | parser.add_argument('--lr', default=0.001, type=float) 74 | parser.add_argument('--weight_decay', default=1e-5, type=float) 75 | parser.add_argument('--num_epoch', default=60, type=int) 76 | parser.add_argument('--batch_size', default=512, type=int) 77 | parser.add_argument('--seed', type=int, default=666, help='random seed') 78 | # model hyperparameters 79 | parser.add_argument('--dim_emb', default=300, type=int) 80 | parser.add_argument('--num_rollout_steps', default=3, type=int) 81 | parser.add_argument('--num_rollouts', default=10, type=int) 82 | parser.add_argument('--dim_word', default=300, type=int) 83 | parser.add_argument('--dim_hidden', default=300, type=int) 84 | parser.add_argument('--bucket_interval', default = 3, type = int) 85 | parser.add_argument('--opt', default = 'adam', type = str) 86 | parser.add_argument('--bandwidth', default = 100, type = int) 87 | parser.add_argument('--gamma', default = 0.95, type = float) 88 | parser.add_argument('--eta', default = 0.95, type = float) 89 | parser.add_argument('--beta', default = 0, type =float) 90 | parser.add_argument('--beam_size', default = 32, type = int) 91 | parser.add_argument('--log_name', default = 'log.txt', type = str) 92 | parser.add_argument('--model_name', default = 'model.pt', type = str) 93 | parser.add_argument('--rel', action = 'store_true') 94 | parser.add_argument('--ckpt', required=True) 95 | args = parser.parse_args() 96 | 97 | # set random seed 98 | torch.manual_seed(args.seed) 99 | 100 | train(args) 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /SRN/readme.md: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | - python3 3 | - pytorch>=1.2.0 4 | - nltk 5 | 6 | ## How to run 7 | 1. Download [GloVe 300d vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip), unzip it to get the file `glove.840B.300d.txt`, and then convert it to a pickle file for faster loading: 8 | ``` 9 | python -m utils.pickle_glove --input --output 10 | ``` 11 | This step can be skipped if you have obtained the glove pickle file in other models. 12 | 13 | 2. Preprocess the training data 14 | ``` 15 | python -m SRN.preprocess --input_dir ./dataset --output_dir 16 | ``` 17 | 3. Train 18 | ``` 19 | python -m SRN.train --input_dir --save_dir --glove_pt 20 | ``` 21 | -------------------------------------------------------------------------------- /SRN/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | from tqdm import tqdm 8 | 9 | from utils.misc import MetricLogger, load_glove 10 | from SRN.data import DataLoader 11 | from SRN.model import SRN 12 | import copy 13 | import logging 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 15 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 16 | rootLogger = logging.getLogger() 17 | 18 | torch.set_num_threads(1) # avoid using multiple cpus 19 | 20 | def validate(model, data, device): 21 | model.eval() 22 | count, correct = 0, 0 23 | with torch.no_grad(): 24 | for batch in tqdm(data, total=len(data)): 25 | questions, topic_entities, answers = [x.to(device) for x in batch] 26 | predict = model(questions, topic_entities) 27 | pred_e2s = predict['pred_e2s'] 28 | pred_e2_scores = predict['pred_e2_scores'] 29 | search_traces = predict['search_traces'] 30 | pred_top_e2 = pred_e2s[:, 0].unsqueeze(-1) # [bsz, beam_size] => [bsz] => [bsz, 1] 31 | correct += torch.any(pred_top_e2 == answers, dim=1).float().sum().item() 32 | count += len(answers) 33 | acc = correct / count 34 | logging.info('\nValid Accuracy: %.4f' % acc) 35 | return acc 36 | 37 | def train(args): 38 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 39 | 40 | logging.info("Create train_loader, val_loader.........") 41 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 42 | train_pt = os.path.join(args.input_dir, 'train.pt') 43 | val_pt = os.path.join(args.input_dir, 'val.pt') 44 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True) 45 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size) 46 | vocab = train_loader.vocab 47 | 48 | logging.info("Create model.........") 49 | model = SRN(args, args.dim_word, args.dim_hidden, vocab) 50 | logging.info("Load pretrained word vectors.........") 51 | pretrained = load_glove(args.glove_pt, vocab['id2word']) 52 | model.word_embeddings.weight.data = torch.Tensor(pretrained) 53 | model = model.to(device) 54 | logging.info(model) 55 | if args.opt == 'adam': 56 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 57 | elif args.opt == 'sgd': 58 | optimizer = optim.SGD(model.parameters(), args.lr, weight_decay=args.weight_decay) 59 | elif args.opt == 'adagrad': 60 | optimizer = optim.Adagrad(model.parameters(), args.lr, weight_decay=args.weight_decay) 61 | else: 62 | raise NotImplementedError 63 | # scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[3], gamma=0.1) 64 | 65 | validate(model, val_loader, device) 66 | meters = MetricLogger(delimiter=" ") 67 | logging.info("Start training........") 68 | best_model= copy.deepcopy(model.state_dict()) 69 | best_acc = 0.0 70 | eps = 0.00001 71 | for epoch in range(args.num_epoch): 72 | model.train() 73 | for iteration, batch in enumerate(train_loader): 74 | iteration = iteration + 1 75 | 76 | question, topic_entity, answer = [x.to(device) for x in batch] 77 | loss, pt_loss = model(question, topic_entity, answer) 78 | optimizer.zero_grad() 79 | loss.backward() 80 | optimizer.step() 81 | meters.update(loss=pt_loss.item()) 82 | 83 | if iteration % (len(train_loader) // 100) == 0: 84 | logging.info( 85 | meters.delimiter.join( 86 | [ 87 | "progress: {progress:.3f}", 88 | "{meters}", 89 | "lr: {lr:.6f}", 90 | ] 91 | ).format( 92 | progress=epoch + iteration / len(train_loader), 93 | meters=str(meters), 94 | lr=optimizer.param_groups[0]["lr"], 95 | ) 96 | ) 97 | break 98 | 99 | 100 | acc = validate(model, val_loader, device) 101 | if acc > best_acc + eps: 102 | best_acc = acc 103 | no_update = 0 104 | best_model = copy.deepcopy(model.state_dict()) 105 | logging.info("Validation accuracy increased from previous epoch {}".format(acc)) 106 | torch.save(model.state_dict(), os.path.join(args.save_dir, '%s-%s-%s-%s.pt'%(args.opt, str(args.lr), str(args.bandwidth), str(epoch)))) 107 | elif (acc < best_acc + eps) and (no_update < args.patience): 108 | no_update +=1 109 | logging.info("Validation accuracy decreases to %f from %f, %d more epoch to check"%(acc, best_acc, args.patience-no_update)) 110 | elif no_update == args.patience: 111 | logging.info("Model has exceed patience. Saving best model and exiting") 112 | torch.save(best_model, os.path.join(args.save_dir, "best_score_model.pt")) 113 | exit() 114 | 115 | # acc = validate(model, test_loader, device) 116 | # torch.save(model.state_dict(), os.path.join(args.save_dir, '%s-%s-%d-%.2f'%(args.model_name, args.opt, args.lr, acc))) 117 | # scheduler.step() 118 | 119 | 120 | def main(): 121 | parser = argparse.ArgumentParser() 122 | # input and output 123 | parser.add_argument('--input_dir', required=True) 124 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 125 | parser.add_argument('--glove_pt', required=True) 126 | 127 | # training parameters 128 | parser.add_argument('--lr', default=0.001, type=float) 129 | parser.add_argument('--weight_decay', default=1e-5, type=float) 130 | parser.add_argument('--num_epoch', default=100, type=int) 131 | parser.add_argument('--batch_size', default=16, type=int) 132 | parser.add_argument('--seed', type=int, default=666, help='random seed') 133 | # model hyperparameters 134 | parser.add_argument('--dim_emb', default=300, type=int) 135 | parser.add_argument('--num_rollout_steps', default=3, type=int) 136 | parser.add_argument('--num_rollouts', default=10, type=int) 137 | parser.add_argument('--dim_word', default=300, type=int) 138 | parser.add_argument('--dim_hidden', default=300, type=int) 139 | parser.add_argument('--bucket_interval', default = 3, type = int) 140 | parser.add_argument('--opt', default = 'adam', type = str) 141 | parser.add_argument('--bandwidth', default = 50, type = int) 142 | parser.add_argument('--gamma', default = 0.95, type = float) 143 | parser.add_argument('--eta', default = 0.95, type = float) 144 | parser.add_argument('--beta', default = 0, type =float) 145 | parser.add_argument('--beam_size', default = 32, type = int) 146 | parser.add_argument('--log_name', default = 'log.txt', type = str) 147 | parser.add_argument('--model_name', default = 'model.pt', type = str) 148 | parser.add_argument('--patience', default = 10, type = int) 149 | args = parser.parse_args() 150 | 151 | # make logging.info display into both shell and file 152 | if not os.path.exists(args.save_dir): 153 | os.makedirs(args.save_dir) 154 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, args.log_name)) 155 | fileHandler.setFormatter(logFormatter) 156 | rootLogger.addHandler(fileHandler) 157 | # args display 158 | for k, v in vars(args).items(): 159 | logging.info(k+':'+str(v)) 160 | 161 | # set random seed 162 | torch.manual_seed(args.seed) 163 | 164 | train(args) 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | from datetime import date 5 | from collections import defaultdict, Counter 6 | from tqdm import tqdm 7 | def whether_equal(answer, pred): 8 | def truncate_float(x): 9 | # convert answer from '100.0 meters' to '100 meters' 10 | try: 11 | v, *u = x.split() 12 | v = float(v) 13 | if v - int(v) < 1e-5: 14 | v = int(v) 15 | if len(u) == 0: 16 | x = str(v) 17 | else: 18 | x = '{} {}'.format(str(v), ' '.join(u)) 19 | except: 20 | pass 21 | return x 22 | 23 | def equal_as_date(x, y): 24 | # check whether x and y are equal as type of date or year 25 | try: 26 | x_split = x.split('-') 27 | y_split = y.split('-') 28 | if len(x_split) == 3: 29 | x = date(int(x_split[0]), int(x_split[1]), int(x_split[2])) 30 | else: 31 | x = int(x) 32 | if len(y_split) == 3: 33 | y = date(int(y_split[0]), int(y_split[1]), int(y_split[2])) 34 | else: 35 | y = int(y) 36 | if isinstance(x, date) and isinstance(y, date): 37 | return x == y 38 | else: 39 | x = x.year if isinstance(x, date) else x 40 | y = y.year if isinstance(y, date) else y 41 | return x == y 42 | except: 43 | return False 44 | 45 | answer = truncate_float(answer) 46 | pred = truncate_float(pred) 47 | if equal_as_date(answer, pred): 48 | return True 49 | else: 50 | return answer == pred 51 | 52 | 53 | def load(f): 54 | data = [] 55 | for line in f: 56 | data.append(json.loads(line.strip())) 57 | return data 58 | def main(): 59 | gt_folder, pred_fn = sys.argv[1], sys.argv[2] 60 | 61 | gt_fn = os.path.join(gt_folder, 'test_answer.json') 62 | gt = json.load(open(gt_fn)) 63 | pred = [x.strip() for x in open(pred_fn).readlines()] # one prediction per line 64 | train_set = json.load(open(os.path.join(gt_folder, 'train.json'))) 65 | train_answer_set = set(x['answer'] for x in train_set) 66 | 67 | labels = ['overall', 'multihop', 'qualifier', 'comparison', 'logical', 'count', 'verify', 'zero-shot'] 68 | total = {k:0 for k in labels} 69 | correct = {k:0 for k in labels} 70 | for i in tqdm(range(len(pred))): 71 | cur_labels = ['overall'] 72 | functions = [f['function'] for f in gt[i]['program']] 73 | 74 | for f in functions: 75 | if f in {'Relate'} or f.startswith('Filter'): 76 | cur_labels.append('multihop') 77 | break 78 | for f in functions: 79 | if f in {'QFilterStr', 'QFilterNum', 'QFilterYear', 'QFilterDate', 'QueryAttrUnderCondition', 'QueryAttrQualifier', 'QueryRelationQualifier'}: 80 | cur_labels.append('qualifier') 81 | break 82 | for f in functions: 83 | if f in {'SelectBetween','SelectAmong'}: 84 | cur_labels.append('comparison') 85 | break 86 | for f in functions: 87 | if f in {'And', 'Or'}: 88 | cur_labels.append('logical') 89 | break 90 | for f in functions: 91 | if f in {'Count'}: 92 | cur_labels.append('count') 93 | break 94 | for f in functions: 95 | if f in {'VerifyStr','VerifyNum','VerifyYear','VerifyDate'}: 96 | cur_labels.append('verify') 97 | break 98 | 99 | answer = gt[i]['answer'] 100 | if answer not in train_answer_set: 101 | cur_labels.append('zero-shot') 102 | 103 | if whether_equal(answer, pred[i]): 104 | for k in cur_labels: 105 | correct[k] += 1 106 | else: 107 | pass 108 | for k in cur_labels: 109 | total[k] += 1 110 | 111 | for k in labels: 112 | print('{}: {:.2f}% ({}/{})'.format(k, correct[k]/total[k]*100, correct[k], total[k])) 113 | if len(pred) < len(gt): 114 | print('WARNING: there are only {} predictions (need {})'.format(len(pred), len(gt))) 115 | 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /utils/BiGRU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class GRU(nn.Module): 5 | 6 | def __init__(self, dim_word, dim_h, num_layers, dropout): 7 | super().__init__() 8 | self.encoder = nn.GRU(input_size=dim_word, 9 | hidden_size=dim_h, 10 | num_layers=num_layers, 11 | dropout=dropout, 12 | batch_first=True, 13 | bidirectional=False) 14 | 15 | def forward_one_step(self, input, last_h): 16 | """ 17 | Args: 18 | - input (bsz, 1, w_dim) 19 | - last_h (num_layers, bsz, h_dim) 20 | """ 21 | hidden, new_h = self.encoder(input, last_h) 22 | return hidden, new_h # (bsz, 1, h_dim), (num_layers, bsz, h_dim) 23 | 24 | 25 | def generate_sequence(self, word_lookup_func, h_0, classifier, vocab, max_step, early_stop=True): 26 | bsz = h_0.size(1) 27 | device = h_0.device 28 | start_id, end_id, pad_id = vocab[''], vocab[''], vocab[''] 29 | 30 | latest = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ] 31 | results = [latest] 32 | last_h = h_0 33 | finished = torch.zeros((bsz,)).bool().to(device) # record whether is produced 34 | for i in range(max_step-1): # exclude 35 | word_emb = word_lookup_func(latest).unsqueeze(1) # [bsz, 1, dim_w] 36 | word_h, last_h = self.forward_one_step(word_emb, last_h) # [bsz, 1, dim_h] 37 | 38 | logit = classifier(word_h).squeeze(1) # [bsz, num_func] 39 | latest = torch.argmax(logit, dim=1).long() # [bsz, ] 40 | latest[finished] = pad_id # set to after 41 | results.append(latest) 42 | 43 | finished = finished | latest.eq(end_id).bool() 44 | if early_stop and finished.sum().item() == bsz: 45 | # print('finished at step {}'.format(i)) 46 | break 47 | results = torch.stack(results, dim=1) # [bsz, max_len'] 48 | return results 49 | 50 | 51 | def forward(self, input, length, h_0=None): 52 | """ 53 | Args: 54 | - input (bsz, len, w_dim) 55 | - length (bsz, ) 56 | - h_0 (num_layers, bsz, h_dim) 57 | Return: 58 | - hidden (bsz, len, dim) : hidden state of each word 59 | - output (bsz, dim) : sentence embedding 60 | """ 61 | bsz, max_len = input.size(0), input.size(1) 62 | sorted_seq_lengths, indices = torch.sort(length, descending=True) 63 | _, desorted_indices = torch.sort(indices, descending=False) 64 | input = input[indices] 65 | packed_input = nn.utils.rnn.pack_padded_sequence(input, sorted_seq_lengths, batch_first=True) 66 | if h_0 is None: 67 | hidden, h_n = self.encoder(packed_input) 68 | else: 69 | h_0 = h_0[:, indices] 70 | hidden, h_n = self.encoder(packed_input, h_0) 71 | # h_n is (num_layers, bsz, h_dim) 72 | hidden = nn.utils.rnn.pad_packed_sequence(hidden, batch_first=True, total_length=max_len)[0] # (bsz, max_len, h_dim) 73 | 74 | output = h_n[-1, :, :] # (bsz, h_dim), take the last layer's state 75 | 76 | # recover order 77 | hidden = hidden[desorted_indices] 78 | output = output[desorted_indices] 79 | h_n = h_n[:, desorted_indices] 80 | return hidden, output, h_n 81 | 82 | 83 | 84 | class BiGRU(nn.Module): 85 | 86 | def __init__(self, dim_word, dim_h, num_layers, dropout): 87 | super().__init__() 88 | self.encoder = nn.GRU(input_size=dim_word, 89 | hidden_size=dim_h//2, 90 | num_layers=num_layers, 91 | dropout=dropout, 92 | batch_first=True, 93 | bidirectional=True) 94 | 95 | def forward(self, input, length): 96 | """ 97 | Args: 98 | - input (bsz, len, w_dim) 99 | - length (bsz, ) 100 | Return: 101 | - hidden (bsz, len, dim) : hidden state of each word 102 | - output (bsz, dim) : sentence embedding 103 | - h_n (num_layers * 2, bsz, dim//2) 104 | """ 105 | bsz, max_len = input.size(0), input.size(1) 106 | sorted_seq_lengths, indices = torch.sort(length, descending=True) 107 | _, desorted_indices = torch.sort(indices, descending=False) 108 | input = input[indices] 109 | packed_input = nn.utils.rnn.pack_padded_sequence(input, sorted_seq_lengths, batch_first=True) 110 | hidden, h_n = self.encoder(packed_input) 111 | # h_n is (num_layers * num_directions, bsz, h_dim//2) 112 | hidden = nn.utils.rnn.pad_packed_sequence(hidden, batch_first=True, total_length=max_len)[0] # (bsz, max_len, h_dim) 113 | 114 | output = h_n[-2:, :, :] # (2, bsz, h_dim//2), take the last layer's state 115 | output = output.permute(1, 0, 2).contiguous().view(bsz, -1) # (bsz, h_dim), merge forward and backward h_n 116 | 117 | # recover order 118 | hidden = hidden[desorted_indices] 119 | output = output[desorted_indices] 120 | h_n = h_n[:, desorted_indices] 121 | return hidden, output, h_n 122 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter, deque 2 | import torch 3 | import json 4 | import pickle 5 | import numpy as np 6 | import torch.nn as nn 7 | import random 8 | import os 9 | import time 10 | ###################################################### 11 | ##################### used in SRN #################### 12 | START_RELATION = 'START_RELATION' 13 | NO_OP_RELATION = 'NO_OP_RELATION' 14 | NO_OP_ENTITY = 'NO_OP_ENTITY' 15 | DUMMY_RELATION = 'DUMMY_RELATION' 16 | DUMMY_ENTITY = 'DUMMY_ENTITY' 17 | 18 | DUMMY_RELATION_ID = 0 19 | START_RELATION_ID = 1 20 | NO_OP_RELATION_ID = 2 21 | DUMMY_ENTITY_ID = 0 22 | NO_OP_ENTITY_ID = 1 23 | 24 | EPSILON = float(np.finfo(float).eps) 25 | HUGE_INT = 1e31 26 | 27 | def format_path(path_trace, id2entity, id2relation): 28 | def get_most_recent_relation(j): 29 | relation_id = int(path_trace[j][0]) 30 | if relation_id == NO_OP_RELATION_ID: 31 | return '' 32 | else: 33 | return id2relation[relation_id] 34 | 35 | def get_most_recent_entity(j): 36 | return id2entity[int(path_trace[j][1])] 37 | 38 | path_str = get_most_recent_entity(0) 39 | for j in range(1, len(path_trace)): 40 | rel = get_most_recent_relation(j) 41 | if not rel.endswith('_inv'): 42 | path_str += ' -{}-> '.format(rel) 43 | else: 44 | path_str += ' <-{}- '.format(rel[:-4]) 45 | path_str += get_most_recent_entity(j) 46 | return path_str 47 | 48 | def pad_and_cat(a, padding_value, padding_dim=1): 49 | max_dim_size = max([x.size()[padding_dim] for x in a]) 50 | padded_a = [] 51 | for x in a: 52 | if x.size()[padding_dim] < max_dim_size: 53 | res_len = max_dim_size - x.size()[1] 54 | pad = nn.ConstantPad1d((0, res_len), padding_value) 55 | padded_a.append(pad(x)) 56 | else: 57 | padded_a.append(x) 58 | return torch.cat(padded_a, dim=0) 59 | 60 | def safe_log(x): 61 | return torch.log(x + EPSILON) 62 | 63 | def entropy(p): 64 | return torch.sum(- p * safe_log(p), 1) 65 | 66 | def init_word2id(): 67 | return { 68 | '': 0, 69 | '': 1, 70 | 'E_S': 2, 71 | } 72 | def init_entity2id(): 73 | return { 74 | DUMMY_ENTITY: DUMMY_ENTITY_ID, 75 | NO_OP_ENTITY: NO_OP_ENTITY_ID 76 | } 77 | def init_relation2id(): 78 | return { 79 | DUMMY_RELATION: DUMMY_RELATION_ID, 80 | START_RELATION: START_RELATION_ID, 81 | NO_OP_RELATION: NO_OP_RELATION_ID 82 | } 83 | 84 | def add_item_to_x2id(item, x2id): 85 | if not item in x2id: 86 | x2id[item] = len(x2id) 87 | 88 | def tile_along_beam(v, beam_size, dim=0): 89 | """ 90 | Tile a tensor along a specified dimension for the specified beam size. 91 | :param v: Input tensor. 92 | :param beam_size: Beam size. 93 | """ 94 | if dim == -1: 95 | dim = len(v.size()) - 1 96 | v = v.unsqueeze(dim + 1) 97 | v = torch.cat([v] * beam_size, dim=dim+1) 98 | new_size = [] 99 | for i, d in enumerate(v.size()): 100 | if i == dim + 1: 101 | new_size[-1] *= d 102 | else: 103 | new_size.append(d) 104 | return v.view(new_size) 105 | ##################### used in SRN #################### 106 | ###################################################### 107 | 108 | 109 | 110 | def init_vocab(): 111 | return { 112 | '': 0, 113 | '': 1, 114 | '': 2, 115 | '': 3 116 | } 117 | 118 | def invert_dict(d): 119 | return {v: k for k, v in d.items()} 120 | 121 | def load_glove(glove_pt, idx_to_token): 122 | glove = pickle.load(open(glove_pt, 'rb')) 123 | dim = len(glove['the']) 124 | matrix = [] 125 | for i in range(len(idx_to_token)): 126 | token = idx_to_token[i] 127 | tokens = token.split() 128 | if len(tokens) > 1: 129 | v = np.zeros((dim,)) 130 | for token in tokens: 131 | v = v + glove.get(token, glove['the']) 132 | v = v / len(tokens) 133 | else: 134 | v = glove.get(token, glove['the']) 135 | matrix.append(v) 136 | matrix = np.asarray(matrix) 137 | return matrix 138 | 139 | 140 | class SmoothedValue(object): 141 | """Track a series of values and provide access to smoothed values over a 142 | window or the global series average. 143 | """ 144 | 145 | def __init__(self, window_size=20): 146 | self.deque = deque(maxlen=window_size) 147 | self.series = [] 148 | self.total = 0.0 149 | self.count = 0 150 | 151 | def update(self, value): 152 | self.deque.append(value) 153 | self.series.append(value) 154 | self.count += 1 155 | self.total += value 156 | 157 | @property 158 | def median(self): 159 | d = torch.tensor(list(self.deque)) 160 | return d.median().item() 161 | 162 | @property 163 | def avg(self): 164 | d = torch.tensor(list(self.deque)) 165 | return d.mean().item() 166 | 167 | @property 168 | def global_avg(self): 169 | return self.total / self.count 170 | 171 | 172 | class MetricLogger(object): 173 | def __init__(self, delimiter="\t"): 174 | self.meters = defaultdict(SmoothedValue) 175 | self.delimiter = delimiter 176 | 177 | def update(self, **kwargs): 178 | for k, v in kwargs.items(): 179 | if isinstance(v, torch.Tensor): 180 | v = v.item() 181 | assert isinstance(v, (float, int)) 182 | self.meters[k].update(v) 183 | 184 | def __getattr__(self, attr): 185 | if attr in self.meters: 186 | return self.meters[attr] 187 | if attr in self.__dict__: 188 | return self.__dict__[attr] 189 | raise AttributeError("'{}' object has no attribute '{}'".format( 190 | type(self).__name__, attr)) 191 | 192 | def __str__(self): 193 | loss_str = [] 194 | for name, meter in self.meters.items(): 195 | loss_str.append( 196 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 197 | ) 198 | return self.delimiter.join(loss_str) 199 | 200 | 201 | def seed_everything(seed=1029): 202 | ''' 203 | 设置整个开发环境的seed 204 | :param seed: 205 | :param device: 206 | :return: 207 | ''' 208 | random.seed(seed) 209 | os.environ['PYTHONHASHSEED'] = str(seed) 210 | np.random.seed(seed) 211 | torch.manual_seed(seed) 212 | torch.cuda.manual_seed(seed) 213 | torch.cuda.manual_seed_all(seed) 214 | # some cudnn methods can be random even after fixing the seed 215 | # unless you tell it to be deterministic 216 | torch.backends.cudnn.deterministic = True 217 | 218 | 219 | class ProgressBar(object): 220 | ''' 221 | custom progress bar 222 | Example: 223 | >>> pbar = ProgressBar(n_total=30,desc='training') 224 | >>> step = 2 225 | >>> pbar(step=step) 226 | ''' 227 | def __init__(self, n_total,width=30,desc = 'Training'): 228 | self.width = width 229 | self.n_total = n_total 230 | self.start_time = time.time() 231 | self.desc = desc 232 | 233 | def __call__(self, step, info={}): 234 | now = time.time() 235 | current = step + 1 236 | recv_per = current / self.n_total 237 | bar = f'[{self.desc}] {current}/{self.n_total} [' 238 | if recv_per >= 1: 239 | recv_per = 1 240 | prog_width = int(self.width * recv_per) 241 | if prog_width > 0: 242 | bar += '=' * (prog_width - 1) 243 | if current< self.n_total: 244 | bar += ">" 245 | else: 246 | bar += '=' 247 | bar += '.' * (self.width - prog_width) 248 | bar += ']' 249 | show_bar = f"\r{bar}" 250 | time_per_unit = (now - self.start_time) / current 251 | if current < self.n_total: 252 | eta = time_per_unit * (self.n_total - current) 253 | if eta > 3600: 254 | eta_format = ('%d:%02d:%02d' % 255 | (eta // 3600, (eta % 3600) // 60, eta % 60)) 256 | elif eta > 60: 257 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 258 | else: 259 | eta_format = '%ds' % eta 260 | time_info = f' - ETA: {eta_format}' 261 | else: 262 | if time_per_unit >= 1: 263 | time_info = f' {time_per_unit:.1f}s/step' 264 | elif time_per_unit >= 1e-3: 265 | time_info = f' {time_per_unit * 1e3:.1f}ms/step' 266 | else: 267 | time_info = f' {time_per_unit * 1e6:.1f}us/step' 268 | 269 | show_bar += time_info 270 | if len(info) != 0: 271 | show_info = f'{show_bar} ' + \ 272 | "-".join([f' {key}: {value:.4f} ' for key, value in info.items()]) 273 | print(show_info, end='') 274 | else: 275 | print(show_bar, end='') -------------------------------------------------------------------------------- /utils/pickle_glove.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--input', required=True) 9 | parser.add_argument('--output', required=True) 10 | args = parser.parse_args() 11 | 12 | res = {} 13 | for line in tqdm(open(args.input, encoding="latin-1")): 14 | word, *vec = line.split() 15 | try: 16 | vec = np.asarray(list(map(float, vec))) 17 | res[word] = vec 18 | except: 19 | print("bad word") 20 | 21 | with open(args.output, 'wb') as f: 22 | pickle.dump(res, f) 23 | 24 | 25 | if __name__ == '__main__': 26 | main() 27 | -------------------------------------------------------------------------------- /utils/value_class.py: -------------------------------------------------------------------------------- 1 | def comp(a, b, op): 2 | """ 3 | Args: 4 | - a (ValueClass): attribute value of a certain entity 5 | - b (ValueClass): comparison target 6 | - op: =/>/': 21 | return a > b 22 | elif op == '!=': 23 | return a != b 24 | 25 | class ValueClass(): 26 | def __init__(self, type, value, unit=None): 27 | """ 28 | When type is 29 | - string, value is a str 30 | - quantity, value is a number and unit is required 31 | - year, value is a int 32 | - date, value is a date object 33 | """ 34 | self.type = type 35 | self.value = value 36 | self.unit = unit 37 | 38 | def isTime(self): 39 | return self.type in {'year', 'date'} 40 | 41 | def can_compare(self, other): 42 | if self.type == 'string': 43 | return other.type == 'string' 44 | elif self.type == 'quantity': 45 | # NOTE: for two quantity, they can compare only when they have the same unit 46 | return other.type == 'quantity' and other.unit == self.unit 47 | else: 48 | # year can compare with date 49 | return other.type == 'year' or other.type == 'date' 50 | 51 | def contains(self, other): 52 | """ 53 | check whether self contains other, which is different from __eq__ and the result is asymmetric 54 | used for conditions like whether 2001-01-01 in 2001, or whether 2001 in 2001-01-01 55 | """ 56 | if self.type == 'year': # year can contain year and date 57 | other_value = other.value if other.type == 'year' else other.value.year 58 | return self.value == other_value 59 | elif self.type == 'date': # date can only contain date 60 | return other.type == 'date' and self.value == other.value 61 | else: 62 | raise Exception('not supported type: %s' % self.type) 63 | 64 | 65 | def __eq__(self, other): 66 | """ 67 | 2001 and 2001-01-01 is not equal 68 | """ 69 | assert self.can_compare(other) 70 | return self.type == other.type and self.value == other.value 71 | 72 | def __lt__(self, other): 73 | """ 74 | Comparison between a year and a date will convert them both to year 75 | """ 76 | assert self.can_compare(other) 77 | if self.type == 'string': 78 | raise Exception('try to compare two string') 79 | elif self.type == 'quantity': 80 | return self.value < other.value 81 | elif self.type == 'year': 82 | other_value = other.value if other.type == 'year' else other.value.year 83 | return self.value < other_value 84 | elif self.type == 'date': 85 | if other.type == 'year': 86 | return self.value.year < other.value 87 | else: 88 | return self.value < other.value 89 | 90 | def __gt__(self, other): 91 | assert self.can_compare(other) 92 | if self.type == 'string': 93 | raise Exception('try to compare two string') 94 | elif self.type == 'quantity': 95 | return self.value > other.value 96 | elif self.type == 'year': 97 | other_value = other.value if other.type == 'year' else other.value.year 98 | return self.value > other_value 99 | elif self.type == 'date': 100 | if other.type == 'year': 101 | return self.value.year > other.value 102 | else: 103 | return self.value > other.value 104 | 105 | def __str__(self): 106 | if self.type == 'string': 107 | return self.value 108 | elif self.type == 'quantity': 109 | if self.value - int(self.value) < 1e-5: 110 | v = int(self.value) 111 | else: 112 | v = self.value 113 | return '{} {}'.format(v, self.unit) if self.unit != '1' else str(v) 114 | elif self.type == 'year': 115 | return str(self.value) 116 | elif self.type == 'date': 117 | return self.value.isoformat() 118 | --------------------------------------------------------------------------------