├── .gitignore
├── Bart_Program
├── README.md
├── data.py
├── predict.py
├── preprocess.py
└── train.py
├── Bart_SPARQL
├── README.md
├── data.py
├── predict.py
├── preprocess.py
├── sparql_engine.py
└── train.py
├── BlindGRU
├── README.md
├── data.py
├── model.py
├── predict.py
├── preprocess.py
└── train.py
├── KVMemNN
├── README.md
├── data.py
├── model.py
├── predict.py
├── preprocess.py
└── train.py
├── LICENSE
├── Program
├── data.py
├── executor_rule.py
├── parser.py
├── predict.py
├── preprocess.py
├── readme.md
└── train.py
├── README.md
├── RGCN
├── README.md
├── data.py
├── model.py
├── predict.py
├── preprocess.py
└── train.py
├── SPARQL
├── README.md
├── data.py
├── model.py
├── predict.py
├── preprocess.py
├── sparql_engine.py
└── train.py
├── SRN
├── data.py
├── input
│ └── pgrk.txt
├── knowledge_graph.py
├── model.py
├── predict.py
├── preprocess.py
├── readme.md
└── train.py
├── evaluate.py
└── utils
├── BiGRU.py
├── load_kb.py
├── lr_scheduler.py
├── misc.py
├── pickle_glove.py
└── value_class.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.sublime-workspace
2 | *.sublime-project
3 | test_dataset/
4 | dataset/
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | pip-wheel-metadata/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 |
58 | # Translations
59 | *.mo
60 | *.pot
61 |
62 | # Django stuff:
63 | *.log
64 | local_settings.py
65 | db.sqlite3
66 | db.sqlite3-journal
67 |
68 | # Flask stuff:
69 | instance/
70 | .webassets-cache
71 |
72 | # Scrapy stuff:
73 | .scrapy
74 |
75 | # Sphinx documentation
76 | docs/_build/
77 |
78 | # PyBuilder
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
99 | __pypackages__/
100 |
101 | # Celery stuff
102 | celerybeat-schedule
103 | celerybeat.pid
104 |
105 | # SageMath parsed files
106 | *.sage.py
107 |
108 | # Environments
109 | .env
110 | .venv
111 | env/
112 | venv/
113 | ENV/
114 | env.bak/
115 | venv.bak/
116 |
117 | # Spyder project settings
118 | .spyderproject
119 | .spyproject
120 |
121 | # Rope project settings
122 | .ropeproject
123 |
124 | # mkdocs documentation
125 | /site
126 |
127 | # mypy
128 | .mypy_cache/
129 | .dmypy.json
130 | dmypy.json
131 |
132 | # Pyre type checker
133 | .pyre/
134 |
--------------------------------------------------------------------------------
/Bart_Program/README.md:
--------------------------------------------------------------------------------
1 | ## Requirements
2 | - python3
3 | - pytorch==1.12.0
4 | - transformers==4.16.2
5 | - kopl==0.0.5
6 |
7 | ## How to run
8 | 1. Install the KoPL engine
9 | ```
10 | pip install kopl
11 | ```
12 | 2. Preprocess the training data, and copy the `./dataset/kb.json` into `output_dir`
13 | ```
14 | python -m Bart_Program.preprocess --input_dir ./dataset --output_dir
--model_name_or_path
15 | cp ./dataset/kb.json
16 | ```
17 | 3. Train
18 | ```
19 | python -m Bart_Program.train --input_dir --output_dir --save_dir --model_name_or_path
20 | ```
21 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order.
22 | ```
23 | python -m Bart_Program.predict --input_dir --save_dir --ckpt
24 | ```
25 |
26 | ## Checkpoints
27 | 1. The pretrained Bart-base checkpoint without finetuning can be downloaded here [bart-base](https://cloud.tsinghua.edu.cn/f/3b59ec6c43034cfc8841/?dl=1)
28 | 2. The checkpoint for finetuned Bart_Program can be downloaded here [finetuned](https://cloud.tsinghua.edu.cn/f/5b82ae04f9f64d1c8d1d/?dl=1)
29 |
30 | ## Change Log
31 |
32 | - [2022/8/8] Upload the evaluation.py; update the KoPL engine based on [KoPL](https://github.com/THU-KEG/KoPL); update kb.json in dataset;
33 |
34 | - A different serializer and add special token in the tokenizer. Note that the argument is for --model_name_or_path for Bart_Program.train
35 |
--------------------------------------------------------------------------------
/Bart_Program/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import torch
4 | from utils.misc import invert_dict
5 |
6 | def load_vocab(path):
7 | vocab = json.load(open(path))
8 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx'])
9 | return vocab
10 |
11 | def collate(batch):
12 | batch = list(zip(*batch))
13 | source_ids = torch.stack(batch[0])
14 | source_mask = torch.stack(batch[1])
15 | choices = torch.stack(batch[2])
16 | if batch[-1][0] is None:
17 | target_ids, answer = None, None
18 | else:
19 | target_ids = torch.stack(batch[3])
20 | answer = torch.cat(batch[4])
21 | return source_ids, source_mask, choices, target_ids, answer
22 |
23 |
24 | class Dataset(torch.utils.data.Dataset):
25 | def __init__(self, inputs):
26 | self.source_ids, self.source_mask, self.target_ids, self.choices, self.answers = inputs
27 | self.is_test = len(self.answers)==0
28 |
29 |
30 | def __getitem__(self, index):
31 | source_ids = torch.LongTensor(self.source_ids[index])
32 | source_mask = torch.LongTensor(self.source_mask[index])
33 | choices = torch.LongTensor(self.choices[index])
34 | if self.is_test:
35 | target_ids = None
36 | answer = None
37 | else:
38 | target_ids = torch.LongTensor(self.target_ids[index])
39 | answer = torch.LongTensor([self.answers[index]])
40 | return source_ids, source_mask, choices, target_ids, answer
41 |
42 |
43 | def __len__(self):
44 | return len(self.source_ids)
45 |
46 |
47 | class DataLoader(torch.utils.data.DataLoader):
48 | def __init__(self, vocab_json, question_pt, batch_size, training=False):
49 | vocab = load_vocab(vocab_json)
50 | if training:
51 | print('#vocab of answer: %d' % (len(vocab['answer_token_to_idx'])))
52 |
53 | inputs = []
54 | with open(question_pt, 'rb') as f:
55 | for _ in range(5):
56 | inputs.append(pickle.load(f))
57 | dataset = Dataset(inputs)
58 | # np.shuffle(dataset)
59 | # dataset = dataset[:(int)(len(dataset) / 10)]
60 | super().__init__(
61 | dataset,
62 | batch_size=batch_size,
63 | shuffle=training,
64 | collate_fn=collate,
65 | )
66 | self.vocab = vocab
--------------------------------------------------------------------------------
/Bart_Program/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import json
7 | from tqdm import tqdm
8 | from datetime import date
9 | from utils.misc import MetricLogger, seed_everything, ProgressBar
10 | from .data import DataLoader
11 | from transformers import BartConfig, BartForConditionalGeneration, BartTokenizer
12 | import torch.optim as optim
13 | import logging
14 | import time
15 | from utils.lr_scheduler import get_linear_schedule_with_warmup
16 | import re
17 | from kopl.kopl import KoPLEngine
18 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
19 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
20 | rootLogger = logging.getLogger()
21 | import warnings
22 | warnings.simplefilter("ignore") # hide warnings that caused by invalid sparql query
23 | from termcolor import colored
24 |
25 | def post_process(text):
26 | pattern = re.compile(r'".*?"')
27 | nes = []
28 | for item in pattern.finditer(text):
29 | nes.append((item.group(), item.span()))
30 | pos = [0]
31 | for name, span in nes:
32 | pos += [span[0], span[1]]
33 | pos.append(len(text))
34 | assert len(pos) % 2 == 0
35 | assert len(pos) / 2 == len(nes) + 1
36 | chunks = [text[pos[i]: pos[i+1]] for i in range(0, len(pos), 2)]
37 | for i in range(len(chunks)):
38 | chunks[i] = chunks[i].replace('?', ' ?').replace('.', ' .')
39 | bingo = ''
40 | for i in range(len(chunks) - 1):
41 | bingo += chunks[i] + nes[i][0]
42 | bingo += chunks[-1]
43 | return bingo
44 |
45 | def vis(args, kb, model, data, device, tokenizer):
46 | while True:
47 | text = input('Input your question:')
48 | with torch.no_grad():
49 | input_ids = tokenizer.batch_encode_plus([text], max_length = 512, pad_to_max_length = True, return_tensors="pt", truncation = True)
50 | source_ids = input_ids['input_ids'].to(device)
51 | outputs = model.generate(
52 | input_ids=source_ids,
53 | max_length = 500,
54 | )
55 | outputs = [tokenizer.decode(output_id, skip_special_tokens = True, clean_up_tokenization_spaces = True) for output_id in outputs]
56 | outputs = [post_process(output) for output in outputs]
57 | print(outputs[0])
58 |
59 | def predict(args, model, data, device, tokenizer, executor):
60 | model.eval()
61 | count, correct = 0, 0
62 | with torch.no_grad():
63 | all_outputs = []
64 | for batch in tqdm(data, total=len(data)):
65 | source_ids = batch[0].to(device)
66 | outputs = model.generate(
67 | input_ids=source_ids,
68 | max_length = 500,
69 | )
70 |
71 | all_outputs.extend(outputs.cpu().numpy())
72 |
73 | outputs = [tokenizer.decode(output_id, skip_special_tokens = True, clean_up_tokenization_spaces = True) for output_id in all_outputs]
74 | with open(os.path.join(args.save_dir, 'predict.txt'), 'w') as f:
75 |
76 | for output in tqdm(outputs):
77 | chunks = output.split('')
78 | func_list = []
79 | inputs_list = []
80 | for chunk in chunks:
81 | chunk = chunk.strip()
82 | res = chunk.split('')
83 | res = [_.strip() for _ in res]
84 | if len(res) > 0:
85 | func = res[0]
86 | inputs = []
87 | if len(res) > 1:
88 | for x in res[1:]:
89 | inputs.append(x)
90 | else:
91 | inputs = []
92 | func_list.append(func)
93 | inputs_list.append(inputs)
94 | ans = executor.forward(func_list, inputs_list, ignore_error = True)
95 | if ans is None:
96 | ans = 'no'
97 | if isinstance(ans, list) and len(ans) > 0:
98 | ans = ans[0]
99 | if isinstance(ans, list) and len(ans) == 0:
100 | ans = 'None'
101 | f.write(ans + '\n')
102 |
103 | def validate(model, data, device, tokenizer, executor):
104 | model.eval()
105 | count, correct = 0, 0
106 | with torch.no_grad():
107 | all_outputs = []
108 | all_answers = []
109 | for batch in tqdm(data, total=len(data)):
110 | source_ids, source_mask, choices, target_ids, answer = [x.to(device) for x in batch]
111 | outputs = model.generate(
112 | input_ids=source_ids,
113 | max_length = 500,
114 | )
115 |
116 | all_outputs.extend(outputs.cpu().numpy())
117 | all_answers.extend(answer.cpu().numpy())
118 |
119 | outputs = [tokenizer.decode(output_id, skip_special_tokens = True, clean_up_tokenization_spaces = True) for output_id in all_outputs]
120 | given_answer = [data.vocab['answer_idx_to_token'][a] for a in all_answers]
121 | for a, output in tqdm(zip(given_answer, outputs)):
122 | chunks = output.split('')
123 | func_list = []
124 | inputs_list = []
125 | for chunk in chunks:
126 | chunk = chunk.strip()
127 | res = chunk.split('')
128 | res = [_.strip() for _ in res]
129 | if len(res) > 0:
130 | func = res[0]
131 | inputs = []
132 | if len(res) > 1:
133 | for x in res[1:]:
134 | inputs.append(x)
135 | else:
136 | inputs = []
137 | func_list.append(func)
138 | inputs_list.append(inputs)
139 | ans = executor.forward(func_list, inputs_list, ignore_error = True)
140 | if ans is None:
141 | ans = 'no'
142 | if isinstance(ans, list) and len(ans) > 0:
143 | ans = ans[0]
144 | if ans == a:
145 | correct += 1
146 | count += 1
147 | acc = correct / count
148 | logging.info('acc: {}'.format(acc))
149 |
150 | return acc
151 |
152 |
153 |
154 | def train(args):
155 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
156 |
157 | logging.info("Create train_loader and val_loader.........")
158 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
159 | val_pt = os.path.join(args.input_dir, 'test.pt')
160 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size)
161 | logging.info("Create model.........")
162 | config_class, model_class, tokenizer_class = (BartConfig, BartForConditionalGeneration, BartTokenizer)
163 | tokenizer = tokenizer_class.from_pretrained(os.path.join(args.ckpt))
164 | model = model_class.from_pretrained(os.path.join(args.ckpt))
165 | model = model.to(device)
166 | logging.info(model)
167 | engine = KoPLEngine(json.load(open(os.path.join(args.input_dir, 'kb.json'))))
168 | # validate(model, val_loader, device, tokenizer, engine)
169 |
170 | predict(args, model, val_loader, device, tokenizer, engine)
171 | def main():
172 | parser = argparse.ArgumentParser()
173 | # input and output
174 | parser.add_argument('--input_dir', required=True)
175 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
176 | parser.add_argument('--ckpt', required=True)
177 |
178 | # training parameters
179 | parser.add_argument('--batch_size', default=256, type=int)
180 | parser.add_argument('--seed', type=int, default=666, help='random seed')
181 |
182 | # validating parameters
183 | # parser.add_argument('--num_return_sequences', default=1, type=int)
184 | # parser.add_argument('--top_p', default=)
185 | # model hyperparameters
186 | parser.add_argument('--dim_hidden', default=1024, type=int)
187 | parser.add_argument('--alpha', default = 1e-4, type = float)
188 | args = parser.parse_args()
189 |
190 | if not os.path.exists(args.save_dir):
191 | os.makedirs(args.save_dir)
192 | time_ = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
193 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, '{}.predict.log'.format(time_)))
194 | fileHandler.setFormatter(logFormatter)
195 | rootLogger.addHandler(fileHandler)
196 | # args display
197 | for k, v in vars(args).items():
198 | logging.info(k+':'+str(v))
199 |
200 | seed_everything(666)
201 |
202 | train(args)
203 |
204 |
205 | if __name__ == '__main__':
206 | main()
207 |
208 |
--------------------------------------------------------------------------------
/Bart_Program/preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | We need the last function to help extract the final answer of SPARQL, used in check_sparql
3 | """
4 |
5 | import os
6 | import json
7 | import pickle
8 | import argparse
9 | import numpy as np
10 | from nltk import word_tokenize
11 | from collections import Counter
12 | from itertools import chain
13 | from tqdm import tqdm
14 | import re
15 |
16 | from utils.misc import init_vocab
17 | from transformers import *
18 |
19 | new_tokens = ['', '']
20 |
21 | def get_program_seq(program):
22 | seq = []
23 | for item in program:
24 | func = item['function']
25 | inputs = item['inputs']
26 | args = ''
27 | for input in inputs:
28 | args += ' ' + input
29 | seq.append(func + args)
30 | seq = ' '.join(seq)
31 | return seq
32 |
33 | def encode_dataset(dataset, vocab, tokenizer, test = False):
34 | questions = []
35 | programs = []
36 | for item in tqdm(dataset):
37 | question = item['question']
38 | questions.append(question)
39 | if not test:
40 | program = item['program']
41 | program = get_program_seq(program)
42 | programs.append(program)
43 | sequences = questions + programs
44 | print('tokenizing')
45 | encoded_inputs = tokenizer(sequences, padding = True)
46 | print('tokenize ended.')
47 | print(encoded_inputs.keys())
48 | print(encoded_inputs['input_ids'][0])
49 | print(tokenizer.decode(encoded_inputs['input_ids'][0]))
50 | print(tokenizer.decode(encoded_inputs['input_ids'][-1]))
51 | max_seq_length = len(encoded_inputs['input_ids'][0])
52 | assert max_seq_length == len(encoded_inputs['input_ids'][-1])
53 | print(max_seq_length)
54 | questions = []
55 | programs = []
56 | choices = []
57 | answers = []
58 | for item in tqdm(dataset):
59 | question = item['question']
60 | questions.append(question)
61 | _ = [vocab['answer_token_to_idx'][w] for w in item['choices']]
62 | choices.append(_)
63 | if not test:
64 | program = item['program']
65 | program = get_program_seq(program)
66 | programs.append(program)
67 | answers.append(vocab['answer_token_to_idx'].get(item['answer']))
68 |
69 | input_ids = tokenizer.batch_encode_plus(questions, max_length = max_seq_length, pad_to_max_length = True, truncation = True)
70 | source_ids = np.array(input_ids['input_ids'], dtype = np.int32)
71 | source_mask = np.array(input_ids['attention_mask'], dtype = np.int32)
72 | if not test:
73 | target_ids = tokenizer.batch_encode_plus(programs, max_length = max_seq_length, pad_to_max_length = True, truncation = True)
74 | target_ids = np.array(target_ids['input_ids'], dtype = np.int32)
75 | else:
76 | target_ids = np.array([], dtype = np.int32)
77 | choices = np.array(choices, dtype = np.int32)
78 | answers = np.array(answers, dtype = np.int32)
79 | return source_ids, source_mask, target_ids, choices, answers
80 |
81 |
82 |
83 | def main():
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument('--input_dir', required=True)
86 | parser.add_argument('--output_dir', required=True)
87 | parser.add_argument('--model_name_or_path', required=True)
88 | args = parser.parse_args()
89 |
90 | print('Build kb vocabulary')
91 | vocab = {
92 | 'answer_token_to_idx': {}
93 | }
94 | print('Load questions')
95 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json')))
96 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json')))
97 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json')))
98 | for question in chain(train_set, val_set, test_set):
99 | for a in question['choices']:
100 | if not a in vocab['answer_token_to_idx']:
101 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
102 |
103 | if not os.path.isdir(args.output_dir):
104 | os.mkdir(args.output_dir)
105 | fn = os.path.join(args.output_dir, 'vocab.json')
106 | print('Dump vocab to {}'.format(fn))
107 | with open(fn, 'w') as f:
108 | json.dump(vocab, f, indent=2)
109 | for k in vocab:
110 | print('{}:{}'.format(k, len(vocab[k])))
111 | tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)
112 | for token in new_tokens:
113 | # NOTE: in some newer versions of transformers, the special_tokens needs to be set as False
114 | tokenizer.add_tokens(token, special_tokens = True)
115 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)):
116 | print('Encode {} set'.format(name))
117 | outputs = encode_dataset(dataset, vocab, tokenizer, name=='test')
118 | assert len(outputs) == 5
119 | print('shape of input_ids of questions, attention_mask of questions, input_ids of sparqls, choices and answers:')
120 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f:
121 | for o in outputs:
122 | print(o.shape)
123 | pickle.dump(o, f)
124 | if __name__ == '__main__':
125 | main()
126 |
--------------------------------------------------------------------------------
/Bart_Program/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import json
7 | from tqdm import tqdm
8 | from datetime import date
9 | from utils.misc import MetricLogger, seed_everything, ProgressBar
10 | from .data import DataLoader
11 | from transformers import BartConfig, BartForConditionalGeneration, BartTokenizer
12 | import torch.optim as optim
13 | import logging
14 | import time
15 | from utils.lr_scheduler import get_linear_schedule_with_warmup
16 | from Bart_Program.predict import validate
17 | from kopl.kopl import KoPLEngine
18 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
19 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
20 | rootLogger = logging.getLogger()
21 | import warnings
22 | warnings.simplefilter("ignore") # hide warnings that caused by invalid sparql query
23 |
24 | new_tokens = ['', '']
25 |
26 | def train(args):
27 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
28 |
29 | logging.info("Create train_loader and val_loader.........")
30 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
31 | train_pt = os.path.join(args.input_dir, 'train.pt')
32 | val_pt = os.path.join(args.input_dir, 'val.pt')
33 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True)
34 | val_loader = DataLoader(vocab_json, val_pt, 64)
35 |
36 | engine = KoPLEngine(json.load(open(os.path.join(args.input_dir, 'kb.json'))))
37 | logging.info("Create model.........")
38 | config_class, model_class, tokenizer_class = (BartConfig, BartForConditionalGeneration, BartTokenizer)
39 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
40 | model = model_class.from_pretrained(args.model_name_or_path)
41 | added_tokens_num = tokenizer.add_tokens(new_tokens, special_tokens = True)
42 | print('added_tokens_num:', added_tokens_num)
43 | if added_tokens_num > 0:
44 | model.resize_token_embeddings(len(tokenizer))
45 |
46 | model = model.to(device)
47 | logging.info(model)
48 | t_total = len(train_loader) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay)
49 | no_decay = ["bias", "LayerNorm.weight"]
50 | bart_param_optimizer = list(model.named_parameters())
51 | optimizer_grouped_parameters = [
52 | {'params': [p for n, p in bart_param_optimizer if not any(nd in n for nd in no_decay)],
53 | 'weight_decay': args.weight_decay, 'lr': args.learning_rate},
54 | {'params': [p for n, p in bart_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
55 | 'lr': args.learning_rate}
56 | ]
57 | args.warmup_steps = int(t_total * args.warmup_proportion)
58 | optimizer = optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
59 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
60 | num_training_steps=t_total)
61 | # Check if saved optimizer or scheduler states exist
62 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
63 | os.path.join(args.model_name_or_path, "scheduler.pt")):
64 | # Load in optimizer and scheduler states
65 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
66 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
67 |
68 | # Train!
69 | logging.info("***** Running training *****")
70 | logging.info(" Num examples = %d", len(train_loader.dataset))
71 | logging.info(" Num Epochs = %d", args.num_train_epochs)
72 | logging.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
73 | logging.info(" Total optimization steps = %d", t_total)
74 |
75 | global_step = 0
76 | steps_trained_in_current_epoch = 0
77 | # Check if continuing training from a checkpoint
78 | if os.path.exists(args.model_name_or_path) and "checkpoint" in args.model_name_or_path:
79 | # set global_step to gobal_step of last saved checkpoint from model path
80 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
81 | epochs_trained = global_step // (len(train_loader) // args.gradient_accumulation_steps)
82 | steps_trained_in_current_epoch = global_step % (len(train_loader) // args.gradient_accumulation_steps)
83 | logging.info(" Continuing training from checkpoint, will skip to saved global_step")
84 | logging.info(" Continuing training from epoch %d", epochs_trained)
85 | logging.info(" Continuing training from global step %d", global_step)
86 | logging.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
87 | logging.info('Checking...')
88 | logging.info("===================Dev==================")
89 | validate(model, val_loader, device, tokenizer, engine)
90 | tr_loss, logging_loss = 0.0, 0.0
91 | model.zero_grad()
92 | for _ in range(int(args.num_train_epochs)):
93 | pbar = ProgressBar(n_total=len(train_loader), desc='Training')
94 | for step, batch in enumerate(train_loader):
95 | # Skip past any already trained steps if resuming training
96 | if steps_trained_in_current_epoch > 0:
97 | steps_trained_in_current_epoch -= 1
98 | continue
99 | model.train()
100 | batch = tuple(t.to(device) for t in batch)
101 | pad_token_id = tokenizer.pad_token_id
102 | source_ids, source_mask, y = batch[0], batch[1], batch[-2]
103 | y_ids = y[:, :-1].contiguous()
104 | lm_labels = y[:, 1:].clone()
105 | lm_labels[y[:, 1:] == pad_token_id] = -100
106 |
107 | inputs = {
108 | "input_ids": source_ids.to(device),
109 | "attention_mask": source_mask.to(device),
110 | "decoder_input_ids": y_ids.to(device),
111 | "labels": lm_labels.to(device),
112 | }
113 | outputs = model(**inputs)
114 | loss = outputs[0]
115 | loss.backward()
116 | pbar(step, {'loss': loss.item()})
117 | tr_loss += loss.item()
118 | if (step + 1) % args.gradient_accumulation_steps == 0:
119 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
120 | optimizer.step()
121 | scheduler.step() # Update learning rate schedule
122 | model.zero_grad()
123 | global_step += 1
124 | validate(model, val_loader, device, tokenizer, engine)
125 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
126 | if not os.path.exists(output_dir):
127 | os.makedirs(output_dir)
128 | model_to_save = (
129 | model.module if hasattr(model, "module") else model
130 | ) # Take care of distributed/parallel training
131 | model_to_save.save_pretrained(output_dir)
132 | tokenizer.save_pretrained(output_dir)
133 | torch.save(args, os.path.join(output_dir, "training_args.bin"))
134 | logging.info("Saving model checkpoint to %s", output_dir)
135 | # tokenizer.save_vocabulary(output_dir)
136 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
137 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
138 | logging.info("Saving optimizer and scheduler states to %s", output_dir)
139 | logging.info("\n")
140 | if 'cuda' in str(device):
141 | torch.cuda.empty_cache()
142 | return global_step, tr_loss / global_step
143 |
144 |
145 | def main():
146 | parser = argparse.ArgumentParser()
147 | # input and output
148 | parser.add_argument('--input_dir', required=True)
149 | parser.add_argument('--output_dir', required=True)
150 |
151 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
152 | parser.add_argument('--model_name_or_path', required = True)
153 | parser.add_argument('--ckpt')
154 |
155 | # training parameters
156 | parser.add_argument('--weight_decay', default=1e-5, type=float)
157 | parser.add_argument('--batch_size', default=16, type=int)
158 | parser.add_argument('--seed', type=int, default=666, help='random seed')
159 | parser.add_argument('--learning_rate', default=3e-5, type = float)
160 | parser.add_argument('--num_train_epochs', default=25, type = int)
161 | parser.add_argument('--save_steps', default=448, type = int)
162 | parser.add_argument('--logging_steps', default=448, type = int)
163 | parser.add_argument('--warmup_proportion', default=0.1, type = float,
164 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training.")
165 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
166 | help="Epsilon for Adam optimizer.")
167 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
168 | help="Number of updates steps to accumulate before performing a backward/update pass.", )
169 | parser.add_argument("--max_grad_norm", default=1.0, type=float,
170 | help="Max gradient norm.")
171 |
172 | # validating parameters
173 | # parser.add_argument('--num_return_sequences', default=1, type=int)
174 | # parser.add_argument('--top_p', default=)
175 | # model hyperparameters
176 | parser.add_argument('--dim_hidden', default=1024, type=int)
177 | parser.add_argument('--alpha', default = 1e-4, type = float)
178 | args = parser.parse_args()
179 |
180 | if not os.path.exists(args.save_dir):
181 | os.makedirs(args.save_dir)
182 | time_ = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
183 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, '{}.log'.format(time_)))
184 | fileHandler.setFormatter(logFormatter)
185 | rootLogger.addHandler(fileHandler)
186 | # args display
187 | for k, v in vars(args).items():
188 | logging.info(k+':'+str(v))
189 |
190 | seed_everything(666)
191 |
192 | train(args)
193 |
194 |
195 | if __name__ == '__main__':
196 | main()
197 |
198 |
--------------------------------------------------------------------------------
/Bart_SPARQL/README.md:
--------------------------------------------------------------------------------
1 | ## Requirements
2 | - python3.7
3 | - rdflib=4.2.2 or 6.1.1
4 | - transformers
5 | ---
6 | **Note for rdflib 4.2.2:**
7 | After installing rdflib via `pip` or `anaconda` or some other tools, we need to fix some bugs of it.
8 |
9 | First, find your rdflib location. One possible way is to run following codes in ipython
10 | ```
11 | import rdflib
12 | rdflib.__file__
13 | ```
14 | It returns `~/anaconda3/lib/python3.7/site-packages/rdflib/__init__.py` in my computer, so I enter the folder `~/anaconda3/lib/python3.7/site-packages/rdflib`.
15 |
16 | Then open `plugins/sparql/parser.py`, find *Line 68*, replace its code with
17 | ```
18 | if i + 1 < l and (not isinstance(terms[i + 1], str) or terms[i + 1] not in ".,;"):
19 | ```
20 | Remember to keep the original indentation.
21 | Note that *Line 67* is a comment of `# is this bnode the subject of more triplets?`. If your line number is different from mine, you could locate the target line by this comment.
22 |
23 | Finally, open `plugins/serializers/turtle.py`, find *Line 328*, change `use_plain=True` to `use_plain=False`
24 |
25 |
26 | **Note for rdflib 6.1.1:**
27 | If you have an erro " can't set attribute" with rdflib=4.2.2,you should try rdflib=6.1.1 .
28 |
29 | ---
30 |
31 | - SPARQLWrapper=1.8.4
32 |
33 | ---
34 | **Note:**
35 | When installing `SPARQLWrapper` with `pip`, it may automatically install another package `keepalive`. You can check whether it is in your environment by
36 | ```
37 | pip show keepalive
38 | ```
39 |
40 | If it is installed, it will cause some problems when we execute a large number of SPARQL queries. Specifically, the available ports will be used out. So we need to manually disable the `keepalive` package. It is okay to directly remove it.
41 | ```
42 | pip uninstall keepalive
43 | ```
44 |
45 | ---
46 |
47 | - Virtuoso backend, refer to the next section
48 |
49 | ## How to install virtuoso backend
50 | The virtuoso backend will start up a web service, we can import our kb into it and then execute SPARQL queries by network requests. We install virtuoso in an Ubuntu 16.04 system. Following are specific steps.
51 |
52 | 1. Download and install virtuoso into our system.
53 | ```
54 | git clone https://github.com/openlink/virtuoso-opensource.git Virtuoso-Opensource
55 | cd Virtuoso-Opensource
56 | git checkout stable/7
57 | sudo apt-get install libtool gawk gperf autoconf automake libtool flex bison m4 make openssl libssl-dev
58 | sudo ./autogen.sh
59 | sudo ./configure
60 | sudo make
61 | sudo make install
62 | ```
63 |
64 | 2. Create a new user for virtuoso service
65 | ```
66 | sudo useradd virtuoso --home /usr/local/virtuoso-opensource
67 | sudo chown -R virtuoso /usr/local/virtuoso-opensource
68 | ```
69 |
70 | 3. Modify some necessary configs:
71 | ```
72 | cd /usr/local/virtuoso-opensource/var/lib/virtuoso/db
73 | sudo vim virtuoso.ini
74 | ```
75 | Find the item `CheckpointInterval`, and change its value from default 60 to 0, to avoid automatical checkpoint process which will cause 404 error.
76 |
77 | 4. Start up the virtuoso service:
78 | ```
79 | sudo -H -u virtuoso ../../../../bin/virtuoso-t -f &
80 | ```
81 | Now you can access the service via the default port 8890.
82 | Enter `[ip]:8890` in a browser, you will see the virtuoso service page.
83 |
84 | [note] The virtuoso may report an erro "There is no configuration file virtuoso.ini" when start up.
85 | ```
86 | sudo vim /etc/rc.conf
87 | ```
88 | Add a line:`virtuoso_config="/usr/local/virtuoso-opensource/var/lib/virtuoso/db/virtuoso.ini"`
89 |
90 |
91 | 5. Now we can import our kb into virtuoso. Before that, we need to convert our kb to `ttl` format and move it to proper position:
92 | ```
93 | python -m Bart_SPARQL.sparql_engine --kb_path dataset/kb.json --ttl_path dataset/kb.ttl
94 | sudo chmod 777 dataset/kb.ttl
95 | sudo mv dataset/kb.ttl /usr/local/virtuoso-opensource/share/virtuoso/vad
96 | ```
97 |
98 | 6. Enter the interactive terminal of virtuoso:
99 | ```
100 | cd /usr/local/virtuoso-opensource/bin
101 | sudo ./isql
102 | ```
103 |
104 | 7. Import our kb by executing these commands in terminal:
105 | ```
106 | SPARQL CREATE GRAPH <[graph_name]>;
107 | SPARQL CLEAR GRAPH <[graph_name]>;
108 | delete from db.dba.load_list;
109 | ld_dir('/usr/local/virtuoso-opensource/share/virtuoso/vad', 'kb.ttl', '[graph_name]');
110 | rdf_loader_run();
111 | select * from DB.DBA.load_list;
112 | exit;
113 | ```
114 | `[graph_name]` could be any legal string, such as *KQAPro*.
115 | You are success if `rdf_loader_run()` lasts for about 10 seconds.
116 |
117 |
118 | ## How to run
119 | 1. Follow the last section, start up the virtuoso service and import `kb.ttl`. Then you need to open `sparql_engine.py` and find the lines of
120 | ```
121 | virtuoso_address = "http://127.0.0.1:8890/sparql"
122 | virtuoso_graph_uri = 'sjx'
123 | ```
124 | Change `virtuoso_address` to your service url (you can visit it in your browser to check whether it is valid) and change `virtuoso_graph_uri` to your ``.
125 | 2. Preprocess the training data
126 | ```
127 | python -m Bart_SPARQL.preprocess --input_dir ./dataset --output_dir --model_name_or_path
128 | cp ./dataset/kb.json
129 | ```
130 | 3. Train
131 | ```
132 | python -m Bart_SPARQL.train --input_dir --output_dir --model_name_or_path --save_dir
133 | ```
134 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order.
135 | ```
136 | python -m Bart_SPARQL.predict --input_dir --ckpt --save_dir
137 |
138 | ```
139 |
140 | ## Checkpoints
141 | 1. The pretrained Bart-base checkpoint without finetuning can be downloaded here [bart-base](https://cloud.tsinghua.edu.cn/f/3b59ec6c43034cfc8841/?dl=1)
142 | 2. The checkpoint for finetuned Bart_SPARQL can be downloaded here [finetuned](https://cloud.tsinghua.edu.cn/f/1b9746dcd96b4fca870d/?dl=1)
143 |
--------------------------------------------------------------------------------
/Bart_SPARQL/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import torch
4 | from utils.misc import invert_dict
5 |
6 | def load_vocab(path):
7 | vocab = json.load(open(path))
8 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx'])
9 | return vocab
10 |
11 | def collate(batch):
12 | batch = list(zip(*batch))
13 | source_ids = torch.stack(batch[0])
14 | source_mask = torch.stack(batch[1])
15 | choices = torch.stack(batch[2])
16 | if batch[-1][0] is None:
17 | target_ids, answer = None, None
18 | else:
19 | target_ids = torch.stack(batch[3])
20 | answer = torch.cat(batch[4])
21 | return source_ids, source_mask, choices, target_ids, answer
22 |
23 |
24 | class Dataset(torch.utils.data.Dataset):
25 | def __init__(self, inputs):
26 | self.source_ids, self.source_mask, self.target_ids, self.choices, self.answers = inputs
27 | self.is_test = len(self.answers)==0
28 |
29 |
30 | def __getitem__(self, index):
31 | source_ids = torch.LongTensor(self.source_ids[index])
32 | source_mask = torch.LongTensor(self.source_mask[index])
33 | choices = torch.LongTensor(self.choices[index])
34 | if self.is_test:
35 | target_ids = None
36 | answer = None
37 | else:
38 | target_ids = torch.LongTensor(self.target_ids[index])
39 | answer = torch.LongTensor([self.answers[index]])
40 | return source_ids, source_mask, choices, target_ids, answer
41 |
42 |
43 | def __len__(self):
44 | return len(self.source_ids)
45 |
46 |
47 | class DataLoader(torch.utils.data.DataLoader):
48 | def __init__(self, vocab_json, question_pt, batch_size, training=False):
49 | vocab = load_vocab(vocab_json)
50 | if training:
51 | print('#vocab of answer: %d' % (len(vocab['answer_token_to_idx'])))
52 |
53 | inputs = []
54 | with open(question_pt, 'rb') as f:
55 | for _ in range(5):
56 | inputs.append(pickle.load(f))
57 | dataset = Dataset(inputs)
58 |
59 | super().__init__(
60 | dataset,
61 | batch_size=batch_size,
62 | shuffle=training,
63 | collate_fn=collate,
64 | )
65 | self.vocab = vocab
--------------------------------------------------------------------------------
/Bart_SPARQL/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import argparse
5 | import numpy as np
6 | from nltk import word_tokenize
7 | from collections import Counter
8 | from itertools import chain
9 | from tqdm import tqdm
10 | import re
11 |
12 | from utils.misc import init_vocab
13 | from transformers import *
14 |
15 |
16 |
17 | def encode_dataset(dataset, vocab, tokenizer, test = False):
18 | questions = []
19 | sparqls = []
20 | for item in tqdm(dataset):
21 | question = item['question']
22 | questions.append(question)
23 | if not test:
24 | sparql = item['sparql']
25 | sparqls.append(sparql)
26 | sequences = questions + sparqls
27 | encoded_inputs = tokenizer(sequences, padding = True)
28 | print(encoded_inputs.keys())
29 | print(encoded_inputs['input_ids'][0])
30 | print(tokenizer.decode(encoded_inputs['input_ids'][0]))
31 | print(tokenizer.decode(encoded_inputs['input_ids'][-1]))
32 | max_seq_length = len(encoded_inputs['input_ids'][0])
33 | assert max_seq_length == len(encoded_inputs['input_ids'][-1])
34 | print(max_seq_length)
35 | questions = []
36 | sparqls = []
37 | choices = []
38 | answers = []
39 | for item in tqdm(dataset):
40 | question = item['question']
41 | questions.append(question)
42 | _ = [vocab['answer_token_to_idx'][w] for w in item['choices']]
43 | choices.append(_)
44 | if not test:
45 | sparql = item['sparql']
46 | sparqls.append(sparql)
47 | answers.append(vocab['answer_token_to_idx'].get(item['answer']))
48 |
49 | input_ids = tokenizer.batch_encode_plus(questions, max_length = max_seq_length, pad_to_max_length = True, truncation = True)
50 | source_ids = np.array(input_ids['input_ids'], dtype = np.int32)
51 | source_mask = np.array(input_ids['attention_mask'], dtype = np.int32)
52 | if not test:
53 | target_ids = tokenizer.batch_encode_plus(sparqls, max_length = max_seq_length, pad_to_max_length = True, truncation = True)
54 | target_ids = np.array(target_ids['input_ids'], dtype = np.int32)
55 | else:
56 | target_ids = np.array([], dtype = np.int32)
57 | choices = np.array(choices, dtype = np.int32)
58 | answers = np.array(answers, dtype = np.int32)
59 | return source_ids, source_mask, target_ids, choices, answers
60 |
61 |
62 |
63 | def main():
64 | parser = argparse.ArgumentParser()
65 | parser.add_argument('--input_dir', required=True)
66 | parser.add_argument('--output_dir', required=True)
67 | parser.add_argument('--model_name_or_path', required=True)
68 | args = parser.parse_args()
69 |
70 | print('Build kb vocabulary')
71 | vocab = {
72 | 'answer_token_to_idx': {}
73 | }
74 | print('Load questions')
75 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json')))
76 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json')))
77 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json')))
78 | for question in chain(train_set, val_set, test_set):
79 | for a in question['choices']:
80 | if not a in vocab['answer_token_to_idx']:
81 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
82 |
83 | if not os.path.isdir(args.output_dir):
84 | os.mkdir(args.output_dir)
85 | fn = os.path.join(args.output_dir, 'vocab.json')
86 | print('Dump vocab to {}'.format(fn))
87 | with open(fn, 'w') as f:
88 | json.dump(vocab, f, indent=2)
89 | for k in vocab:
90 | print('{}:{}'.format(k, len(vocab[k])))
91 | tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)
92 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)):
93 | print('Encode {} set'.format(name))
94 | outputs = encode_dataset(dataset, vocab, tokenizer, name=='test')
95 | assert len(outputs) == 5
96 | print('shape of input_ids of questions, attention_mask of questions, input_ids of sparqls, choices and answers:')
97 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f:
98 | for o in outputs:
99 | print(o.shape)
100 | pickle.dump(o, f)
101 | if __name__ == '__main__':
102 | main()
--------------------------------------------------------------------------------
/Bart_SPARQL/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import shutil
7 | import json
8 | from tqdm import tqdm
9 | from datetime import date
10 | from utils.misc import MetricLogger, seed_everything, ProgressBar
11 | from utils.load_kb import DataForSPARQL
12 | from .data import DataLoader
13 | from transformers import BartConfig, BartForConditionalGeneration, BartTokenizer
14 | from .sparql_engine import get_sparql_answer
15 | import torch.optim as optim
16 | import logging
17 | import time
18 | from utils.lr_scheduler import get_linear_schedule_with_warmup
19 |
20 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
21 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
22 | rootLogger = logging.getLogger()
23 | import warnings
24 | warnings.simplefilter("ignore") # hide warnings that caused by invalid sparql query
25 |
26 |
27 |
28 |
29 | def train(args):
30 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
31 |
32 | logging.info("Create train_loader and val_loader.........")
33 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
34 | train_pt = os.path.join(args.input_dir, 'train.pt')
35 | val_pt = os.path.join(args.input_dir, 'val.pt')
36 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True)
37 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size)
38 |
39 | vocab = train_loader.vocab
40 |
41 | logging.info("Create model.........")
42 | config_class, model_class, tokenizer_class = (BartConfig, BartForConditionalGeneration, BartTokenizer)
43 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
44 | model = model_class.from_pretrained(args.model_name_or_path)
45 | model = model.to(device)
46 | logging.info(model)
47 | t_total = len(train_loader) // args.gradient_accumulation_steps * args.num_train_epochs # Prepare optimizer and schedule (linear warmup and decay)
48 | no_decay = ["bias", "LayerNorm.weight"]
49 | bart_param_optimizer = list(model.named_parameters())
50 | optimizer_grouped_parameters = [
51 | {'params': [p for n, p in bart_param_optimizer if not any(nd in n for nd in no_decay)],
52 | 'weight_decay': args.weight_decay, 'lr': args.learning_rate},
53 | {'params': [p for n, p in bart_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0,
54 | 'lr': args.learning_rate}
55 | ]
56 | args.warmup_steps = int(t_total * args.warmup_proportion)
57 | optimizer = optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
58 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
59 | num_training_steps=t_total)
60 | # Check if saved optimizer or scheduler states exist
61 | if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
62 | os.path.join(args.model_name_or_path, "scheduler.pt")):
63 | # Load in optimizer and scheduler states
64 | optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
65 | scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
66 |
67 | # Train!
68 | logging.info("***** Running training *****")
69 | logging.info(" Num examples = %d", len(train_loader.dataset))
70 | logging.info(" Num Epochs = %d", args.num_train_epochs)
71 | logging.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
72 | logging.info(" Total optimization steps = %d", t_total)
73 |
74 | global_step = 0
75 | steps_trained_in_current_epoch = 0
76 | # Check if continuing training from a checkpoint
77 | if os.path.exists(args.model_name_or_path) and "checkpoint" in args.model_name_or_path:
78 | # set global_step to gobal_step of last saved checkpoint from model path
79 | global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
80 | epochs_trained = global_step // (len(train_loader) // args.gradient_accumulation_steps)
81 | steps_trained_in_current_epoch = global_step % (len(train_loader) // args.gradient_accumulation_steps)
82 | logging.info(" Continuing training from checkpoint, will skip to saved global_step")
83 | logging.info(" Continuing training from epoch %d", epochs_trained)
84 | logging.info(" Continuing training from global step %d", global_step)
85 | logging.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
86 | logging.info('Checking...')
87 | logging.info("===================Dev==================")
88 | # evaluate(args, model, val_loader, device)
89 | tr_loss, logging_loss = 0.0, 0.0
90 | model.zero_grad()
91 | for _ in range(int(args.num_train_epochs)):
92 | pbar = ProgressBar(n_total=len(train_loader), desc='Training')
93 | for step, batch in enumerate(train_loader):
94 | # Skip past any already trained steps if resuming training
95 | if steps_trained_in_current_epoch > 0:
96 | steps_trained_in_current_epoch -= 1
97 | continue
98 | model.train()
99 | batch = tuple(t.to(device) for t in batch)
100 | pad_token_id = tokenizer.pad_token_id
101 | source_ids, source_mask, y = batch[0], batch[1], batch[-2]
102 | y_ids = y[:, :-1].contiguous()
103 | lm_labels = y[:, 1:].clone()
104 | lm_labels[y[:, 1:] == pad_token_id] = -100
105 |
106 | inputs = {
107 | "input_ids": source_ids.to(device),
108 | "attention_mask": source_mask.to(device),
109 | "decoder_input_ids": y_ids.to(device),
110 | "labels": lm_labels.to(device)
111 | }
112 | outputs = model(**inputs)
113 | loss = outputs[0]
114 | loss.backward()
115 | pbar(step, {'loss': loss.item()})
116 | tr_loss += loss.item()
117 | if (step + 1) % args.gradient_accumulation_steps == 0:
118 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
119 | optimizer.step()
120 | scheduler.step() # Update learning rate schedule
121 | model.zero_grad()
122 | global_step += 1
123 | # if args.logging_steps > 0 and global_step % args.logging_steps == 0:
124 | # logging.info("===================Dev==================")
125 | # evaluate(args, model, val_loader, device)
126 | # logging.info("===================Test==================")
127 | # evaluate(args, model, test_loader, device)
128 | if args.save_steps > 0 and global_step % args.save_steps == 0:
129 | # Save model checkpoint
130 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
131 | if not os.path.exists(output_dir):
132 | os.makedirs(output_dir)
133 | model_to_save = (
134 | model.module if hasattr(model, "module") else model
135 | ) # Take care of distributed/parallel training
136 | model_to_save.save_pretrained(output_dir)
137 | torch.save(args, os.path.join(output_dir, "training_args.bin"))
138 | logging.info("Saving model checkpoint to %s", output_dir)
139 | tokenizer.save_vocabulary(output_dir)
140 | torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
141 | torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
142 | logging.info("Saving optimizer and scheduler states to %s", output_dir)
143 | logging.info("\n")
144 | if 'cuda' in str(device):
145 | torch.cuda.empty_cache()
146 | return global_step, tr_loss / global_step
147 |
148 |
149 | def main():
150 | parser = argparse.ArgumentParser()
151 | # input and output
152 | parser.add_argument('--input_dir', required=True)
153 | parser.add_argument('--output_dir', required=True)
154 |
155 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
156 | parser.add_argument('--model_name_or_path', required = True, help = 'pretrained language models')
157 | parser.add_argument('--ckpt')
158 |
159 | # training parameters
160 | parser.add_argument('--weight_decay', default=1e-5, type=float)
161 | parser.add_argument('--batch_size', default=8, type=int)
162 | parser.add_argument('--seed', type=int, default=666, help='random seed')
163 | parser.add_argument('--learning_rate', default=3e-5, type = float)
164 | parser.add_argument('--num_train_epochs', default=25, type = int)
165 | parser.add_argument('--save_steps', default=448, type = int)
166 | parser.add_argument('--logging_steps', default=448, type = int)
167 | parser.add_argument('--warmup_proportion', default=0.1, type = float,
168 | help="Proportion of training to perform linear learning rate warmup for,E.g., 0.1 = 10% of training.")
169 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
170 | help="Epsilon for Adam optimizer.")
171 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
172 | help="Number of updates steps to accumulate before performing a backward/update pass.", )
173 | parser.add_argument("--max_grad_norm", default=1.0, type=float,
174 | help="Max gradient norm.")
175 |
176 | # validating parameters
177 | # parser.add_argument('--num_return_sequences', default=1, type=int)
178 | # parser.add_argument('--top_p', default=)
179 | # model hyperparameters
180 | parser.add_argument('--dim_hidden', default=1024, type=int)
181 | parser.add_argument('--alpha', default = 1e-4, type = float)
182 | args = parser.parse_args()
183 |
184 | if not os.path.exists(args.save_dir):
185 | os.makedirs(args.save_dir)
186 | time_ = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
187 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, '{}.log'.format(time_)))
188 | fileHandler.setFormatter(logFormatter)
189 | rootLogger.addHandler(fileHandler)
190 | # args display
191 | for k, v in vars(args).items():
192 | logging.info(k+':'+str(v))
193 |
194 | seed_everything(666)
195 |
196 | train(args)
197 |
198 |
199 | if __name__ == '__main__':
200 | main()
201 |
202 |
--------------------------------------------------------------------------------
/BlindGRU/README.md:
--------------------------------------------------------------------------------
1 | ## Requirements
2 | - python3
3 | - pytorch>=1.2.0
4 | - nltk
5 |
6 | ## How to run
7 | 1. Download [GloVe 300d vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip), unzip it to get the file `glove.840B.300d.txt`, and then convert it to a pickle file for faster loading:
8 | ```
9 | python -m utils.pickle_glove --input --output
10 | ```
11 | This step can be skipped if you have obtained the glove pickle file in other models.
12 | 2. Preprocess the training data
13 | ```
14 | python -m BlindGRU.preprocess --input_dir ./dataset --output_dir
15 | ```
16 | 3. Train
17 | ```
18 | python -m BlindGRU.train --input_dir --save_dir --glove_pt
19 | ```
20 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order.
21 | ```
22 | python -m BlindGRU.predict --input_dir --save_dir
23 | ```
24 |
--------------------------------------------------------------------------------
/BlindGRU/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import torch
4 | from utils.misc import invert_dict
5 |
6 |
7 | def load_vocab(path):
8 | vocab = json.load(open(path))
9 | vocab['word_idx_to_token'] = invert_dict(vocab['word_token_to_idx'])
10 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx'])
11 | return vocab
12 |
13 | def collate(batch):
14 | batch = list(zip(*batch))
15 | question = torch.stack(batch[0])
16 | choices = torch.stack(batch[1])
17 | if batch[-1][0] is None:
18 | answer = None
19 | else:
20 | answer = torch.cat(batch[2])
21 | return question, choices, answer
22 |
23 |
24 | class Dataset(torch.utils.data.Dataset):
25 | def __init__(self, inputs):
26 | self.questions, self.choices, self.answers = inputs
27 | self.is_test = len(self.answers)==0
28 |
29 |
30 | def __getitem__(self, index):
31 | question = torch.LongTensor(self.questions[index])
32 | choices = torch.LongTensor(self.choices[index])
33 | if self.is_test:
34 | answer = None
35 | else:
36 | answer = torch.LongTensor([self.answers[index]])
37 | return question, choices, answer
38 |
39 |
40 | def __len__(self):
41 | return len(self.questions)
42 |
43 |
44 | class DataLoader(torch.utils.data.DataLoader):
45 | def __init__(self, vocab_json, question_pt, batch_size, training=False):
46 | vocab = load_vocab(vocab_json)
47 | if training:
48 | print('#vocab of word/answer: %d/%d' %
49 | (len(vocab['word_token_to_idx']), len(vocab['answer_token_to_idx'])))
50 |
51 | inputs = []
52 | with open(question_pt, 'rb') as f:
53 | for _ in range(3):
54 | inputs.append(pickle.load(f))
55 | dataset = Dataset(inputs)
56 |
57 | super().__init__(
58 | dataset,
59 | batch_size=batch_size,
60 | shuffle=training,
61 | collate_fn=collate,
62 | )
63 | self.vocab = vocab
64 |
65 |
--------------------------------------------------------------------------------
/BlindGRU/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from utils.BiGRU import BiGRU
5 |
6 | class GRUClassifier(nn.Module):
7 | def __init__(self, vocab, dim_word, dim_hidden):
8 | super().__init__()
9 |
10 | num_class = len(vocab['answer_token_to_idx'])
11 | num_words = len(vocab['word_token_to_idx'])
12 |
13 | self.word_embeddings = nn.Embedding(num_words, dim_word)
14 | self.word_dropout = nn.Dropout(0.3)
15 | self.question_encoder = BiGRU(dim_word, dim_hidden, num_layers=2, dropout=0.2)
16 |
17 | self.classifier = nn.Sequential(
18 | nn.Linear(dim_hidden, 1024),
19 | nn.ReLU(),
20 | nn.Linear(1024, num_class)
21 | )
22 |
23 | for m in self.modules():
24 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
25 | nn.init.kaiming_normal_(m.weight)
26 | if m.bias is not None:
27 | m.bias.data.zero_()
28 |
29 | def forward(self, questions):
30 | """
31 | Args:
32 | - questions (LongTensor) [bsz, max_len]
33 | """
34 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means
35 | # print(question_lens)
36 | question_input = self.word_dropout(self.word_embeddings(questions))
37 | _, question_embeddings, _ = self.question_encoder(question_input, question_lens)
38 | logits = self.classifier(question_embeddings)
39 | return logits
40 |
--------------------------------------------------------------------------------
/BlindGRU/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import shutil
7 | from tqdm import tqdm
8 |
9 | from .data import DataLoader
10 | from .model import GRUClassifier
11 |
12 |
13 | def predict(args):
14 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
15 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
16 | test_pt = os.path.join(args.input_dir, 'test.pt')
17 | test_loader = DataLoader(vocab_json, test_pt, 128)
18 | vocab = test_loader.vocab
19 |
20 | model = GRUClassifier(vocab, args.dim_word, args.dim_hidden)
21 | model.load_state_dict(torch.load(os.path.join(args.save_dir, 'model.pt')))
22 | model = model.to(device)
23 | model.eval()
24 |
25 | def write(f, predict):
26 | predict = predict.squeeze().tolist()
27 | for i in predict:
28 | f.write(vocab['answer_idx_to_token'][i] + '\n')
29 |
30 | f1 = open(os.path.join(args.save_dir, 'predict.txt'), 'w')
31 | f2 = open(os.path.join(args.save_dir, 'choice_predict.txt'), 'w')
32 | with torch.no_grad():
33 | for batch in tqdm(test_loader, total=len(test_loader)):
34 | question, choices = [x.to(device) for x in batch[:2]]
35 | logit = model(question)
36 | predict = logit.max(1)[1]
37 | write(f1, predict)
38 | choiced_logit = torch.gather(logit, 1, choices) # [bsz, num_choices]
39 | choiced_predict = torch.gather(choices, 1, choiced_logit.max(1)[1].unsqueeze(-1)) # [bsz, 1]
40 | write(f2, choiced_predict)
41 | f1.close()
42 | f2.close()
43 |
44 |
45 |
46 | def main():
47 | parser = argparse.ArgumentParser()
48 | # input and output
49 | parser.add_argument('--input_dir', required=True)
50 | parser.add_argument('--save_dir', required=True, help='folder of checkpoint')
51 |
52 | # model hyperparameters
53 | parser.add_argument('--dim_word', default=300, type=int)
54 | parser.add_argument('--dim_hidden', default=1024, type=int)
55 | args = parser.parse_args()
56 |
57 | predict(args)
58 |
59 |
60 | if __name__ == '__main__':
61 | main()
62 |
--------------------------------------------------------------------------------
/BlindGRU/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import argparse
5 | import numpy as np
6 | from nltk import word_tokenize
7 | from collections import Counter
8 | from itertools import chain
9 | from tqdm import tqdm
10 |
11 | from utils.misc import init_vocab
12 |
13 |
14 | def encode_dataset(dataset, vocab, test=False):
15 | questions = []
16 | choices = []
17 | answers = []
18 | for question in tqdm(dataset):
19 | q = [vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx'][''])
20 | for w in word_tokenize(question['question'].lower())]
21 | questions.append(q)
22 |
23 | _ = [vocab['answer_token_to_idx'][w] for w in question['choices']]
24 | choices.append(_)
25 |
26 | if test:
27 | continue
28 |
29 | if 'answer' in question:
30 | answers.append(vocab['answer_token_to_idx'].get(question['answer']))
31 |
32 | # question padding
33 | max_len = max(len(q) for q in questions)
34 | for q in questions:
35 | while len(q) < max_len:
36 | q.append(vocab['word_token_to_idx'][''])
37 |
38 | questions = np.asarray(questions, dtype=np.int32)
39 | choices = np.asarray(choices, dtype=np.int32)
40 | answers = np.asarray(answers, dtype=np.int32)
41 | return questions, choices, answers
42 |
43 |
44 |
45 | def main():
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument('--input_dir', required=True)
48 | parser.add_argument('--output_dir', required=True)
49 | parser.add_argument('--min_cnt', type=int, default=1)
50 | args = parser.parse_args()
51 |
52 |
53 |
54 | vocab = {
55 | 'word_token_to_idx': init_vocab(), # include question text and function inputs
56 | 'answer_token_to_idx': {}
57 | }
58 | print('Load questions')
59 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json')))
60 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json')))
61 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json')))
62 | print('Build question vocabulary')
63 | word_counter = Counter()
64 | for question in train_set:
65 | tokens = word_tokenize(question['question'].lower())
66 | word_counter.update(tokens)
67 | # add candidate answers
68 | for a in question['choices']:
69 | if a not in vocab['answer_token_to_idx']:
70 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
71 | # filter low-frequency words
72 | for w, c in word_counter.items():
73 | if w and c >= args.min_cnt and w not in vocab['word_token_to_idx']:
74 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx'])
75 | # add candidate answers of val and test set
76 | for question in chain(val_set, test_set):
77 | for a in question['choices']:
78 | if a not in vocab['answer_token_to_idx']:
79 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
80 |
81 |
82 | if not os.path.isdir(args.output_dir):
83 | os.mkdir(args.output_dir)
84 | fn = os.path.join(args.output_dir, 'vocab.json')
85 | print('Dump vocab to {}'.format(fn))
86 | with open(fn, 'w') as f:
87 | json.dump(vocab, f, indent=2)
88 | for k in vocab:
89 | print('{}:{}'.format(k, len(vocab[k])))
90 |
91 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)):
92 | print('Encode {} set'.format(name))
93 | outputs = encode_dataset(dataset, vocab, name=='test')
94 | print('shape of questions, choices, answers:')
95 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f:
96 | for o in outputs:
97 | print(o.shape)
98 | pickle.dump(o, f)
99 |
100 |
101 |
102 |
103 |
104 | if __name__ == '__main__':
105 | main()
106 |
--------------------------------------------------------------------------------
/BlindGRU/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import shutil
7 | from tqdm import tqdm
8 |
9 | from .data import DataLoader
10 | from .model import GRUClassifier
11 | from utils.misc import MetricLogger, load_glove
12 |
13 | import logging
14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
15 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
16 | rootLogger = logging.getLogger()
17 |
18 |
19 | def validate(model, data, device):
20 | model.eval()
21 | count, correct = 0, 0
22 | with torch.no_grad():
23 | for batch in tqdm(data, total=len(data)):
24 | question, choices, answer = [x.to(device) for x in batch]
25 | logit = model(question)
26 | predict = logit.max(1)[1]
27 | correct += torch.eq(predict, answer).long().sum().item()
28 | count += len(answer)
29 |
30 | acc = correct / count
31 | logging.info('\nValid Accuracy: %.4f\n' % acc)
32 | return acc
33 |
34 |
35 | def train(args):
36 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
37 |
38 | logging.info("Create train_loader and val_loader.........")
39 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
40 | train_pt = os.path.join(args.input_dir, 'train.pt')
41 | val_pt = os.path.join(args.input_dir, 'val.pt')
42 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True)
43 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size)
44 | vocab = train_loader.vocab
45 |
46 | logging.info("Create model.........")
47 | model = GRUClassifier(vocab, args.dim_word, args.dim_hidden)
48 | logging.info("Load pretrained word vectors.........")
49 | pretrained = load_glove(args.glove_pt, vocab['word_idx_to_token'])
50 | with torch.no_grad():
51 | model.word_embeddings.weight.set_(torch.Tensor(pretrained))
52 | model = model.to(device)
53 | logging.info(model)
54 |
55 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
56 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[5, 50], gamma=0.1)
57 | criterion = nn.CrossEntropyLoss().to(device)
58 |
59 | validate(model, val_loader, device)
60 | meters = MetricLogger(delimiter=" ")
61 | best_acc = 0
62 | logging.info("Start training........")
63 | for epoch in range(args.num_epoch):
64 | model.train()
65 | for iteration, batch in enumerate(train_loader):
66 | iteration = iteration + 1
67 |
68 | question, choices, answer = [x.to(device) for x in batch]
69 | logits = model(question)
70 | loss = criterion(logits, answer)
71 | optimizer.zero_grad()
72 | loss.backward()
73 | optimizer.step()
74 | meters.update(loss=loss.item())
75 |
76 | if iteration % (len(train_loader) // 100) == 0:
77 | logging.info(
78 | meters.delimiter.join(
79 | [
80 | "progress: {progress:.3f}",
81 | "{meters}",
82 | "lr: {lr:.6f}",
83 | ]
84 | ).format(
85 | progress=epoch + iteration / len(train_loader),
86 | meters=str(meters),
87 | lr=optimizer.param_groups[0]["lr"],
88 | )
89 | )
90 |
91 | acc = validate(model, val_loader, device)
92 | scheduler.step()
93 | if acc and acc > best_acc:
94 | best_acc = acc
95 | logging.info("\nupdate best ckpt with acc: {:.4f}".format(best_acc))
96 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model.pt'))
97 |
98 |
99 | def main():
100 | parser = argparse.ArgumentParser()
101 | # input and output
102 | parser.add_argument('--input_dir', required=True)
103 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
104 | parser.add_argument('--glove_pt', required=True)
105 |
106 | # training parameters
107 | parser.add_argument('--lr', default=0.001, type=float)
108 | parser.add_argument('--weight_decay', default=1e-5, type=float)
109 | parser.add_argument('--num_epoch', default=100, type=int)
110 | parser.add_argument('--batch_size', default=128, type=int)
111 | parser.add_argument('--seed', type=int, default=666, help='random seed')
112 |
113 | # model hyperparameters
114 | parser.add_argument('--dim_word', default=300, type=int)
115 | parser.add_argument('--dim_hidden', default=1024, type=int)
116 | args = parser.parse_args()
117 |
118 | # make logging.info display into both shell and file
119 | if os.path.isdir(args.save_dir):
120 | shutil.rmtree(args.save_dir)
121 | os.mkdir(args.save_dir)
122 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt'))
123 | fileHandler.setFormatter(logFormatter)
124 | rootLogger.addHandler(fileHandler)
125 | # args display
126 | for k, v in vars(args).items():
127 | logging.info(k+':'+str(v))
128 |
129 | # set random seed
130 | torch.manual_seed(args.seed)
131 |
132 | train(args)
133 |
134 |
135 | if __name__ == '__main__':
136 | main()
137 |
--------------------------------------------------------------------------------
/KVMemNN/README.md:
--------------------------------------------------------------------------------
1 | ## Requirements
2 | - python3
3 | - pytorch>=1.2.0
4 | - nltk
5 |
6 | ## How to run
7 | 1. Download [GloVe 300d vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip), unzip it to get the file `glove.840B.300d.txt`, and then convert it to a pickle file for faster loading:
8 | ```
9 | python -m utils.pickle_glove --input --output
10 | ```
11 | This step can be skipped if you have obtained the glove pickle file in other models.
12 | 2. Preprocess the training data
13 | ```
14 | python -m KVMemNN.preprocess --input_dir ./dataset --output_dir
15 | ```
16 | 3. Train
17 | ```
18 | python -m KVMemNN.train --input_dir --save_dir --glove_pt
19 | ```
20 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order.
21 | ```
22 | python -m KVMemNN.predict --input_dir --save_dir
23 | ```
24 |
--------------------------------------------------------------------------------
/KVMemNN/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import torch
4 | from utils.misc import invert_dict
5 |
6 |
7 | def load_vocab(path):
8 | vocab = json.load(open(path))
9 | vocab['word_idx_to_token'] = invert_dict(vocab['word_token_to_idx'])
10 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx'])
11 | return vocab
12 |
13 | def collate(batch):
14 | batch = list(zip(*batch))
15 | question, choices, keys, values = list(map(torch.stack, batch[:4]))
16 | if batch[-1][0] is None:
17 | answer = None
18 | else:
19 | answer = torch.cat(batch[-1])
20 | return question, choices, keys, values, answer
21 |
22 |
23 | class Dataset(torch.utils.data.Dataset):
24 | def __init__(self, all_keys, all_values, inputs):
25 | self.all_keys = all_keys
26 | self.all_values = all_values
27 | self.questions, self.key_indexes, self.choices, self.answers = inputs
28 | self.is_test = len(self.answers)==0
29 |
30 |
31 | def __getitem__(self, index):
32 | question = torch.LongTensor(self.questions[index])
33 | key_index = self.key_indexes[index]
34 | keys = torch.LongTensor(self.all_keys[key_index])
35 | values = torch.LongTensor(self.all_values[key_index])
36 | choices = torch.LongTensor(self.choices[index])
37 | if self.is_test:
38 | answer = None
39 | else:
40 | answer = torch.LongTensor([self.answers[index]])
41 | return question, choices, keys, values, answer
42 |
43 |
44 | def __len__(self):
45 | return len(self.questions)
46 |
47 |
48 | class DataLoader(torch.utils.data.DataLoader):
49 | def __init__(self, vocab_json, kb_pt, question_pt, batch_size, training=False):
50 | vocab = load_vocab(vocab_json)
51 |
52 | inputs = []
53 | with open(question_pt, 'rb') as f:
54 | for _ in range(4):
55 | inputs.append(pickle.load(f))
56 | with open(kb_pt, 'rb') as f:
57 | all_keys = pickle.load(f)
58 | all_values = pickle.load(f)
59 | dataset = Dataset(all_keys, all_values, inputs)
60 |
61 | super().__init__(
62 | dataset,
63 | batch_size=batch_size,
64 | shuffle=training,
65 | collate_fn=collate,
66 | )
67 | self.vocab = vocab
68 |
69 |
--------------------------------------------------------------------------------
/KVMemNN/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from utils.BiGRU import BiGRU, GRU
5 |
6 | class KVMemNN(nn.Module):
7 | def __init__(self, num_hop, dim_emb, vocab):
8 | super().__init__()
9 | self.num_hop = num_hop
10 | num_vocab = len(vocab['word_token_to_idx'])
11 | num_class = len(vocab['answer_token_to_idx'])
12 |
13 | self.embeddings = nn.Embedding(num_vocab, dim_emb)
14 | self.question_encoder = BiGRU(dim_emb, dim_emb, num_layers=2, dropout=0.2)
15 | self.word_dropout = nn.Dropout(0.3)
16 | self.linears = []
17 | for i in range(num_hop):
18 | lin = nn.Linear(dim_emb, dim_emb)
19 | self.linears.append(lin)
20 | self.add_module('linear_{}'.format(i), lin)
21 |
22 | self.classifier = nn.Sequential(
23 | nn.Linear(dim_emb, 1024),
24 | nn.ReLU(),
25 | nn.Linear(1024, num_class)
26 | )
27 | for m in self.modules():
28 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
29 | nn.init.kaiming_normal_(m.weight)
30 | if m.bias is not None:
31 | m.bias.data.zero_()
32 |
33 | def forward(self, questions, keys, values):
34 | """
35 | Args:
36 | questions [bsz, max_q_len]
37 | keys [bsz, num_slot, max_k_len]
38 | values [bsz, num_slot, max_v_len]
39 | """
40 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means
41 | q_word_emb = self.word_dropout(self.embeddings(questions))
42 | q, q_embeddings, q_hn = self.question_encoder(q_word_emb, question_lens)
43 | q = self.embeddings(questions).sum(dim=1) # [bsz, dim_emb]
44 | k = self.embeddings(keys).sum(dim=2) # [bsz, num_slot, dim_emb]
45 | v = self.embeddings(values).sum(dim=2) # [bsz, num_slot, dim_emb]
46 |
47 | for i in range(self.num_hop):
48 | weights = torch.bmm(k, q.unsqueeze(2)).squeeze(2) # [bsz, num_slot]
49 | weights = torch.softmax(weights, dim=1)
50 | o = torch.bmm(weights.unsqueeze(1), v).squeeze(1) # [bsz, dim_emb]
51 | q = self.linears[i](q + o) # [bsz, dim_emb]
52 | logits = self.classifier(q)
53 | return logits
54 |
--------------------------------------------------------------------------------
/KVMemNN/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import shutil
7 | from tqdm import tqdm
8 |
9 | from .data import DataLoader
10 | from .model import KVMemNN
11 |
12 |
13 | def main():
14 | parser = argparse.ArgumentParser()
15 | # input and output
16 | parser.add_argument('--input_dir', required=True)
17 | parser.add_argument('--save_dir', required=True, help='path of checkpoint')
18 |
19 | # model hyperparameters
20 | parser.add_argument('--dim_emb', default=300, type=int)
21 | parser.add_argument('--num_hop', default=3, type=int)
22 | args = parser.parse_args()
23 |
24 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
25 |
26 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
27 | test_pt = os.path.join(args.input_dir, 'test.pt')
28 | kb_pt = os.path.join(args.input_dir, 'kb.pt')
29 | test_loader = DataLoader(vocab_json, kb_pt, test_pt, 32)
30 | vocab = test_loader.vocab
31 |
32 |
33 | model = KVMemNN(
34 | args.num_hop,
35 | args.dim_emb,
36 | vocab
37 | )
38 | model.load_state_dict(torch.load(os.path.join(args.save_dir, 'model.pt')))
39 | model = model.to(device)
40 | model.eval()
41 |
42 | def write(f, predict):
43 | predict = predict.squeeze().tolist()
44 | for i in predict:
45 | f.write(vocab['answer_idx_to_token'][i] + '\n')
46 |
47 | f1 = open(os.path.join(args.save_dir, 'predict.txt'), 'w')
48 | f2 = open(os.path.join(args.save_dir, 'choice_predict.txt'), 'w')
49 | with torch.no_grad():
50 | for batch in tqdm(test_loader, total=len(test_loader)):
51 | question, choices, keys, values = [x.to(device) for x in batch[:4]]
52 | logit = model(question, keys, values) # [bsz, num_answers]
53 | predict = logit.max(1)[1]
54 | write(f1, predict)
55 | choiced_logit = torch.gather(logit, 1,choices) # [bsz, num_choices]
56 | choiced_predict = torch.gather(choices, 1, choiced_logit.max(1)[1].unsqueeze(-1)) # [bsz, 1]
57 | write(f2, choiced_predict)
58 | f1.close()
59 | f2.close()
60 |
61 |
62 | if __name__ == '__main__':
63 | main()
64 |
--------------------------------------------------------------------------------
/KVMemNN/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import argparse
5 | import numpy as np
6 | from nltk import word_tokenize
7 | from collections import Counter, defaultdict
8 | from itertools import chain
9 | from tqdm import tqdm
10 |
11 | from utils.load_kb import load_as_key_value
12 | from utils.misc import init_vocab
13 |
14 |
15 | def create_inverted(keys):
16 | inverted_index = defaultdict(set)
17 | counter = Counter()
18 | for i in range(len(keys)):
19 | for w in keys[i]:
20 | inverted_index[w].add(i)
21 | counter[w] += 1
22 | return inverted_index
23 |
24 |
25 | def find_candidate_keys(inverted_index, stopwords, question, num_cand_keys):
26 | """
27 | find keys that are relevant to question, and then return the top num_cand_keys
28 | if not enough, pad 0
29 | """
30 | words = word_tokenize(question['question'].lower())
31 | counter = Counter()
32 | for w in words:
33 | if w in stopwords: # skip stopwords
34 | continue
35 | counter.update(inverted_index.get(w, []))
36 | indexes = [x[0] for x in counter.most_common(num_cand_keys)]
37 | if len(indexes) < num_cand_keys:
38 | indexes += [0] * (num_cand_keys - len(indexes))
39 | return indexes
40 |
41 |
42 |
43 | def encode_kb(keys, values, vocab):
44 | encoded_keys = []
45 | encoded_values = []
46 | for i in tqdm(range(len(keys))):
47 | encoded_keys.append([vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx']['']) for w in keys[i]])
48 | encoded_values.append([vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx']['']) for w in values[i]])
49 | keys = encoded_keys
50 | values = encoded_values
51 | max_len = max(len(k) for k in keys)
52 | for k in keys:
53 | while len(k) < max_len:
54 | k.append(vocab['word_token_to_idx'][''])
55 | max_len = max(len(k) for k in values)
56 | for k in values:
57 | while len(k) < max_len:
58 | k.append(vocab['word_token_to_idx'][''])
59 | keys = np.asarray(keys, dtype=np.int32)
60 | values = np.asarray(values, dtype=np.int32)
61 | return keys, values
62 |
63 |
64 | def encode_dataset(dataset, vocab, inverted_index, stopwords, num_cand_keys):
65 | questions = []
66 | key_indexes = []
67 | choices = []
68 | answers = []
69 | for question in tqdm(dataset):
70 | q = [vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx'][''])
71 | for w in word_tokenize(question['question'].lower())]
72 | questions.append(q)
73 |
74 | key_indexes.append(find_candidate_keys(inverted_index, stopwords, question, num_cand_keys))
75 |
76 |
77 | _ = [vocab['answer_token_to_idx'][w] for w in question['choices']]
78 | choices.append(_)
79 | if 'answer' in question:
80 | answers.append(vocab['answer_token_to_idx'].get(question['answer']))
81 |
82 | # question padding
83 | max_len = max(len(q) for q in questions)
84 | for q in questions:
85 | while len(q) < max_len:
86 | q.append(vocab['word_token_to_idx'][''])
87 |
88 | questions = np.asarray(questions, dtype=np.int32)
89 | key_indexes = np.asarray(key_indexes, dtype=np.int32)
90 | choices = np.asarray(choices, dtype=np.int32)
91 | answers = np.asarray(answers, dtype=np.int32)
92 | return questions, key_indexes, choices, answers
93 |
94 |
95 |
96 | def main():
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument('--input_dir', required=True)
99 | parser.add_argument('--output_dir', required=True)
100 | parser.add_argument('--min_cnt', type=int, default=1)
101 | parser.add_argument('--stop_thresh', type=int, default=1000)
102 | parser.add_argument('--num_cand_keys', type=int, default=1000)
103 | args = parser.parse_args()
104 |
105 |
106 | print('Build kb vocabulary')
107 | kb_vocab, kb_keys, kb_values = load_as_key_value(os.path.join(args.input_dir, 'kb.json'), args.min_cnt)
108 | vocab = {
109 | 'word_token_to_idx': init_vocab(),
110 | 'answer_token_to_idx': {}
111 | }
112 | print('Load questions')
113 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json')))
114 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json')))
115 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json')))
116 | print('Build question vocabulary')
117 | word_counter = Counter()
118 | for question in train_set:
119 | tokens = word_tokenize(question['question'].lower())
120 | word_counter.update(tokens)
121 | # add candidate answers
122 | for a in question['choices']:
123 | if a not in vocab['answer_token_to_idx']:
124 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
125 | # filter low-frequency words
126 | stopwords = set()
127 | for w, c in word_counter.items():
128 | if w and c >= args.min_cnt and w not in vocab['word_token_to_idx']:
129 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx'])
130 | if w and c >= args.stop_thresh:
131 | stopwords.add(w)
132 | print('number of stop words (>={}): {}'.format(args.stop_thresh, len(stopwords)))
133 | # merge kb vocab
134 | for w in kb_vocab:
135 | if w not in vocab['word_token_to_idx']:
136 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx'])
137 | # add candidate answers of val and test set
138 | for question in chain(val_set, test_set):
139 | for a in question['choices']:
140 | if a not in vocab['answer_token_to_idx']:
141 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
142 |
143 |
144 | if not os.path.isdir(args.output_dir):
145 | os.mkdir(args.output_dir)
146 | fn = os.path.join(args.output_dir, 'vocab.json')
147 | print('Dump vocab to {}'.format(fn))
148 | with open(fn, 'w') as f:
149 | json.dump(vocab, f, indent=2)
150 | for k in vocab:
151 | print('{}:{}'.format(k, len(vocab[k])))
152 |
153 | print('Create inverted index for keys')
154 | inverted_index = create_inverted(kb_keys)
155 |
156 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)):
157 | print('Encode {} set'.format(name))
158 | outputs = encode_dataset(dataset, vocab, inverted_index, stopwords, args.num_cand_keys)
159 | print('shape of questions, key indexes, choices, answers:')
160 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f:
161 | for o in outputs:
162 | print(o.shape)
163 | pickle.dump(o, f)
164 |
165 | print('Encode kb')
166 | outputs = encode_kb(kb_keys, kb_values, vocab)
167 | print('shape of keys, values:')
168 | with open(os.path.join(args.output_dir, 'kb.pt'), 'wb') as f:
169 | for o in outputs:
170 | print(o.shape)
171 | pickle.dump(o, f)
172 |
173 |
174 |
175 |
176 | if __name__ == '__main__':
177 | main()
178 |
--------------------------------------------------------------------------------
/KVMemNN/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import shutil
7 | from tqdm import tqdm
8 |
9 | from utils.misc import MetricLogger, load_glove
10 | from .data import DataLoader
11 | from .model import KVMemNN
12 |
13 | import logging
14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
15 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
16 | rootLogger = logging.getLogger()
17 |
18 | torch.set_num_threads(1) # avoid using multiple cpus
19 |
20 | def validate(model, data, device):
21 | model.eval()
22 | count, correct = 0, 0
23 | with torch.no_grad():
24 | for batch in tqdm(data, total=len(data)):
25 | question, choices, keys, values, answer = [x.to(device) for x in batch]
26 | logit = model(question, keys, values)
27 | predict = logit.max(1)[1]
28 | correct += torch.eq(predict, answer).long().sum().item()
29 | count += len(answer)
30 |
31 | acc = correct / count
32 | logging.info('\nValid Accuracy: %.4f\n' % acc)
33 | return acc
34 |
35 |
36 | def train(args):
37 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
38 |
39 | logging.info("Create train_loader and val_loader.........")
40 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
41 | train_pt = os.path.join(args.input_dir, 'train.pt')
42 | val_pt = os.path.join(args.input_dir, 'val.pt')
43 | kb_pt = os.path.join(args.input_dir, 'kb.pt')
44 | train_loader = DataLoader(vocab_json, kb_pt, train_pt, args.batch_size, training=True)
45 | val_loader = DataLoader(vocab_json, kb_pt, val_pt, args.batch_size)
46 | vocab = train_loader.vocab
47 |
48 | logging.info("Create model.........")
49 | model = KVMemNN(
50 | args.num_hop,
51 | args.dim_emb,
52 | vocab
53 | )
54 | logging.info("Load pretrained word vectors.........")
55 | pretrained = load_glove(args.glove_pt, vocab['word_idx_to_token'])
56 | with torch.no_grad():
57 | model.embeddings.weight.set_(torch.Tensor(pretrained))
58 | model = model.to(device)
59 | logging.info(model)
60 |
61 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
62 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[5, 50], gamma=0.1)
63 | criterion = nn.CrossEntropyLoss().to(device)
64 |
65 | validate(model, val_loader, device)
66 | meters = MetricLogger(delimiter=" ")
67 | best_acc = 0
68 | logging.info("Start training........")
69 | for epoch in range(args.num_epoch):
70 | model.train()
71 | for iteration, batch in enumerate(train_loader):
72 | iteration = iteration + 1
73 |
74 | question, choices, keys, values, answer = [x.to(device) for x in batch]
75 | logits = model(question, keys, values)
76 | loss = criterion(logits, answer)
77 | optimizer.zero_grad()
78 | loss.backward()
79 | optimizer.step()
80 | meters.update(loss=loss.item())
81 |
82 | if iteration % (len(train_loader) // 100) == 0:
83 | logging.info(
84 | meters.delimiter.join(
85 | [
86 | "progress: {progress:.3f}",
87 | "{meters}",
88 | "lr: {lr:.6f}",
89 | ]
90 | ).format(
91 | progress=epoch + iteration / len(train_loader),
92 | meters=str(meters),
93 | lr=optimizer.param_groups[0]["lr"],
94 | )
95 | )
96 |
97 | acc = validate(model, val_loader, device)
98 | scheduler.step()
99 | if acc and acc > best_acc:
100 | best_acc = acc
101 | logging.info("\nupdate best ckpt with acc: {:.4f}".format(best_acc))
102 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model.pt'))
103 |
104 |
105 | def main():
106 | parser = argparse.ArgumentParser()
107 | # input and output
108 | parser.add_argument('--input_dir', required=True)
109 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
110 | parser.add_argument('--glove_pt', required=True)
111 |
112 | # training parameters
113 | parser.add_argument('--lr', default=0.001, type=float)
114 | parser.add_argument('--weight_decay', default=1e-5, type=float)
115 | parser.add_argument('--num_epoch', default=100, type=int)
116 | parser.add_argument('--batch_size', default=32, type=int)
117 | parser.add_argument('--seed', type=int, default=666, help='random seed')
118 | # model hyperparameters
119 | parser.add_argument('--dim_emb', default=300, type=int)
120 | parser.add_argument('--num_hop', default=3, type=int)
121 | args = parser.parse_args()
122 |
123 | # make logging.info display into both shell and file
124 | if not os.path.exists(args.save_dir):
125 | os.makedirs(args.save_dir)
126 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt'))
127 | fileHandler.setFormatter(logFormatter)
128 | rootLogger.addHandler(fileHandler)
129 | # args display
130 | for k, v in vars(args).items():
131 | logging.info(k+':'+str(v))
132 |
133 | # set random seed
134 | torch.manual_seed(args.seed)
135 |
136 | train(args)
137 |
138 |
139 | if __name__ == '__main__':
140 | main()
141 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 THU-KEG
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Program/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import torch
4 | from utils.misc import invert_dict
5 |
6 |
7 | def load_vocab(path):
8 | vocab = json.load(open(path))
9 | vocab['word_idx_to_token'] = invert_dict(vocab['word_token_to_idx'])
10 | vocab['function_idx_to_token'] = invert_dict(vocab['function_token_to_idx'])
11 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx'])
12 | return vocab
13 |
14 | def collate(batch):
15 | batch = list(zip(*batch))
16 | question = torch.stack(batch[0])
17 | choices = torch.stack(batch[1])
18 | if batch[-1][0] is None:
19 | program, prog_depends, prog_inputs, answer = None, None, None, None
20 | else:
21 | program, prog_depends, prog_inputs = list(map(torch.stack, batch[2:5]))
22 | answer = torch.cat(batch[5])
23 | return question, choices, program, prog_depends, prog_inputs, answer
24 |
25 |
26 | class Dataset(torch.utils.data.Dataset):
27 | def __init__(self, inputs):
28 | self.questions, self.functions, self.func_depends, self.func_inputs, \
29 | self.choices, self.answers = inputs
30 | self.is_test = len(self.answers)==0
31 |
32 |
33 | def __getitem__(self, index):
34 | question = torch.LongTensor(self.questions[index])
35 | choices = torch.LongTensor(self.choices[index])
36 | if self.is_test:
37 | program = None
38 | prog_depends = None
39 | prog_inputs = None
40 | answer = None
41 | else:
42 | program = torch.LongTensor(self.functions[index])
43 | prog_depends = torch.LongTensor(self.func_depends[index])
44 | prog_inputs = torch.LongTensor(self.func_inputs[index])
45 | answer = torch.LongTensor([self.answers[index]])
46 | # dependency is not necessary because it can be inferred based on functions
47 | return question, choices, program, prog_depends, prog_inputs, answer
48 |
49 |
50 | def __len__(self):
51 | return len(self.questions)
52 |
53 |
54 | class DataLoader(torch.utils.data.DataLoader):
55 | def __init__(self, vocab_json, question_pt, batch_size, training=False):
56 | vocab = load_vocab(vocab_json)
57 | if training:
58 | print('#vocab of word: %d' % len(vocab['word_token_to_idx']))
59 | print('#vocab of answer: %d' % len(vocab['answer_token_to_idx']))
60 |
61 | inputs = []
62 | with open(question_pt, 'rb') as f:
63 | for _ in range(6):
64 | inputs.append(pickle.load(f))
65 | dataset = Dataset(inputs)
66 |
67 | super().__init__(
68 | dataset,
69 | batch_size=batch_size,
70 | shuffle=training,
71 | collate_fn=collate,
72 | )
73 | self.vocab = vocab
74 |
--------------------------------------------------------------------------------
/Program/parser.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from utils.BiGRU import GRU, BiGRU
5 |
6 | class Parser(nn.Module):
7 | def __init__(self, vocab, dim_word, dim_hidden, max_dec_len=20, max_inp=3):
8 | super().__init__()
9 | num_func = len(vocab['function_token_to_idx'])
10 | num_words = len(vocab['word_token_to_idx'])
11 | self.vocab = vocab
12 | self.dim_word = dim_word
13 | self.dim_hidden = dim_hidden
14 | self.max_dec_len = max_dec_len
15 | self.max_inp = max_inp
16 |
17 | self.word_embeddings = nn.Embedding(num_words, dim_word)
18 | self.word_dropout = nn.Dropout(0.2)
19 | self.question_encoder = GRU(dim_word, dim_hidden, num_layers=2, dropout=0.2)
20 |
21 | self.func_embeddings = nn.Embedding(num_func, dim_word)
22 | self.decoder = GRU(dim_word, dim_hidden, num_layers=2, dropout=0.2)
23 |
24 | self.func_classifier = nn.Sequential(
25 | nn.Linear(dim_hidden, 1024),
26 | nn.ReLU(),
27 | nn.Linear(1024, num_func),
28 | )
29 |
30 | self.inp_embeddings = nn.Embedding(num_words, dim_word)
31 | self.inp_decoder = GRU(dim_word + dim_hidden, dim_hidden, num_layers=2, dropout=0.2)
32 | self.inp_classifier = nn.Sequential(
33 | nn.Linear(dim_hidden, 1024),
34 | nn.ReLU(),
35 | nn.Linear(1024, num_words),
36 | )
37 |
38 | for m in self.modules():
39 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
40 | nn.init.kaiming_normal_(m.weight)
41 | if m.bias is not None:
42 | m.bias.data.zero_()
43 |
44 | def forward(self, questions, programs=None, inputs=None):
45 | """
46 | Args:
47 | questions [bsz, max_q]
48 | programs [bsz, max_prog]
49 | inputs [bsz, max_prog, max_inp=3]
50 | Return:
51 | if programs are given, then return losses
52 | else, return predicted programs
53 | """
54 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means
55 | q_word_emb = self.word_dropout(self.word_embeddings(questions))
56 | q_word_h, q_embeddings, q_hn = self.question_encoder(q_word_emb, question_lens)
57 | # [bsz, max_q, dim_h], [bsz, dim_h], [num_layers, bsz, dim_h]
58 |
59 | if programs is None: # during inference
60 | return self.inference(q_word_h, q_embeddings, q_hn)
61 | else:
62 | return self.train_phase(q_word_h, q_embeddings, q_hn, programs, inputs)
63 |
64 |
65 | def train_phase(self, q_word_h, q_embeddings, q_hn, programs, inputs):
66 | bsz, max_prog = programs.size(0), programs.size(1)
67 | device = programs.device
68 | program_lens = programs.size(1) - programs.eq(0).long().sum(dim=1) # 0 means
69 | program_mask = programs.ne(0).long()
70 |
71 | p_word_emb = self.word_dropout(self.func_embeddings(programs))
72 | p_word_h, _, _ = self.decoder(p_word_emb, program_lens, h_0=q_hn) # [bsz, max_prog, dim_h]
73 | # attention over question words
74 | attn = torch.softmax(torch.bmm(p_word_h, q_word_h.permute(0, 2, 1)), dim=2) # [bsz, max_prog, max_q]
75 | attn_word_h = torch.bmm(attn, q_word_h) # [bsz, max_prog, dim_h]
76 | # sum up
77 | p_word_h = p_word_h + attn_word_h # [bsz, max_prog, dim_h]
78 |
79 |
80 | criterion_CE = nn.CrossEntropyLoss().to(device)
81 | # predict function
82 | logit_func = self.func_classifier(p_word_h) # [bsz, max_prog, num_func]
83 | loss_func = criterion_CE(logit_func.permute(0, 2, 1)[:,:,:-1], programs[:,1:]) # remember to shift the gt
84 |
85 | # remove inputs of function
86 | inputs = inputs[:,1:,:].view(bsz, -1) # [bsz, (max_prog-1)*3]
87 | # add an extra at the beginning, for convenience of inference
88 | start_token = torch.zeros((bsz, 1)).to(device).fill_(self.vocab['word_token_to_idx']['']).long()
89 | inputs = torch.cat((start_token, inputs), dim=1) # [bsz, 1+(max_prog-1)*3]
90 | inp_emb = self.word_dropout(self.inp_embeddings(inputs)) # [bsz, 1+(max_prog-1)*3, dim_w]
91 |
92 | rep_p_word_h = p_word_h.view(bsz, max_prog, 1, -1).expand(-1, -1, 3, -1).\
93 | reshape(bsz, max_prog*3, -1).contiguous() # [bsz, max_prog*3, dim_h]
94 | # align, so that func is used to predict the 3 inputs of the first function
95 | rep_p_word_h = rep_p_word_h[:, :1+(max_prog-1)*3]
96 | inp_h, _, _ = self.inp_decoder(torch.cat((inp_emb, rep_p_word_h), dim=2),
97 | 1+(program_lens-1)*3, h_0=q_hn) # [bsz, 1+(max_prog-1)*3, dim_h]
98 | # attention over question words
99 | attn = torch.softmax(torch.bmm(inp_h, q_word_h.permute(0, 2, 1)), dim=2)
100 | attn_word_h = torch.bmm(attn, q_word_h)
101 | # sum up
102 | inp_h = inp_h + attn_word_h # [bsz, 1+(max_prog-1)*3, dim_h]
103 | # logit
104 | logit_inp = self.inp_classifier(inp_h) # [bsz, 1+(max_prog-1)*3, dim_h]
105 | loss_inp = criterion_CE(logit_inp.permute(0, 2, 1)[:,:,:-1], inputs[:,1:]) # shift the input
106 |
107 | loss = loss_func + loss_inp
108 |
109 | return loss
110 |
111 |
112 | def inference(self, q_word_h, q_embeddings, q_hn):
113 | """
114 | Predict programs, and inputs
115 | """
116 | bsz = q_word_h.size(0)
117 | device = q_word_h.device
118 | start_id = self.vocab['function_token_to_idx']['']
119 | end_id = self.vocab['function_token_to_idx']['']
120 |
121 | latest_func = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ]
122 | last_h = q_hn
123 | finished = torch.zeros((bsz,)).byte().to(device) # record whether is produced
124 |
125 | latest_inp = torch.LongTensor([self.vocab['word_token_to_idx']['']]*bsz).to(device) # [bsz, ]
126 | last_inp_h = q_hn
127 |
128 | # store predictions at each step
129 | programs = [latest_func]
130 | inputs = [torch.zeros((bsz, self.max_inp)).long().to(device)]
131 |
132 | for i in range(self.max_dec_len):
133 | p_word_emb = self.word_dropout(self.func_embeddings(latest_func)).unsqueeze(1) # [bsz, 1, dim_w]
134 | p_word_h, last_h = self.decoder.forward_one_step(p_word_emb, last_h) # [bsz, 1, dim_h]
135 | # attention over question words
136 | attn = torch.softmax(torch.bmm(p_word_h, q_word_h.permute(0, 2, 1)), dim=2) # [bsz, 1, max_q]
137 | attn_word_h = torch.bmm(attn, q_word_h) # [bsz, 1, dim_h]
138 | # sum up
139 | p_word_h = p_word_h + attn_word_h # [bsz, 1, dim_h]
140 |
141 | # predict function
142 | logit_func = self.func_classifier(p_word_h).squeeze(1) # [bsz, num_func]
143 | latest_func = torch.argmax(logit_func, dim=1) # [bsz, ]
144 | programs.append(latest_func)
145 |
146 | # predict input
147 | pred_inp = []
148 | for _ in range(self.max_inp):
149 | inp_emb = self.word_dropout(self.inp_embeddings(latest_inp)).unsqueeze(1) # [bsz, 1, dim_w]
150 | inp_h, last_inp_h = self.inp_decoder.forward_one_step(
151 | torch.cat((inp_emb, p_word_h), dim=2),
152 | last_inp_h) # [bsz, 1, dim_h]
153 | attn = torch.softmax(torch.bmm(inp_h, q_word_h.permute(0, 2, 1)), dim=2)
154 | attn_word_h = torch.bmm(attn, q_word_h)
155 | inp_h = inp_h + attn_word_h # [bsz, 1, dim_h]
156 |
157 | logit_inp = self.inp_classifier(inp_h).squeeze(1) # [bsz, num_word]
158 | latest_inp = torch.argmax(logit_inp, dim=1) # [bsz, ]
159 | pred_inp.append(latest_inp)
160 | pred_inp = torch.stack(pred_inp, dim=1) # [bsz, 3]
161 | inputs.append(pred_inp)
162 |
163 | finished = finished | latest_func.eq(end_id).byte()
164 | if finished.sum().item() == bsz:
165 | # print('finished at step {}'.format(i))
166 | break
167 |
168 | programs = torch.stack(programs, dim=1) # [bsz, max_prog]
169 | inputs = torch.stack(inputs, dim=1) # [bsz, max_prog, 3]
170 | return programs, inputs
171 |
172 |
--------------------------------------------------------------------------------
/Program/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import shutil
5 | from tqdm import tqdm
6 | import numpy as np
7 |
8 | from .data import DataLoader
9 | from .parser import Parser
10 | from .executor_rule import RuleExecutor
11 |
12 | def main():
13 | parser = argparse.ArgumentParser()
14 | # input and output
15 | parser.add_argument('--input_dir', required=True)
16 | parser.add_argument('--save_dir', required=True, help='path of checkpoint')
17 | # model hyperparameters
18 | parser.add_argument('--dim_word', default=300, type=int)
19 | parser.add_argument('--dim_hidden', default=1024, type=int)
20 | args = parser.parse_args()
21 |
22 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
23 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
24 | test_pt = os.path.join(args.input_dir, 'test.pt')
25 | test_loader = DataLoader(vocab_json, test_pt, 128)
26 | vocab = test_loader.vocab
27 |
28 | rule_executor = RuleExecutor(vocab, os.path.join(args.input_dir, 'kb.json'))
29 | model = Parser(vocab, args.dim_word, args.dim_hidden)
30 |
31 | print("load ckpt from {}".format(args.save_dir))
32 | model.load_state_dict(
33 | torch.load(os.path.join(args.save_dir, 'model.pt'), map_location={'cuda': 'cpu'}))
34 | model = model.to(device)
35 | model.eval()
36 |
37 | with open(os.path.join(args.save_dir, 'predict.txt'), 'w') as f:
38 | with torch.no_grad():
39 | for batch in tqdm(test_loader, total=len(test_loader)):
40 | question, choices = [x.to(device) for x in batch[:2]]
41 | pred_program, pred_inputs = model(question)
42 |
43 | pred_program, pred_inputs = [x.cpu().numpy() for x in (pred_program, pred_inputs)]
44 | for i in range(len(pred_program)):
45 | pred = rule_executor.forward(pred_program[i], pred_inputs[i], ignore_error=True)
46 | f.write(str(pred) + '\n')
47 | print("save predictions into {}".format(os.path.join(args.save_dir, 'predict.txt')))
48 |
49 | if __name__ == '__main__':
50 | main()
51 |
--------------------------------------------------------------------------------
/Program/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import argparse
5 | import numpy as np
6 | from nltk import word_tokenize
7 | from collections import Counter, defaultdict
8 | from itertools import chain
9 | from tqdm import tqdm
10 |
11 | from utils.misc import init_vocab
12 |
13 | max_dep = 2
14 | max_inp = 3
15 |
16 |
17 | def encode_dataset(dataset, vocab, test=False):
18 | questions = []
19 | functions = []
20 | func_depends = []
21 | func_inputs = []
22 | choices = []
23 | answers = []
24 | for question in tqdm(dataset):
25 | q = [vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx'][''])
26 | for w in word_tokenize(question['question'].lower())]
27 | questions.append(q)
28 |
29 | _ = [vocab['answer_token_to_idx'][w] for w in question['choices']]
30 | choices.append(_)
31 |
32 | if test:
33 | continue
34 |
35 | func, dep, inp = [], [], []
36 | # wrap program with and flags
37 | program = [{'function':'','dependencies':[-1,-1],'inputs':['']}] + \
38 | question['program'] + \
39 | [{'function':'','dependencies':[-1,-1],'inputs':['']}]
40 | for f in program:
41 | func.append(vocab['function_token_to_idx'][f['function']])
42 | dep.append(f['dependencies'])
43 | inp.append([vocab['word_token_to_idx'].get(i, vocab['word_token_to_idx'][''])
44 | for i in f['inputs']])
45 |
46 | functions.append(func)
47 | func_depends.append(dep)
48 | func_inputs.append(inp)
49 |
50 | if 'answer' in question:
51 | answers.append(vocab['answer_token_to_idx'].get(question['answer']))
52 |
53 | # question padding
54 | max_len = max(len(q) for q in questions)
55 | for i in range(len(questions)):
56 | while len(questions[i]) < max_len:
57 | questions[i].append(vocab['word_token_to_idx'][''])
58 |
59 | if not test:
60 | # function padding
61 | max_len = max(len(f) for f in functions)
62 | for i in range(len(functions)):
63 | while len(functions[i]) < max_len:
64 | functions[i].append(vocab['function_token_to_idx'][''])
65 | func_depends[i].append([-1, -1])
66 | func_inputs[i].append([])
67 | for j in range(max_len):
68 | while len(func_depends[i][j]) < max_dep:
69 | func_depends[i][j].append(-1) # use -1 to pad dependency
70 | while len(func_inputs[i][j]) < max_inp:
71 | func_inputs[i][j].append(vocab['word_token_to_idx'][''])
72 |
73 | questions = np.asarray(questions, dtype=np.int32)
74 | functions = np.asarray(functions, dtype=np.int32)
75 | func_depends = np.asarray(func_depends, dtype=np.int32)
76 | # Because we wrap a before the program, dependencies should shift to the right
77 | # After that, all dependencies >= 0 and 0 means padding
78 | func_depends = func_depends + 1
79 |
80 | func_inputs = np.asarray(func_inputs, dtype=np.int32)
81 | choices = np.asarray(choices, dtype=np.int32)
82 | answers = np.asarray(answers, dtype=np.int32)
83 | return questions, functions, func_depends, func_inputs, choices, answers
84 |
85 |
86 |
87 | def main():
88 | parser = argparse.ArgumentParser()
89 | parser.add_argument('--input_dir', required=True)
90 | parser.add_argument('--output_dir', required=True)
91 | parser.add_argument('--min_cnt', type=int, default=1)
92 | args = parser.parse_args()
93 |
94 |
95 | vocab = {
96 | 'word_token_to_idx': init_vocab(),
97 | 'function_token_to_idx': init_vocab(),
98 | 'answer_token_to_idx': {}
99 | }
100 | print('Load questions')
101 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json')))
102 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json')))
103 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json')))
104 | print('Build question vocabulary')
105 | word_counter = Counter()
106 | for question in train_set:
107 | tokens = word_tokenize(question['question'].lower())
108 | word_counter.update(tokens)
109 | # add candidate answers
110 | for a in question['choices']:
111 | if a not in vocab['answer_token_to_idx']:
112 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
113 | # add functions
114 | for f in question['program']:
115 | a = f['function']
116 | if a not in vocab['function_token_to_idx']:
117 | vocab['function_token_to_idx'][a] = len(vocab['function_token_to_idx'])
118 | word_counter.update(f['inputs'])
119 | # filter low-frequency words
120 | for w, c in word_counter.items():
121 | if w and c >= args.min_cnt and w not in vocab['word_token_to_idx']:
122 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx'])
123 | # add candidate answers of val and test set
124 | for question in chain(val_set, test_set):
125 | for a in question['choices']:
126 | if a not in vocab['answer_token_to_idx']:
127 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
128 |
129 |
130 | if not os.path.isdir(args.output_dir):
131 | os.mkdir(args.output_dir)
132 | fn = os.path.join(args.output_dir, 'vocab.json')
133 | print('Dump vocab to {}'.format(fn))
134 | with open(fn, 'w') as f:
135 | json.dump(vocab, f, indent=2)
136 | for k in vocab:
137 | print('{}:{}'.format(k, len(vocab[k])))
138 |
139 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)):
140 | print('Encode {} set'.format(name))
141 | outputs = encode_dataset(dataset, vocab, test=name=='test')
142 | assert len(outputs) == 6
143 | print('shape of questions, functions, func_depends, func_inputs, choices, answers:')
144 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f:
145 | for o in outputs:
146 | print(o.shape)
147 | pickle.dump(o, f)
148 |
149 |
150 |
151 |
152 | if __name__ == '__main__':
153 | main()
154 |
--------------------------------------------------------------------------------
/Program/readme.md:
--------------------------------------------------------------------------------
1 | ## Requirements
2 | - python3
3 | - pytorch>=1.2.0
4 | - nltk
5 |
6 | ## How to run
7 | 1. Download [GloVe 300d vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip), unzip it to get the file `glove.840B.300d.txt`, and then convert it to a pickle file for faster loading:
8 | ```
9 | python -m utils.pickle_glove --input --output
10 | ```
11 | This step can be skipped if you have obtained the glove pickle file in other models.
12 |
13 | 2. Preprocess the training data, and copy the `./dataset/kb.json` into `output_dir`
14 | ```
15 | python -m Program.preprocess --input_dir ./dataset --output_dir
16 | cp ./dataset/kb.json
17 | ```
18 | 3. Train
19 | ```
20 | python -m Program.train --input_dir --save_dir --glove_pt
21 | ```
22 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order.
23 | ```
24 | python -m Program.predict --input_dir --save_dir
25 | ```
26 |
--------------------------------------------------------------------------------
/Program/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import shutil
7 | from tqdm import tqdm
8 | import numpy as np
9 |
10 | from utils.misc import MetricLogger, load_glove
11 | from .data import DataLoader
12 | from .parser import Parser
13 | from .executor_rule import RuleExecutor
14 |
15 | import logging
16 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
17 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
18 | rootLogger = logging.getLogger()
19 |
20 |
21 | def validate_executor(executor, data):
22 | # validate whether the executor is correct
23 | correct = 0
24 | count = 0
25 | for batch in tqdm(data, total=len(data)):
26 | question, choices, gt_program, gt_dep, gt_inputs, answer = batch
27 | gt_program, gt_dep, gt_inputs = [x.cpu().numpy() for x in (gt_program, gt_dep, gt_inputs)]
28 | answer = [data.vocab['answer_idx_to_token'][a.item()] for a in answer]
29 | preds = []
30 | for i in range(len(gt_program)):
31 | pred = executor.forward(gt_program[i], gt_inputs[i], ignore_error=True)
32 | if pred == answer[i]:
33 | correct += 1
34 | else:
35 | print(pred, answer[i])
36 | pred = executor.forward(gt_program[i], gt_dep[i], gt_inputs[i], ignore_error=True, show_details=True)
37 | embed()
38 | count += 1
39 | if count >= 10000:
40 | break
41 | print('{}/{}/{:.4f}'.format(correct, count, correct/count))
42 |
43 |
44 | def validate(model, data, device, executor=None):
45 | model.eval()
46 | end_id = data.vocab['function_token_to_idx']['']
47 | match_prog_num = 0
48 | match_dep_num = 0
49 | match_inp_num = 0
50 | match_all_num = 0
51 | correct = 0
52 | count = 0
53 | with torch.no_grad():
54 | for batch in tqdm(data, total=len(data)):
55 | question, choices, gt_program, gt_dep, gt_inputs, answer = [x.to(device) for x in batch]
56 | pred_program, pred_inputs = model(question)
57 |
58 | gt_program, gt_inputs = [x.cpu().numpy() for x in (gt_program, gt_inputs)]
59 | pred_program, pred_inputs = [x.cpu().numpy() for x in (pred_program, pred_inputs)]
60 |
61 | for i in range(len(gt_program)):
62 |
63 | # print(gt_program[i])
64 | # print(gt_inputs[i])
65 | # print('---')
66 | # print(pred_program[i])
67 | # print(pred_inputs[i])
68 | # print('==========')
69 |
70 | match = True
71 | for j in range(min(len(gt_program[i]), len(pred_program[i]))):
72 | if gt_program[i, j] != pred_program[i, j]:
73 | match = False
74 | break
75 | if gt_program[i, j] == end_id and pred_program[i, j] == end_id:
76 | l = j
77 | break
78 | if match:
79 | match_prog_num += 1
80 | if np.all(gt_inputs[i,1:l,:]==pred_inputs[i,1:l,:]):
81 | match_inp_num += 1
82 |
83 | count += len(gt_program)
84 |
85 | if executor:
86 | answer = [data.vocab['answer_idx_to_token'][a.item()] for a in answer]
87 | for i in range(len(gt_program)):
88 | pred = executor.forward(pred_program[i], pred_inputs[i], ignore_error=True)
89 | if pred == answer[i]:
90 | correct += 1
91 |
92 | logging.info('\nValid match program: {:.4f}, inputs: {:.4f}\n'.format(
93 | match_prog_num / count,
94 | match_inp_num / count,
95 | ))
96 | if executor:
97 | logging.info('Accuracy: {:.4f}\n'.format(correct / count))
98 | return correct / count
99 | else:
100 | return None
101 |
102 |
103 | def train(args):
104 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
105 |
106 | logging.info("Create train_loader and val_loader.........")
107 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
108 | train_pt = os.path.join(args.input_dir, 'train.pt')
109 | val_pt = os.path.join(args.input_dir, 'val.pt')
110 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True)
111 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size)
112 | vocab = train_loader.vocab
113 |
114 | rule_executor = RuleExecutor(vocab, os.path.join(args.input_dir, 'kb.json'))
115 |
116 | logging.info("Create model.........")
117 | model = Parser(vocab, args.dim_word, args.dim_hidden)
118 | logging.info("Load pretrained word vectors.........")
119 | pretrained = load_glove(args.glove_pt, vocab['word_idx_to_token'])
120 | with torch.no_grad():
121 | model.word_embeddings.weight.set_(torch.Tensor(pretrained))
122 | model = model.to(device)
123 | logging.info(model)
124 | if args.ckpt and os.path.exists(args.ckpt):
125 | logging.info("load ckpt from {}".format(args.ckpt))
126 | model.load_state_dict(torch.load(args.ckpt, map_location={'cuda': 'cpu'}))
127 |
128 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
129 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[5, 50], gamma=0.1)
130 |
131 | # validate_executor(rule_executor, train_loader) # accuracy of val_loader is about 80% due to OOV issue
132 | validate(model, val_loader, device)
133 |
134 | meters = MetricLogger(delimiter=" ")
135 | best_acc = 0
136 | logging.info("Start training........")
137 | for epoch in range(args.num_epoch):
138 | model.train()
139 | for iteration, batch in enumerate(train_loader):
140 | iteration = iteration + 1
141 |
142 | question, choices, program, prog_depends, prog_inputs, answer = [x.to(device) for x in batch]
143 | loss = model(question, program, prog_inputs)
144 | optimizer.zero_grad()
145 | loss.backward()
146 | optimizer.step()
147 | meters.update(loss=loss.item())
148 |
149 | if iteration % (len(train_loader) // 100) == 0:
150 | logging.info(
151 | meters.delimiter.join(
152 | [
153 | "progress: {progress:.3f}",
154 | "{meters}",
155 | "lr: {lr:.6f}",
156 | ]
157 | ).format(
158 | progress=epoch + iteration / len(train_loader),
159 | meters=str(meters),
160 | lr=optimizer.param_groups[0]["lr"],
161 | )
162 | )
163 |
164 | scheduler.step()
165 | if epoch == args.num_epoch-1 or (epoch+1)%5 == 0:
166 | acc = validate(model, val_loader, device, rule_executor)
167 | else:
168 | acc = validate(model, val_loader, device)
169 | if acc and acc > best_acc:
170 | best_acc = acc
171 | logging.info("\nupdate best ckpt with acc: {:.4f}".format(best_acc))
172 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model.pt'))
173 |
174 |
175 | def main():
176 | parser = argparse.ArgumentParser()
177 | # input and output
178 | parser.add_argument('--input_dir', required=True)
179 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
180 | parser.add_argument('--glove_pt', required=True)
181 | parser.add_argument('--ckpt')
182 |
183 | # training parameters
184 | parser.add_argument('--lr', default=0.001, type=float)
185 | parser.add_argument('--weight_decay', default=1e-5, type=float)
186 | parser.add_argument('--num_epoch', default=100, type=int)
187 | parser.add_argument('--batch_size', default=64, type=int)
188 | parser.add_argument('--seed', type=int, default=666, help='random seed')
189 | # model hyperparameters
190 | parser.add_argument('--dim_word', default=300, type=int)
191 | parser.add_argument('--dim_hidden', default=1024, type=int)
192 | args = parser.parse_args()
193 |
194 | # make logging.info display into both shell and file
195 | if os.path.isdir(args.save_dir):
196 | shutil.rmtree(args.save_dir)
197 | os.mkdir(args.save_dir)
198 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt'))
199 | fileHandler.setFormatter(logFormatter)
200 | rootLogger.addHandler(fileHandler)
201 | # args display
202 | for k, v in vars(args).items():
203 | logging.info(k+':'+str(v))
204 |
205 | # set random seed
206 | torch.manual_seed(args.seed)
207 |
208 | train(args)
209 |
210 |
211 | if __name__ == '__main__':
212 | main()
213 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # KQA Pro Baselines
2 | [KQA Pro](https://arxiv.org/abs/2007.03875) is a large-scale dataset of complex question answering over knowledge base, which provides strong supervision of SPARQL and program for each question.
3 | [Here is its homepage website](http://thukeg.gitee.io/kqa-pro/). This dataset is licensed under a Creative Commons Attribution-ShareAlike 4.0 International.
4 |
5 | This repo implements several baselines for the dataset:
6 |
7 | - Blind GRU. It predicts the answer in terms of only the input question, ignoring the knowledge base. We use it to measure the dataset bias.
8 | - [KVMemNN](https://www.aclweb.org/anthology/D16-1147/) (Key-Value Memory Networks)
9 | - [RGCN](https://arxiv.org/abs/1703.06103) (Relational Graph Convolutional Networks)
10 | - [SRN](https://dl.acm.org/doi/10.1145/3336191.3371812) (Stepwise Relational Networks)
11 | - RNN seq2seq SPARQL parser
12 | - RNN seq2seq program parser
13 | - [BART](https://arxiv.org/abs/1910.13461) seq2seq SPARQL parser
14 | - [BART](https://arxiv.org/abs/1910.13461) seq2seq program parser
15 |
16 | Instructions of how to run these models are described in their README files.
17 | Before trying them, you need to first download the [dataset](https://cloud.tsinghua.edu.cn/f/04ce81541e704a648b03/?dl=1) and unzip it into the folder `./dataset`.
18 | The file tree should be like
19 | ```
20 | .
21 | +-- dataset
22 | | +-- kb.json
23 | | +-- train.json
24 | | +-- val.json
25 | | +-- test.json
26 | +-- GRU
27 | | +-- preprocess.py
28 | | +-- train.py
29 | | +-- ...
30 | +-- KVMemNN
31 | +-- RGCN
32 | ...
33 | ```
34 |
--------------------------------------------------------------------------------
/RGCN/README.md:
--------------------------------------------------------------------------------
1 | ## Requirements
2 | - python3
3 | - pytorch>=1.2.0
4 | - nltk
5 | - [dgl>=0.4.3](https://github.com/dmlc/dgl/)
6 |
7 | ## How to run
8 | 1. Download [GloVe 300d vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip), unzip it to get the file `glove.840B.300d.txt`, and then convert it to a pickle file for faster loading:
9 | ```
10 | python -m utils.pickle_glove --input --output
11 | ```
12 | This step can be skipped if you have obtained the glove pickle file in other models.
13 | 2. Preprocess the training data, and copy the `./dataset/kb.json` into `output_dir`
14 | ```
15 | python -m RGCN.preprocess --input_dir ./dataset --output_dir
16 | ```
17 | 3. Train
18 | ```
19 | python -m RGCN.train --input_dir --save_dir --glove_pt
20 | ```
21 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order.
22 | ```
23 | python -m RGCN.predict --input_dir --save_dir
24 | ```
25 |
--------------------------------------------------------------------------------
/RGCN/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import torch
4 | from utils.misc import invert_dict
5 |
6 |
7 | def load_vocab(path):
8 | vocab = json.load(open(path))
9 | vocab['word_idx_to_token'] = invert_dict(vocab['word_token_to_idx'])
10 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx'])
11 | vocab['kb_idx_to_token'] = invert_dict(vocab['kb_token_to_idx'])
12 | vocab['predicate_idx_to_token'] = invert_dict(vocab['predicate_token_to_idx'])
13 | return vocab
14 |
15 | def collate(batch):
16 | batch = list(zip(*batch))
17 | question, choices = list(map(torch.stack, batch[:2]))
18 | if batch[-1][0] is None:
19 | answer = None
20 | else:
21 | answer = torch.cat(batch[-1])
22 | return question, choices, answer
23 |
24 |
25 | class Dataset(torch.utils.data.Dataset):
26 | def __init__(self, inputs):
27 | self.questions, self.choices, self.answers = inputs
28 | self.is_test = len(self.answers)==0
29 |
30 |
31 | def __getitem__(self, index):
32 | question = torch.LongTensor(self.questions[index])
33 | choices = torch.LongTensor(self.choices[index])
34 | if self.is_test:
35 | answer = None
36 | else:
37 | answer = torch.LongTensor([self.answers[index]])
38 | return question, choices, answer
39 |
40 |
41 | def __len__(self):
42 | return len(self.questions)
43 |
44 |
45 | class DataLoader(torch.utils.data.DataLoader):
46 | def __init__(self, vocab_json, kb_pt, question_pt, batch_size, training=False):
47 | vocab = load_vocab(vocab_json)
48 |
49 | inputs = []
50 | with open(question_pt, 'rb') as f:
51 | for _ in range(3):
52 | inputs.append(pickle.load(f))
53 | with open(kb_pt, 'rb') as f:
54 | self.node_descs = torch.LongTensor(pickle.load(f))
55 | self.triples = torch.LongTensor(pickle.load(f))
56 |
57 | dataset = Dataset(inputs)
58 |
59 | super().__init__(
60 | dataset,
61 | batch_size=batch_size,
62 | shuffle=training,
63 | collate_fn=collate,
64 | )
65 | self.vocab = vocab
66 |
67 |
--------------------------------------------------------------------------------
/RGCN/model.py:
--------------------------------------------------------------------------------
1 | """
2 | Refer to https://github.com/dmlc/dgl/tree/master/examples/pytorch/rgcn
3 | """
4 | import math
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from dgl import DGLGraph
9 | from dgl.nn.pytorch import RelGraphConv
10 |
11 | from utils.BiGRU import BiGRU
12 |
13 | class RGCN(nn.Module):
14 | def __init__(self, in_dim, h_dim, out_dim, num_rels, num_bases,
15 | num_hidden_layers=2, dropout=0,
16 | use_self_loop=False, use_cuda=True):
17 | super().__init__()
18 | self.in_dim = in_dim
19 | self.h_dim = h_dim
20 | self.out_dim = out_dim
21 | self.num_rels = num_rels
22 | self.num_bases = None if num_bases < 0 else num_bases
23 | self.num_hidden_layers = num_hidden_layers
24 | self.dropout = dropout
25 | self.use_self_loop = use_self_loop
26 | self.use_cuda = use_cuda
27 |
28 | # create rgcn layers
29 | self.build_model()
30 |
31 | def build_model(self):
32 | self.layers = nn.ModuleList()
33 | # i2h
34 | i2h = self.build_input_layer()
35 | if i2h is not None:
36 | self.layers.append(i2h)
37 | # h2h
38 | for idx in range(self.num_hidden_layers):
39 | h2h = self.build_hidden_layer(idx)
40 | self.layers.append(h2h)
41 | # h2o
42 | h2o = self.build_output_layer()
43 | if h2o is not None:
44 | self.layers.append(h2o)
45 |
46 | def build_input_layer(self):
47 | return None
48 |
49 | def build_hidden_layer(self, idx):
50 | return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "basis",
51 | self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
52 | dropout=self.dropout)
53 |
54 | def build_output_layer(self):
55 | return None
56 | # return RelGraphConv(self.h_dim, self.out_dim, self.num_rels, "basis",
57 | # self.num_bases, activation=None,
58 | # self_loop=self.use_self_loop)
59 |
60 | def forward(self, g, h, r, norm=None):
61 | for layer in self.layers:
62 | h = layer(g, h, r, norm)
63 | return h
64 |
65 |
66 | class QuesAnsByRGCN(nn.Module):
67 | def __init__(self, vocab, node_descs, edge_triples,
68 | dim_word, dim_hidden, dim_g, num_bases=1, num_hidden_layers=1):
69 | """
70 | Args:
71 | - edge_triples (np.array) [#triple, 3]
72 | """
73 | super().__init__()
74 | num_rels = len(vocab['predicate_token_to_idx'])
75 | num_desc_word = len(vocab['kb_token_to_idx'])
76 | num_question_word = len(vocab['word_token_to_idx'])
77 | num_class = len(vocab['answer_token_to_idx'])
78 |
79 | self.rgcn = RGCN(dim_g, dim_g, dim_g, num_rels, num_bases, num_hidden_layers)
80 | edge_src = edge_triples[:,0]
81 | edge_type = edge_triples[:,1]
82 | edge_dst = edge_triples[:,2]
83 | self.edge_type = edge_type
84 | self.num_nodes = len(node_descs)
85 | self.node_descs = node_descs # [#node, max_desc]
86 | self.dim_g = dim_g
87 |
88 | self.desc_embeddings = nn.Embedding(num_desc_word, dim_g)
89 | nn.init.normal_(self.desc_embeddings.weight, mean=0, std=1/math.sqrt(dim_g))
90 |
91 | self.input_embeddings = nn.Embedding(num_question_word, dim_word)
92 | nn.init.normal_(self.input_embeddings.weight, mean=0, std=1/math.sqrt(dim_word))
93 |
94 | self.word_dropout = nn.Dropout(0.3)
95 | self.question_encoder = BiGRU(dim_word, dim_hidden, num_layers=1, dropout=0.0)
96 |
97 | # create graph
98 | self.g = DGLGraph()
99 | self.g.add_nodes(self.num_nodes)
100 | self.g.add_edges(edge_src, edge_dst)
101 |
102 | self.lin_h_to_g = nn.Linear(dim_hidden, dim_g)
103 | self.classifier = nn.Sequential(
104 | nn.Linear(dim_g + dim_hidden, 1024),
105 | nn.ReLU(),
106 | nn.Linear(1024, num_class)
107 | )
108 |
109 | for m in self.modules():
110 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
111 | nn.init.kaiming_normal_(m.weight)
112 | if m.bias is not None:
113 | m.bias.data.zero_()
114 |
115 |
116 | def forward(self, questions, only_q=False):
117 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means
118 | question_input = self.word_dropout(self.input_embeddings(questions))
119 | _, question_embeddings, _ = self.question_encoder(question_input, question_lens)
120 | # [bsz, dim_h]
121 |
122 | if only_q:
123 | bsz = question_embeddings.size(0)
124 | device = question_embeddings.device
125 | empty = torch.zeros((bsz, self.dim_g)).to(device)
126 | feat = torch.cat((empty, question_embeddings), dim=1)
127 | logits = self.classifier(feat)
128 | return logits
129 |
130 |
131 | agg_feats = []
132 | bsz = len(questions)
133 | for i in range(bsz):
134 | # construct initial node features
135 | q = question_embeddings[i].view(1, 1, -1) # [1, 1, dim_h]
136 | node_desc_emb = self.word_dropout(self.desc_embeddings(self.node_descs))
137 | # [#node, max_desc, dim_g]
138 | q_g = self.lin_h_to_g(q) # [1, 1, dim_g]
139 | attn = torch.softmax(torch.sum(node_desc_emb * q_g, dim=2), dim=1) # [#node, max_desc]
140 | node_feat = torch.sum(attn.unsqueeze(2) * node_desc_emb, dim=1) # [#node, dim_g]
141 |
142 | # rgcn
143 | node_feat = self.rgcn(self.g, node_feat, self.edge_type) # [#node, dim_g]
144 |
145 | # answer feature
146 | q_g = q_g.view(1, -1) # [1, dim_g]
147 | attn = torch.softmax(torch.sum(node_feat * q_g, dim=1, keepdim=True), dim=0) # [#node, 1]
148 | node_agg = torch.sum(node_feat * attn, dim=0) # [dim_g]
149 | node_agg = torch.cat((node_agg, q.view(-1)), dim=0) # [dim_g+dim_h]
150 | agg_feats.append(node_agg)
151 |
152 | agg_feats = torch.stack(agg_feats) # [bsz, 2*dim_h]
153 | logits = self.classifier(agg_feats)
154 | return logits
155 |
--------------------------------------------------------------------------------
/RGCN/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import json
5 | from tqdm import tqdm
6 |
7 | from .data import DataLoader
8 | from .model import QuesAnsByRGCN
9 |
10 |
11 | def test(args):
12 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
13 |
14 | print('load test data')
15 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
16 | test_pt = os.path.join(args.input_dir, 'test.pt')
17 | kb_pt = os.path.join(args.input_dir, 'kb.pt')
18 | data = DataLoader(vocab_json, kb_pt, test_pt, 4)
19 | vocab = data.vocab
20 |
21 | print('load model')
22 | node_descs = data.node_descs.to(device)
23 | node_descs = node_descs[:, :args.max_desc]
24 | triples = data.triples.to(device)
25 | triples = triples[:args.max_triple]
26 | model = QuesAnsByRGCN(vocab,
27 | node_descs, triples,
28 | args.dim_word, args.dim_hidden, args.dim_g)
29 | model = model.to(device)
30 | model.eval()
31 | model.load_state_dict(torch.load(os.path.join(args.save_dir, 'model.pt')))
32 |
33 | fn_open = open(os.path.join(args.save_dir, 'predict.txt'), 'w')
34 | fn_choice = open(os.path.join(args.save_dir, 'choice_predict.txt'), 'w')
35 | for batch in tqdm(data, total=len(data)):
36 | question, choices, answer = batch
37 | question = question.to(device)
38 | logit = model(question)
39 | logit = logit.detach().cpu()
40 |
41 | for l, c in zip(logit, choices):
42 | a = l.max(0)[1].item()
43 | a = vocab['answer_idx_to_token'][a]
44 | fn_open.write(a + '\n')
45 | # mask for multi-choice
46 | l = torch.softmax(l, 0)
47 | mask = torch.ones((len(l),)).bool()
48 | mask[c] = 0
49 | l[mask] = 0
50 | a = l.max(0)[1].item()
51 | a = vocab['answer_idx_to_token'][a]
52 | fn_choice.write(a + '\n')
53 | fn_open.close()
54 | fn_choice.close()
55 |
56 |
57 | def main():
58 | parser = argparse.ArgumentParser()
59 | # input and output
60 | parser.add_argument('--input_dir', required=True)
61 | parser.add_argument('--save_dir', required=True, help='path to store predictions')
62 |
63 | # model hyperparameters
64 | parser.add_argument('--dim_word', default=300, type=int)
65 | parser.add_argument('--dim_hidden', default=512, type=int)
66 | parser.add_argument('--dim_g', default=32, type=int)
67 | parser.add_argument('--max_desc', default=20, type=int)
68 | parser.add_argument('--max_triple', default=200000, type=int)
69 | args = parser.parse_args()
70 |
71 | test(args)
72 |
73 |
74 | if __name__ == '__main__':
75 | main()
76 |
--------------------------------------------------------------------------------
/RGCN/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pickle
4 | import numpy as np
5 | from nltk import word_tokenize
6 | from collections import Counter
7 | from itertools import chain
8 | from tqdm import tqdm
9 | import argparse
10 |
11 | from utils.load_kb import load_as_graph
12 | from utils.misc import init_vocab
13 |
14 |
15 | def encode_dataset(dataset, vocab):
16 | questions = []
17 | choices = []
18 | answers = []
19 | for question in tqdm(dataset):
20 | q = [vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx'][''])
21 | for w in word_tokenize(question['question'].lower())]
22 | questions.append(q)
23 |
24 | _ = [vocab['answer_token_to_idx'][w] for w in question['choices']]
25 | choices.append(_)
26 |
27 | if 'answer' in question:
28 | answers.append(vocab['answer_token_to_idx'].get(question['answer']))
29 |
30 | # question padding
31 | max_len = max(len(q) for q in questions)
32 | for q in questions:
33 | while len(q) < max_len:
34 | q.append(vocab['word_token_to_idx'][''])
35 |
36 | questions = np.asarray(questions, dtype=np.int32)
37 | choices = np.asarray(choices, dtype=np.int32)
38 | answers = np.asarray(answers, dtype=np.int32)
39 | return questions, choices, answers
40 |
41 |
42 |
43 |
44 | def main():
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--input_dir', required=True)
47 | parser.add_argument('--output_dir', required=True)
48 | parser.add_argument('--min_cnt', type=int, default=1)
49 | parser.add_argument('--max_desc', type=int, default=200)
50 | args = parser.parse_args()
51 |
52 |
53 | print('Load and encode kb...')
54 | kb_vocab, node_descs, triples, nodeid2idx, pred2idx = \
55 | load_as_graph(os.path.join(args.input_dir, 'kb.json'), args.max_desc)
56 | node_descs = np.asarray(node_descs)
57 | triples = np.asarray(triples)
58 | print("shape of node_descs and triples:", node_descs.shape, triples.shape)
59 | print(node_descs[-10:])
60 | print(triples[:10])
61 |
62 | with open(os.path.join(args.output_dir, 'kb.pt'), 'wb') as f:
63 | pickle.dump(node_descs, f)
64 | pickle.dump(triples, f)
65 |
66 |
67 | vocab = {
68 | 'kb_token_to_idx': kb_vocab,
69 | 'predicate_token_to_idx': pred2idx,
70 | 'word_token_to_idx': init_vocab(),
71 | 'answer_token_to_idx': {}
72 | }
73 |
74 | print('Load questions')
75 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json')))
76 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json')))
77 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json')))
78 | print('Build question vocabulary')
79 | word_counter = Counter()
80 | for question in train_set:
81 | tokens = word_tokenize(question['question'].lower())
82 | word_counter.update(tokens)
83 | # add candidate answers
84 | for a in question['choices']:
85 | if a not in vocab['answer_token_to_idx']:
86 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
87 | for w, c in word_counter.items():
88 | if w and c >= args.min_cnt and w not in vocab['word_token_to_idx']:
89 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx'])
90 | # add candidate answers of val and test set
91 | for question in chain(val_set, test_set):
92 | for a in question['choices']:
93 | if a not in vocab['answer_token_to_idx']:
94 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
95 |
96 | if not os.path.isdir(args.output_dir):
97 | os.mkdir(args.output_dir)
98 | fn = os.path.join(args.output_dir, 'vocab.json')
99 | print('Dump vocab to {}'.format(fn))
100 | with open(fn, 'w') as f:
101 | json.dump(vocab, f, indent=2)
102 | for k in vocab:
103 | print('{}:{}'.format(k, len(vocab[k])))
104 |
105 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)):
106 | print('Encode {} set'.format(name))
107 | outputs = encode_dataset(dataset, vocab)
108 | print('shape of questions, choices, answers:')
109 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f:
110 | for o in outputs:
111 | print(o.shape)
112 | pickle.dump(o, f)
113 |
114 |
115 |
116 |
117 | if __name__ == '__main__':
118 | main()
119 |
120 |
--------------------------------------------------------------------------------
/RGCN/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.optim as optim
4 | import numpy as np
5 | import random
6 | from tqdm import tqdm
7 | import os
8 | import pickle
9 | import argparse
10 | import shutil
11 |
12 | from utils.misc import MetricLogger, load_glove
13 | from .data import DataLoader
14 | from .model import QuesAnsByRGCN
15 |
16 | import logging
17 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
18 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
19 | rootLogger = logging.getLogger()
20 |
21 |
22 | def validate(model, data, device):
23 | model.eval()
24 | count, correct = 0, 0
25 | with torch.no_grad():
26 | for batch in tqdm(data, total=len(data)):
27 | question, choices, answer = [x.to(device) for x in batch]
28 | logit = model(question)
29 | predict = logit.max(1)[1]
30 | correct += torch.eq(predict, answer).long().sum().item()
31 | count += len(answer)
32 |
33 | acc = correct / count
34 | logging.info('\nValid Accuracy: %.4f\n' % acc)
35 | return acc
36 |
37 |
38 | def train(args):
39 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
40 |
41 | logging.info("Create train_loader and val_loader.........")
42 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
43 | train_pt = os.path.join(args.input_dir, 'train.pt')
44 | val_pt = os.path.join(args.input_dir, 'val.pt')
45 | kb_pt = os.path.join(args.input_dir, 'kb.pt')
46 | train_loader = DataLoader(vocab_json, kb_pt, train_pt, args.batch_size, training=True)
47 | train_loader_large_bsz = DataLoader(vocab_json, kb_pt, train_pt, 128, training=True)
48 | val_loader = DataLoader(vocab_json, kb_pt, val_pt, 128)
49 | vocab = train_loader.vocab
50 |
51 | logging.info("Create model.........")
52 | node_descs = train_loader.node_descs.to(device)
53 | node_descs = node_descs[:, :args.max_desc]
54 | triples = train_loader.triples.to(device)
55 | triples = triples[:args.max_triple]
56 | model = QuesAnsByRGCN(vocab,
57 | node_descs, triples,
58 | args.dim_word, args.dim_hidden, args.dim_g)
59 | logging.info("Load pretrained word vectors.........")
60 | pretrained = load_glove(args.glove_pt, vocab['word_idx_to_token'])
61 | with torch.no_grad():
62 | model.input_embeddings.weight.set_(torch.Tensor(pretrained))
63 | model = model.to(device)
64 | logging.info(model)
65 |
66 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
67 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[5, 20], gamma=0.1)
68 | criterion = nn.CrossEntropyLoss().to(device)
69 |
70 | # validate(model, val_loader, device)
71 | meters = MetricLogger(delimiter=" ")
72 | best_acc = 0
73 | logging.info("Start training........")
74 | for epoch in range(args.num_epoch):
75 | model.train()
76 | if epoch < 2:
77 | _train_loader = train_loader_large_bsz
78 | only_q = True
79 | else:
80 | _train_loader = train_loader
81 | only_q = False
82 | for iteration, batch in enumerate(_train_loader):
83 | iteration = iteration + 1
84 |
85 | question, choices, answer = [x.to(device) for x in batch]
86 | logits = model(question, only_q)
87 | loss = criterion(logits, answer)
88 | optimizer.zero_grad()
89 | loss.backward()
90 | optimizer.step()
91 | meters.update(loss=loss.item())
92 |
93 | if iteration % (len(train_loader) // 1000) == 0:
94 | logging.info(
95 | meters.delimiter.join(
96 | [
97 | "progress: {progress:.3f}",
98 | "{meters}",
99 | "lr: {lr:.6f}",
100 | ]
101 | ).format(
102 | progress=epoch + iteration / len(_train_loader),
103 | meters=str(meters),
104 | lr=optimizer.param_groups[0]["lr"],
105 | )
106 | )
107 |
108 | if epoch == args.num_epoch-1 or (epoch+1)%2 == 0:
109 | acc = validate(model, val_loader, device)
110 | else:
111 | acc = None
112 | scheduler.step()
113 | if acc and acc > best_acc:
114 | best_acc = acc
115 | logging.info("\nupdate best ckpt with acc: {:.4f}".format(best_acc))
116 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model.pt'))
117 |
118 |
119 |
120 | def main():
121 | parser = argparse.ArgumentParser()
122 | # input and output
123 | parser.add_argument('--input_dir', required=True)
124 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
125 | parser.add_argument('--glove_pt', required=True)
126 |
127 | # training parameters
128 | parser.add_argument('--lr', default=0.001, type=float)
129 | parser.add_argument('--weight_decay', default=1e-5, type=float)
130 | parser.add_argument('--num_epoch', default=40, type=int)
131 | parser.add_argument('--batch_size', default=6, type=int)
132 | parser.add_argument('--seed', type=int, default=666, help='random seed')
133 | # model hyperparameters
134 | parser.add_argument('--dim_word', default=300, type=int)
135 | parser.add_argument('--dim_hidden', default=512, type=int)
136 | parser.add_argument('--dim_g', default=32, type=int)
137 | parser.add_argument('--max_desc', default=20, type=int)
138 | parser.add_argument('--max_triple', default=200000, type=int)
139 | args = parser.parse_args()
140 |
141 | # make logging.info display into both shell and file
142 | if not os.path.exists(args.save_dir):
143 | os.makedirs(args.save_dir)
144 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt'))
145 | fileHandler.setFormatter(logFormatter)
146 | rootLogger.addHandler(fileHandler)
147 | # args display
148 | for k, v in vars(args).items():
149 | logging.info(k + ':' + str(v))
150 |
151 | # set random seed
152 | torch.manual_seed(args.seed)
153 |
154 | train(args)
155 |
156 | if __name__ == '__main__':
157 | main()
158 |
159 |
--------------------------------------------------------------------------------
/SPARQL/README.md:
--------------------------------------------------------------------------------
1 | ## Requirements
2 | - python3
3 | - rdflib=4.2.2
4 | ---
5 | **Note:**
6 | After installing rdflib via `pip` or `anaconda` or some other tools, we need to fix some bugs of it.
7 |
8 | First, find your rdflib location. One possible way is to run following codes in ipython
9 | ```
10 | import rdflib
11 | rdflib.__file__
12 | ```
13 | It returns `~/anaconda3/lib/python3.7/site-packages/rdflib/__init__.py` in my computer, so I enter the folder `~/anaconda3/lib/python3.7/site-packages/rdflib`.
14 |
15 | Then open `plugins/sparql/parser.py`, find *Line 68*, replace its code with
16 | ```
17 | if i + 1 < l and (not isinstance(terms[i + 1], str) or terms[i + 1] not in ".,;"):
18 | ```
19 | Remember to keep the original indentation.
20 | Note that *Line 67* is a comment of `# is this bnode the subject of more triplets?`. If your line number is different from mine, you could locate the target line by this comment.
21 |
22 | Finally, open `plugins/serializers/turtle.py`, find *Line 328*, change `use_plain=True` to `use_plain=False`
23 |
24 | ---
25 |
26 | - SPARQLWrapper=1.8.4
27 |
28 | ---
29 | **Note:**
30 | When installing `SPARQLWrapper` with `pip`, it may automatically install another package `keepalive`. You can check whether it is in your environment by
31 | ```
32 | pip show keepalive
33 | ```
34 |
35 | If it is installed, it will cause some problems when we execute a large number of SPARQL queries. Specifically, the available ports will be used out. So we need to manually disable the `keepalive` package. It is okay to directly remove it.
36 | ```
37 | pip uninstall keepalive
38 | ```
39 |
40 | ---
41 |
42 | - Virtuoso backend, refer to the next section
43 |
44 | ## How to install virtuoso backend
45 | The virtuoso backend will start up a web service, we can import our kb into it and then execute SPARQL queries by network requests. We install virtuoso in an Ubuntu 16.04 system. Following are specific steps.
46 |
47 | 1. Download and install virtuoso into our system.
48 | ```
49 | git clone https://github.com/openlink/virtuoso-opensource.git Virtuoso-Opensource
50 | cd Virtuoso-Opensource
51 | git checkout stable/7
52 | sudo apt-get install libtool gawk gperf autoconf automake libtool flex bison m4 make openssl libssl-dev
53 | sudo ./autogen.sh
54 | sudo ./configure
55 | sudo make
56 | sudo make install
57 | ```
58 |
59 | 2. Create a new user for virtuoso service
60 | ```
61 | sudo useradd virtuoso --home /usr/local/virtuoso-opensource
62 | sudo chown -R virtuoso /usr/local/virtuoso-opensource
63 | ```
64 |
65 | 3. Modify some necessary configs:
66 | ```
67 | cd /usr/local/virtuoso-opensource/var/lib/virtuoso/db
68 | sudo vim virtuoso.ini
69 | ```
70 | Find the item `CheckpointInterval`, and change its value from default 60 to 0, to avoid automatical checkpoint process which will cause 404 error.
71 |
72 | 4. Start up the virtuoso service:
73 | ```
74 | sudo -H -u virtuoso ../../../../bin/virtuoso-t -f &
75 | ```
76 | Now you can access the service via the default port 8890.
77 | Enter `[ip]:8890` in a browser, you will see the virtuoso service page.
78 |
79 | 5. Now we can import our kb into virtuoso. Before that, we need to convert our kb to `ttl` format and move it to proper position:
80 | ```
81 | python sparql_engine.py --kb_path .dataset/kb.json --ttl_path .dataset/kb.ttl
82 | sudo chmod 777 .dataset/kb.ttl
83 | sudo mv .dataset/kb.ttl /usr/local/virtuoso-opensource/share/virtuoso/vad
84 | ```
85 |
86 | 6. Enter the interactive terminal of virtuoso:
87 | ```
88 | cd /usr/local/virtuoso-opensource/bin
89 | sudo ./isql
90 | ```
91 |
92 | 7. Import our kb by executing these commands in terminal:
93 | ```
94 | SPARQL CREATE GRAPH <[graph_name]>;
95 | SPARQL CLEAR GRAPH <[graph_name]>;
96 | delete from db.dba.load_list;
97 | ld_dir('/usr/local/virtuoso-opensource/share/virtuoso/vad', 'kb.ttl', '[graph_name]');
98 | rdf_loader_run();
99 | select * from DB.DBA.load_list;
100 | exit;
101 | ```
102 | `[graph_name]` could be any legal string, such as *KQAPro*.
103 | You are success if `rdf_loader_run()` lasts for about 10 seconds.
104 |
105 |
106 | ## How to run
107 | 1. Follow the last section, start up the virtuoso service and import `kb.ttl`. Then you need to open `sparql_engine.py` and find the lines of
108 | ```
109 | virtuoso_address = "http://127.0.0.1:8890/sparql"
110 | virtuoso_graph_uri = 'sjx'
111 | ```
112 | Change `virtuoso_address` to your service url (you can visit it in your browser to check whether it is valid) and change `virtuoso_graph_uri` to your ``.
113 | 2. Preprocess the training data
114 | ```
115 | python -m SPARQL.preprocess --input_dir ./dataset --output_dir
116 | cp ./dataset/kb.json
117 | ```
118 | 3. Train
119 | ```
120 | python -m SPARQL.train --input_dir --save_dir
121 | ```
122 | 4. Predict answers of the test set. It will produce a file named `predict.txt` in the `--save_dir`, storing the predictions of test questions in order.
123 | ```
124 | python -m SPARQL.predict --input_dir --save_dir
125 | ```
126 |
--------------------------------------------------------------------------------
/SPARQL/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import torch
4 | from utils.misc import invert_dict
5 |
6 |
7 | def load_vocab(path):
8 | vocab = json.load(open(path))
9 | vocab['word_idx_to_token'] = invert_dict(vocab['word_token_to_idx'])
10 | vocab['sparql_idx_to_token'] = invert_dict(vocab['sparql_token_to_idx'])
11 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx'])
12 | return vocab
13 |
14 | def collate(batch):
15 | batch = list(zip(*batch))
16 | question = torch.stack(batch[0])
17 | choices = torch.stack(batch[1])
18 | if batch[-1][0] is None:
19 | sparql, answer = None, None
20 | else:
21 | sparql = torch.stack(batch[2])
22 | answer = torch.cat(batch[3])
23 | return question, choices, sparql, answer
24 |
25 |
26 | class Dataset(torch.utils.data.Dataset):
27 | def __init__(self, inputs):
28 | self.questions, self.sparqls, self.choices, self.answers = inputs
29 | self.is_test = len(self.answers)==0
30 |
31 |
32 | def __getitem__(self, index):
33 | question = torch.LongTensor(self.questions[index])
34 | choices = torch.LongTensor(self.choices[index])
35 | if self.is_test:
36 | sparql = None
37 | answer = None
38 | else:
39 | sparql = torch.LongTensor(self.sparqls[index])
40 | answer = torch.LongTensor([self.answers[index]])
41 | return question, choices, sparql, answer
42 |
43 |
44 | def __len__(self):
45 | return len(self.questions)
46 |
47 |
48 | class DataLoader(torch.utils.data.DataLoader):
49 | def __init__(self, vocab_json, question_pt, batch_size, training=False):
50 | vocab = load_vocab(vocab_json)
51 | if training:
52 | print('#vocab of word/sparql/answer: %d/%d/%d' %
53 | (len(vocab['word_token_to_idx']), len(vocab['sparql_token_to_idx']), len(vocab['answer_token_to_idx'])))
54 |
55 | inputs = []
56 | with open(question_pt, 'rb') as f:
57 | for _ in range(4):
58 | inputs.append(pickle.load(f))
59 | dataset = Dataset(inputs)
60 |
61 | super().__init__(
62 | dataset,
63 | batch_size=batch_size,
64 | shuffle=training,
65 | collate_fn=collate,
66 | )
67 | self.vocab = vocab
68 |
69 |
--------------------------------------------------------------------------------
/SPARQL/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from utils.BiGRU import GRU, BiGRU
5 |
6 | class SPARQLParser(nn.Module):
7 | def __init__(self, vocab, dim_word, dim_hidden, max_dec_len):
8 | super().__init__()
9 | num_words = len(vocab['word_token_to_idx'])
10 | num_sparql = len(vocab['sparql_token_to_idx'])
11 | self.vocab = vocab
12 | self.dim_word = dim_word
13 | self.dim_hidden = dim_hidden
14 | self.max_dec_len = max_dec_len
15 |
16 | self.word_embeddings = nn.Embedding(num_words, dim_word)
17 | self.word_dropout = nn.Dropout(0.3)
18 | self.question_encoder = GRU(dim_word, dim_hidden, num_layers=2, dropout=0.2)
19 |
20 | self.sparql_embeddings = nn.Embedding(num_sparql, dim_word)
21 | self.decoder = GRU(dim_word, dim_hidden, num_layers=2, dropout=0.2)
22 |
23 | self.sparql_classifier = nn.Sequential(
24 | nn.Linear(dim_hidden, 1024),
25 | nn.ReLU(),
26 | nn.Linear(1024, num_sparql),
27 | )
28 |
29 | self.att_lin = nn.Linear(dim_hidden, dim_hidden)
30 |
31 | for m in self.modules():
32 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
33 | nn.init.kaiming_normal_(m.weight)
34 | if m.bias is not None:
35 | m.bias.data.zero_()
36 |
37 | def forward(self, questions, sparqls=None):
38 | """
39 | Args:
40 | questions [bsz, max_q]
41 | sparqls [bsz, max_s]
42 | Return:
43 | if sparqls are given, then return losses
44 | else, return predicted sparqls
45 | """
46 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means
47 | q_word_emb = self.word_dropout(self.word_embeddings(questions))
48 | q_word_h, q_embeddings, q_hn = self.question_encoder(q_word_emb, question_lens)
49 | # [bsz, max_q, dim_h], [bsz, dim_h], [num_layers, bsz, dim_h]
50 |
51 | if sparqls is None: # during inference
52 | return self.inference(q_word_h, q_embeddings, q_hn)
53 | else:
54 | return self.train_phase(q_word_h, q_embeddings, q_hn, sparqls)
55 |
56 |
57 | def train_phase(self, q_word_h, q_embeddings, q_hn, sparqls):
58 | bsz, max_s = sparqls.size(0), sparqls.size(1)
59 | device = sparqls.device
60 | sparql_lens = max_s - sparqls.eq(0).long().sum(dim=1) # 0 means
61 | sparql_mask = sparqls.ne(0).long()
62 |
63 | s_word_emb = self.word_dropout(self.sparql_embeddings(sparqls))
64 | s_word_h, _, _ = self.decoder(s_word_emb, sparql_lens, h_0=q_hn) # [bsz, max_s, dim_h]
65 | # attention over question words
66 | attn = torch.softmax(torch.bmm(s_word_h, q_word_h.permute(0, 2, 1)), dim=2) # [bsz, max_s, max_q]
67 | attn_word_h = torch.bmm(attn, q_word_h) # [bsz, max_s, dim_h]
68 | # sum up
69 | s_word_h = s_word_h + attn_word_h # [bsz, max_s, dim_h]
70 |
71 | criterion = nn.CrossEntropyLoss().to(device)
72 | logit = self.sparql_classifier(s_word_h) # [bsz, max_s, num_sparql]
73 | loss = criterion(logit.permute(0, 2, 1)[:,:,:-1], sparqls[:,1:]) # remember to shift the gt
74 |
75 | return loss
76 |
77 |
78 | def inference(self, q_word_h, q_embeddings, q_hn):
79 | """
80 | Predict sparqls
81 | """
82 | bsz = q_word_h.size(0)
83 | device = q_word_h.device
84 | start_id = self.vocab['sparql_token_to_idx']['']
85 | end_id = self.vocab['sparql_token_to_idx']['']
86 |
87 | latest_sparql = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ]
88 | last_h = q_hn
89 | finished = torch.zeros((bsz,)).byte().to(device) # record whether is produced
90 |
91 | # store predictions at each step
92 | sparqls = [latest_sparql]
93 |
94 | for i in range(self.max_dec_len):
95 | s_word_emb = self.word_dropout(self.sparql_embeddings(latest_sparql)).unsqueeze(1) # [bsz, 1, dim_w]
96 | s_word_h, last_h = self.decoder.forward_one_step(s_word_emb, last_h) # [bsz, 1, dim_h]
97 | # attention over question words
98 | attn = torch.softmax(torch.bmm(s_word_h, q_word_h.permute(0, 2, 1)), dim=2) # [bsz, 1, max_q]
99 | attn_word_h = torch.bmm(attn, q_word_h) # [bsz, 1, dim_h]
100 | # sum up
101 | s_word_h = s_word_h + attn_word_h # [bsz, 1, dim_h]
102 |
103 | logit = self.sparql_classifier(s_word_h).squeeze(1) # [bsz, num_sparql]
104 | latest_sparql = torch.argmax(logit, dim=1) # [bsz, ]
105 | sparqls.append(latest_sparql)
106 |
107 | finished = finished | latest_sparql.eq(end_id).byte()
108 | if finished.sum().item() == bsz:
109 | # print('finished at step {}'.format(i))
110 | break
111 |
112 | sparqls = torch.stack(sparqls, dim=1) # [bsz, max_s]
113 |
114 | return sparqls
115 |
--------------------------------------------------------------------------------
/SPARQL/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import json
5 | from tqdm import tqdm
6 |
7 | from utils.load_kb import DataForSPARQL
8 | from .data import DataLoader
9 | from .model import SPARQLParser
10 | from .sparql_engine import get_sparql_answer
11 | from .preprocess import postprocess_sparql_tokens
12 |
13 | import warnings
14 | warnings.simplefilter("ignore") # hide warnings that caused by invalid sparql query
15 |
16 |
17 | def test(args):
18 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
19 |
20 | print('load test data')
21 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
22 | test_pt = os.path.join(args.input_dir, 'test.pt')
23 | data = DataLoader(vocab_json, test_pt, 128, training=False)
24 | vocab = data.vocab
25 | kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json'))
26 |
27 | print('load model')
28 | model = SPARQLParser(vocab, args.dim_word, args.dim_hidden, args.max_dec_len)
29 | model = model.to(device)
30 | model.load_state_dict(torch.load(os.path.join(args.save_dir, 'model.pt')))
31 |
32 | f = open(os.path.join(args.save_dir, 'predict.txt'), 'w')
33 | for batch in tqdm(data, total=len(data)):
34 | question, choices, sparql, answer = batch
35 | question = question.to(device)
36 | pred_sparql = model(question)
37 |
38 | pred_sparql = pred_sparql.cpu().numpy().tolist()
39 | for s in pred_sparql:
40 | s = [vocab['sparql_idx_to_token'][i] for i in s]
41 | end_idx = len(s)
42 | if '' in s:
43 | end_idx = s.index('')
44 | s = ' '.join(s[1:end_idx])
45 | s = postprocess_sparql_tokens(s)
46 | answer = str(get_sparql_answer(s, kb))
47 | f.write(answer + '\n')
48 | f.close()
49 |
50 |
51 |
52 | def main():
53 | parser = argparse.ArgumentParser()
54 | # input and output
55 | parser.add_argument('--input_dir', required=True)
56 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
57 |
58 | # model hyperparameters
59 | parser.add_argument('--dim_word', default=300, type=int)
60 | parser.add_argument('--dim_hidden', default=1024, type=int)
61 | parser.add_argument('--max_dec_len', default=100, type=int)
62 | args = parser.parse_args()
63 |
64 | test(args)
65 |
66 |
67 | if __name__ == '__main__':
68 | main()
69 |
--------------------------------------------------------------------------------
/SPARQL/preprocess.py:
--------------------------------------------------------------------------------
1 | """
2 | We need the last function to help extract the final answer of SPARQL, used in check_sparql
3 | """
4 |
5 | import os
6 | import json
7 | import pickle
8 | import argparse
9 | import numpy as np
10 | from nltk import word_tokenize
11 | from collections import Counter
12 | from itertools import chain
13 | from tqdm import tqdm
14 | import re
15 |
16 | from utils.misc import init_vocab
17 |
18 | def tokenize_sparql(s):
19 | # separate punctuations
20 | s = s.replace('"', ' " ').replace('^^', ' ^^ ')
21 | # NOTE: after decoding, these extra space must be removed
22 | # this may cause some mistakes, but the ratio is very small, about one of thousands
23 | return s.split()
24 |
25 | def postprocess_sparql_tokens(s):
26 | # organize the predicted sparql tokens into a valid query
27 | s = s.replace(' ^^ ', '^^')
28 | skip_idxs = set()
29 | for i in range(len(s)):
30 | if s[i] == '"':
31 | if i > 2 and s[i-1]==' ' and s[i-2] not in {'>'}:
32 | skip_idxs.add(i-1)
33 | if i < len(s)-2 and s[i+1]==' ' and s[i+2] not in {'<'}:
34 | skip_idxs.add(i+1)
35 | s = ''.join([s[i] for i in range(len(s)) if i not in skip_idxs])
36 | return s
37 |
38 | def encode_dataset(dataset, vocab, test=False):
39 | questions = []
40 | sparqls = []
41 | choices = []
42 | answers = []
43 | for question in tqdm(dataset):
44 | q = [vocab['word_token_to_idx'].get(w, vocab['word_token_to_idx'][''])
45 | for w in word_tokenize(question['question'].lower())]
46 | questions.append(q)
47 |
48 | _ = [vocab['answer_token_to_idx'][w] for w in question['choices']]
49 | choices.append(_)
50 |
51 | if test:
52 | continue
53 |
54 | _ = [vocab['sparql_token_to_idx'].get(w, vocab['sparql_token_to_idx'][''])
55 | for w in tokenize_sparql(question['sparql'])]
56 | # wrap with
57 | _ = [vocab['sparql_token_to_idx']['']] + _ + [vocab['sparql_token_to_idx']['']]
58 | sparqls.append(_)
59 |
60 | if 'answer' in question:
61 | answers.append(vocab['answer_token_to_idx'].get(question['answer']))
62 |
63 | # question padding
64 | max_len = max(len(q) for q in questions)
65 | for q in questions:
66 | while len(q) < max_len:
67 | q.append(vocab['word_token_to_idx'][''])
68 | if not test:
69 | # sparql padding
70 | max_len = max(len(s) for s in sparqls)
71 | for s in sparqls:
72 | while len(s) < max_len:
73 | s.append(vocab['sparql_token_to_idx'][''])
74 |
75 | questions = np.asarray(questions, dtype=np.int32)
76 | sparqls = np.asarray(sparqls, dtype=np.int32)
77 | choices = np.asarray(choices, dtype=np.int32)
78 | answers = np.asarray(answers, dtype=np.int32)
79 | return questions, sparqls, choices, answers
80 |
81 |
82 |
83 | def main():
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument('--input_dir', required=True)
86 | parser.add_argument('--output_dir', required=True)
87 | parser.add_argument('--min_cnt', type=int, default=1)
88 | args = parser.parse_args()
89 |
90 |
91 | print('Build kb vocabulary')
92 | vocab = {
93 | 'word_token_to_idx': init_vocab(),
94 | 'sparql_token_to_idx': init_vocab(),
95 | 'answer_token_to_idx': {}
96 | }
97 | print('Load questions')
98 | train_set = json.load(open(os.path.join(args.input_dir, 'train.json')))
99 | val_set = json.load(open(os.path.join(args.input_dir, 'val.json')))
100 | test_set = json.load(open(os.path.join(args.input_dir, 'test.json')))
101 | print('Build question vocabulary')
102 | word_counter = Counter()
103 | for question in train_set:
104 | tokens = word_tokenize(question['question'].lower())
105 | word_counter.update(tokens)
106 | # add candidate answers
107 | for a in question['choices']:
108 | if a not in vocab['answer_token_to_idx']:
109 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
110 | # add sparql
111 | for a in tokenize_sparql(question['sparql']):
112 | if a not in vocab['sparql_token_to_idx']:
113 | vocab['sparql_token_to_idx'][a] = len(vocab['sparql_token_to_idx'])
114 |
115 | # filter low-frequency words
116 | for w, c in word_counter.items():
117 | if w and c >= args.min_cnt and w not in vocab['word_token_to_idx']:
118 | vocab['word_token_to_idx'][w] = len(vocab['word_token_to_idx'])
119 | # add candidate answers of val and test set
120 | for question in chain(val_set, test_set):
121 | for a in question['choices']:
122 | if a not in vocab['answer_token_to_idx']:
123 | vocab['answer_token_to_idx'][a] = len(vocab['answer_token_to_idx'])
124 |
125 |
126 | if not os.path.isdir(args.output_dir):
127 | os.mkdir(args.output_dir)
128 | fn = os.path.join(args.output_dir, 'vocab.json')
129 | print('Dump vocab to {}'.format(fn))
130 | with open(fn, 'w') as f:
131 | json.dump(vocab, f, indent=2)
132 | for k in vocab:
133 | print('{}:{}'.format(k, len(vocab[k])))
134 |
135 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)):
136 | print('Encode {} set'.format(name))
137 | outputs = encode_dataset(dataset, vocab, name=='test')
138 | assert len(outputs) == 4
139 | print('shape of questions, sparqls, choices, answers:')
140 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f:
141 | for o in outputs:
142 | print(o.shape)
143 | pickle.dump(o, f)
144 |
145 |
146 |
147 |
148 |
149 | if __name__ == '__main__':
150 | main()
151 |
--------------------------------------------------------------------------------
/SPARQL/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import shutil
7 | import json
8 | from tqdm import tqdm
9 | from datetime import date
10 |
11 | from utils.misc import MetricLogger
12 | from utils.load_kb import DataForSPARQL
13 | from .data import DataLoader
14 | from .model import SPARQLParser
15 | from .sparql_engine import get_sparql_answer
16 | from .preprocess import postprocess_sparql_tokens
17 |
18 | import logging
19 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
20 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
21 | rootLogger = logging.getLogger()
22 | import warnings
23 | warnings.simplefilter("ignore") # hide warnings that caused by invalid sparql query
24 |
25 | def whether_equal(answer, pred):
26 | """
27 | check whether the two arguments are equal as attribute value
28 | """
29 | def truncate_float(x):
30 | # convert answer from '100.0 meters' to '100 meters'
31 | try:
32 | v, *u = x.split()
33 | v = float(v)
34 | if v - int(v) < 1e-5:
35 | v = int(v)
36 | if len(u) == 0:
37 | x = str(v)
38 | else:
39 | x = '{} {}'.format(str(v), ' '.join(u))
40 | except:
41 | pass
42 | return x
43 |
44 | def equal_as_date(x, y):
45 | # check whether x and y are equal as type of date or year
46 | try:
47 | x_split = x.split('-')
48 | y_split = y.split('-')
49 | if len(x_split) == 3:
50 | x = date(int(x_split[0]), int(x_split[1]), int(x_split[2]))
51 | else:
52 | x = int(x)
53 | if len(y_split) == 3:
54 | y = date(int(y_split[0]), int(y_split[1]), int(y_split[2]))
55 | else:
56 | y = int(y)
57 | if isinstance(x, date) and isinstance(y, date):
58 | return x == y
59 | else:
60 | x = x.year if isinstance(x, date) else x
61 | y = y.year if isinstance(y, date) else y
62 | return x == y
63 | except:
64 | return False
65 |
66 | answer = truncate_float(answer)
67 | pred = truncate_float(pred)
68 | if equal_as_date(answer, pred):
69 | return True
70 | else:
71 | return answer == pred
72 |
73 |
74 | def validate(args, kb, model, data, device):
75 | model.eval()
76 | count, correct = 0, 0
77 | with torch.no_grad():
78 | for batch in tqdm(data, total=len(data)):
79 | question, choices, sparql, answer = [x.to(device) for x in batch]
80 | pred_sparql = model(question)
81 |
82 | answer, pred_sparql = [x.cpu().numpy().tolist() for x in (answer, pred_sparql)]
83 | for a, s in zip(answer, pred_sparql):
84 | given_answer = data.vocab['answer_idx_to_token'][a]
85 | s = [data.vocab['sparql_idx_to_token'][i] for i in s]
86 | end_idx = len(s)
87 | if '' in s:
88 | end_idx = s.index('')
89 | s = ' '.join(s[1:end_idx])
90 | s = postprocess_sparql_tokens(s)
91 | pred_answer = get_sparql_answer(s, kb)
92 | is_match = whether_equal(given_answer, pred_answer)
93 | if is_match:
94 | correct += 1
95 | count += len(answer)
96 | acc = correct / count
97 | logging.info('\nValid Accuracy: %.4f\n' % acc)
98 | return acc
99 |
100 | def test_sparql(args):
101 | # check whether the SPARQL engine is correct, with the training set
102 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
103 | train_pt = os.path.join(args.input_dir, 'train.pt')
104 | data = DataLoader(vocab_json, train_pt, args.batch_size, training=False)
105 | kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json'))
106 |
107 | count, correct = 0, 0
108 | for batch in tqdm(data, total=len(data)):
109 | question, choices, sparql, answer = batch
110 | pred_sparql = sparql
111 |
112 | answer = answer.cpu().numpy().tolist()
113 | pred_sparql = pred_sparql.cpu().numpy().tolist()
114 | for a, s in zip(answer, pred_sparql):
115 | given_answer = data.vocab['answer_idx_to_token'][a]
116 | s = [data.vocab['sparql_idx_to_token'][i] for i in s]
117 | end_idx = len(s)
118 | if '' in s:
119 | end_idx = s.index('')
120 | s = ' '.join(s[1:end_idx])
121 | s = postprocess_sparql_tokens(s)
122 | pred_answer = get_sparql_answer(s, kb)
123 | is_match = whether_equal(given_answer, pred_answer)
124 | count += 1
125 | if is_match:
126 | correct += 1
127 | else:
128 | print(given_answer, pred_answer)
129 |
130 | def train(args):
131 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
132 |
133 | logging.info("Create train_loader and val_loader.........")
134 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
135 | train_pt = os.path.join(args.input_dir, 'train.pt')
136 | val_pt = os.path.join(args.input_dir, 'val.pt')
137 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True)
138 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size)
139 | vocab = train_loader.vocab
140 | kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json'))
141 |
142 | logging.info("Create model.........")
143 | model = SPARQLParser(vocab, args.dim_word, args.dim_hidden, args.max_dec_len)
144 | model = model.to(device)
145 | logging.info(model)
146 |
147 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
148 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[5, 50], gamma=0.1)
149 |
150 | # validate(args, kb, model, val_loader, device)
151 | meters = MetricLogger(delimiter=" ")
152 | best_acc = 0
153 | logging.info("Start training........")
154 | for epoch in range(args.num_epoch):
155 | model.train()
156 | for iteration, batch in enumerate(train_loader):
157 | iteration = iteration + 1
158 |
159 | question, choices, sparql, answer = [x.to(device) for x in batch]
160 | loss = model(question, sparql)
161 | optimizer.zero_grad()
162 | loss.backward()
163 | optimizer.step()
164 | meters.update(loss=loss.item())
165 |
166 | if iteration % (len(train_loader) // 100) == 0:
167 | logging.info(
168 | meters.delimiter.join(
169 | [
170 | "progress: {progress:.3f}",
171 | "{meters}",
172 | "lr: {lr:.6f}",
173 | ]
174 | ).format(
175 | progress=epoch + iteration / len(train_loader),
176 | meters=str(meters),
177 | lr=optimizer.param_groups[0]["lr"],
178 | )
179 | )
180 |
181 | acc = validate(args, kb, model, val_loader, device)
182 | scheduler.step()
183 | if acc and acc > best_acc:
184 | best_acc = acc
185 | logging.info("\nupdate best ckpt with acc: {:.4f}".format(best_acc))
186 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model.pt'))
187 |
188 |
189 | def main():
190 | parser = argparse.ArgumentParser()
191 | # input and output
192 | parser.add_argument('--input_dir', required=True)
193 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
194 |
195 | # training parameters
196 | parser.add_argument('--lr', default=0.001, type=float)
197 | parser.add_argument('--weight_decay', default=1e-5, type=float)
198 | parser.add_argument('--num_epoch', default=100, type=int)
199 | parser.add_argument('--batch_size', default=64, type=int)
200 | parser.add_argument('--seed', type=int, default=666, help='random seed')
201 | # model hyperparameters
202 | parser.add_argument('--dim_word', default=300, type=int)
203 | parser.add_argument('--dim_hidden', default=1024, type=int)
204 | parser.add_argument('--max_dec_len', default=100, type=int)
205 | args = parser.parse_args()
206 |
207 | # make logging.info display into both shell and file
208 | if os.path.isdir(args.save_dir):
209 | shutil.rmtree(args.save_dir)
210 | os.mkdir(args.save_dir)
211 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt'))
212 | fileHandler.setFormatter(logFormatter)
213 | rootLogger.addHandler(fileHandler)
214 | # args display
215 | for k, v in vars(args).items():
216 | logging.info(k+':'+str(v))
217 |
218 | # set random seed
219 | torch.manual_seed(args.seed)
220 |
221 | train(args)
222 | # test_sparql(args)
223 |
224 |
225 | if __name__ == '__main__':
226 | main()
227 |
--------------------------------------------------------------------------------
/SRN/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import torch
4 | from utils.misc import invert_dict
5 |
6 |
7 | def load_vocab(path):
8 | vocab = json.load(open(path))
9 | vocab['id2word'] = invert_dict(vocab['word2id'])
10 | vocab['id2entity'] = invert_dict(vocab['entity2id'])
11 | vocab['id2relation'] = invert_dict(vocab['relation2id'])
12 | # vocab['entity2name'] = invert_dict(vocab['name2entity'])
13 | return vocab
14 |
15 | def collate(batch):
16 | batch = list(zip(*batch))
17 | question, topic_entity, answer = list(map(torch.stack, batch))
18 | return question, topic_entity, answer
19 |
20 |
21 | class Dataset(torch.utils.data.Dataset):
22 | def __init__(self, inputs):
23 | self.questions, self.topic_entities, self.answers = inputs
24 | print(self.questions.shape)
25 | print(self.topic_entities.shape)
26 | print(self.answers.shape)
27 |
28 | def __getitem__(self, index):
29 | question = torch.LongTensor(self.questions[index])
30 | topic_entity = torch.LongTensor(self.topic_entities[index])
31 | answer = torch.LongTensor(self.answers[index])
32 | return question, topic_entity, answer
33 |
34 |
35 | def __len__(self):
36 | return len(self.questions)
37 |
38 |
39 | class DataLoader(torch.utils.data.DataLoader):
40 | def __init__(self, vocab_json, question_pt, batch_size, training=False):
41 | vocab = load_vocab(vocab_json)
42 |
43 | inputs = []
44 | with open(question_pt, 'rb') as f:
45 | for _ in range(3):
46 | inputs.append(pickle.load(f))
47 | dataset = Dataset(inputs)
48 |
49 | super().__init__(
50 | dataset,
51 | batch_size=batch_size,
52 | shuffle=training,
53 | collate_fn=collate,
54 | )
55 | self.vocab = vocab
56 |
57 |
--------------------------------------------------------------------------------
/SRN/knowledge_graph.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import pickle
4 | from collections import defaultdict
5 | import torch
6 | import torch.nn as nn
7 | from utils.misc import *
8 |
9 | class KnowledgeGraph(nn.Module):
10 | def __init__(self, args, vocab):
11 | super(KnowledgeGraph, self).__init__()
12 | self.args = args
13 | self.entity2id, self.id2entity = vocab['entity2id'], vocab['id2entity']
14 | self.relation2id, self.id2relation = vocab['relation2id'], vocab['id2relation']
15 | self.adj_list = None
16 | self.action_space = None
17 | self.action_mask = None
18 | self.bandwidth = args.bandwidth
19 | with open(os.path.join(args.input_dir, 'adj_list.pt'), 'rb') as f:
20 | self.adj_list = pickle.load(f)
21 | self.vectorize_action_space()
22 | self.relation_embeddings = nn.Embedding(self.num_relations, args.dim_hidden)
23 | nn.init.xavier_normal_(self.relation_embeddings.weight)
24 |
25 |
26 | def vectorize_action_space(self):
27 | def load_pgrk_score():
28 | pgrk_scores = defaultdict(float)
29 | with open(os.path.join(self.args.input_dir, 'pgrk.txt')) as f:
30 | for line in f:
31 | e, score = line.strip().split(':')
32 | pgrk_scores[(int)(e)] = float(score)
33 | return pgrk_scores
34 |
35 | page_rank_scores = load_pgrk_score()
36 |
37 | def get_action_space(e1):
38 | action_space = []
39 | if e1 in self.adj_list:
40 | for r in self.adj_list[e1]:
41 | targets = self.adj_list[e1][r]
42 | for e2 in targets:
43 | action_space.append((r, e2))
44 | if len(action_space) + 1 >= self.bandwidth:
45 | # Base graph pruning
46 | sorted_action_space = \
47 | sorted(action_space, key=lambda x: page_rank_scores[x[1]], reverse=True)
48 | action_space = sorted_action_space[:self.bandwidth]
49 | action_space.insert(0, (NO_OP_RELATION_ID, e1))
50 | return action_space
51 |
52 | def vectorize_action_space(action_space_list, action_space_size):
53 | bucket_size = len(action_space_list)
54 | r_space = torch.zeros(bucket_size, action_space_size) + self.dummy_r
55 | e_space = torch.zeros(bucket_size, action_space_size) + self.dummy_e
56 | action_mask = torch.zeros(bucket_size, action_space_size)
57 | for i, action_space in enumerate(action_space_list):
58 | for j, (r, e) in enumerate(action_space):
59 | r_space[i, j] = r
60 | e_space[i, j] = e
61 | action_mask[i, j] = 1
62 | return (r_space.long(), e_space.long()), action_mask
63 |
64 | self.action_space_buckets = {}
65 | action_space_buckets_discrete = defaultdict(list)
66 | self.entity2bucketid = torch.zeros(self.num_entities, 2).long()
67 | num_facts_saved_in_action_table = 0
68 | for e1 in range(self.num_entities):
69 | action_space = get_action_space(e1)
70 | key = int(len(action_space) / self.args.bucket_interval) + 1
71 | self.entity2bucketid[e1, 0] = key
72 | self.entity2bucketid[e1, 1] = len(action_space_buckets_discrete[key])
73 | action_space_buckets_discrete[key].append(action_space)
74 | num_facts_saved_in_action_table += len(action_space)
75 | print('Sanity check: {} facts saved in action table'.format(num_facts_saved_in_action_table - self.num_entities))
76 | for key in action_space_buckets_discrete:
77 | self.action_space_buckets[key] = vectorize_action_space(action_space_buckets_discrete[key], key * self.args.bucket_interval)
78 | print('Vectorize action spaces bucket {} with size {} finished'.format(key, len(self.action_space_buckets[key][-1])))
79 | print('Sanity check: {} action space bucket in total'.format(len(self.action_space_buckets)))
80 |
81 |
82 | @property
83 | def num_entities(self):
84 | return len(self.entity2id)
85 |
86 | @property
87 | def num_relations(self):
88 | return len(self.relation2id)
89 |
90 | @property
91 | def self_edge(self):
92 | return NO_OP_RELATION_ID
93 |
94 | @property
95 | def self_e(self):
96 | return NO_OP_ENTITY_ID
97 |
98 | @property
99 | def dummy_r(self):
100 | return DUMMY_RELATION_ID
101 |
102 | @property
103 | def dummy_e(self):
104 | return DUMMY_ENTITY_ID
105 |
106 | @property
107 | def dummy_start_r(self):
108 | return START_RELATION_ID
109 |
--------------------------------------------------------------------------------
/SRN/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import shutil
7 | from tqdm import tqdm
8 |
9 | from utils.misc import MetricLogger
10 | from SRN.data import DataLoader
11 | from SRN.model import SRN
12 |
13 | import logging
14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
15 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
16 | rootLogger = logging.getLogger()
17 |
18 | torch.set_num_threads(1) # avoid using multiple cpus
19 |
20 | def validate(args, vocab, model, data, device):
21 | def write(f, predict):
22 | predict = predict.squeeze().tolist()
23 | for i in predict:
24 | f.write(vocab['id2entity'][i] + '\n')
25 | model.eval()
26 | count, correct = 0, 0
27 | f1 = open(os.path.join(args.save_dir, 'predict.txt'), 'w')
28 | with torch.no_grad():
29 | for batch in tqdm(data, total=len(data)):
30 | questions, topic_entities, answers = [x.to(device) for x in batch]
31 | predict = model(questions, topic_entities)
32 |
33 | pred_e2s = predict['pred_e2s']
34 | pred_e2_scores = predict['pred_e2_scores']
35 | search_traces = predict['search_traces']
36 | pred_top_e2 = pred_e2s[:, 0].unsqueeze(-1) # [bsz, beam_size] => [bsz] => [bsz, 1]
37 | write(f1, pred_top_e2)
38 | correct += torch.any(pred_top_e2 == answers, dim=1).float().sum().item()
39 | count += len(answers)
40 | acc = correct / count
41 | f1.close()
42 | logging.info('\nValid Accuracy: %.4f\n' % acc)
43 | return acc
44 |
45 | def train(args):
46 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
47 |
48 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
49 | train_pt = os.path.join(args.input_dir, 'train.pt')
50 | val_pt = os.path.join(args.input_dir, 'val.pt')
51 | test_pt = os.path.join(args.input_dir, 'test.pt')
52 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True)
53 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size)
54 | test_loader = DataLoader(vocab_json, test_pt, args.batch_size)
55 | vocab = train_loader.vocab
56 |
57 | model = SRN(args, args.dim_word, args.dim_hidden, vocab)
58 | model.load_state_dict(torch.load(args.ckpt))
59 | model = model.to(device)
60 | validate(args, vocab, model, test_loader, device)
61 |
62 |
63 |
64 |
65 | def main():
66 | parser = argparse.ArgumentParser()
67 | # input and output
68 | parser.add_argument('--input_dir', required=True)
69 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
70 | parser.add_argument('--glove_pt', default='/data/csl/resources/word2vec/glove.840B.300d.py36.pt')
71 |
72 | # training parameters
73 | parser.add_argument('--lr', default=0.001, type=float)
74 | parser.add_argument('--weight_decay', default=1e-5, type=float)
75 | parser.add_argument('--num_epoch', default=60, type=int)
76 | parser.add_argument('--batch_size', default=512, type=int)
77 | parser.add_argument('--seed', type=int, default=666, help='random seed')
78 | # model hyperparameters
79 | parser.add_argument('--dim_emb', default=300, type=int)
80 | parser.add_argument('--num_rollout_steps', default=3, type=int)
81 | parser.add_argument('--num_rollouts', default=10, type=int)
82 | parser.add_argument('--dim_word', default=300, type=int)
83 | parser.add_argument('--dim_hidden', default=300, type=int)
84 | parser.add_argument('--bucket_interval', default = 3, type = int)
85 | parser.add_argument('--opt', default = 'adam', type = str)
86 | parser.add_argument('--bandwidth', default = 100, type = int)
87 | parser.add_argument('--gamma', default = 0.95, type = float)
88 | parser.add_argument('--eta', default = 0.95, type = float)
89 | parser.add_argument('--beta', default = 0, type =float)
90 | parser.add_argument('--beam_size', default = 32, type = int)
91 | parser.add_argument('--log_name', default = 'log.txt', type = str)
92 | parser.add_argument('--model_name', default = 'model.pt', type = str)
93 | parser.add_argument('--rel', action = 'store_true')
94 | parser.add_argument('--ckpt', required=True)
95 | args = parser.parse_args()
96 |
97 | # set random seed
98 | torch.manual_seed(args.seed)
99 |
100 | train(args)
101 |
102 |
103 | if __name__ == '__main__':
104 | main()
105 |
--------------------------------------------------------------------------------
/SRN/readme.md:
--------------------------------------------------------------------------------
1 | ## Requirements
2 | - python3
3 | - pytorch>=1.2.0
4 | - nltk
5 |
6 | ## How to run
7 | 1. Download [GloVe 300d vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip), unzip it to get the file `glove.840B.300d.txt`, and then convert it to a pickle file for faster loading:
8 | ```
9 | python -m utils.pickle_glove --input --output
10 | ```
11 | This step can be skipped if you have obtained the glove pickle file in other models.
12 |
13 | 2. Preprocess the training data
14 | ```
15 | python -m SRN.preprocess --input_dir ./dataset --output_dir
16 | ```
17 | 3. Train
18 | ```
19 | python -m SRN.train --input_dir --save_dir --glove_pt
20 | ```
21 |
--------------------------------------------------------------------------------
/SRN/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.optim as optim
4 | import torch.nn as nn
5 | import argparse
6 | import shutil
7 | from tqdm import tqdm
8 |
9 | from utils.misc import MetricLogger, load_glove
10 | from SRN.data import DataLoader
11 | from SRN.model import SRN
12 | import copy
13 | import logging
14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
15 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
16 | rootLogger = logging.getLogger()
17 |
18 | torch.set_num_threads(1) # avoid using multiple cpus
19 |
20 | def validate(model, data, device):
21 | model.eval()
22 | count, correct = 0, 0
23 | with torch.no_grad():
24 | for batch in tqdm(data, total=len(data)):
25 | questions, topic_entities, answers = [x.to(device) for x in batch]
26 | predict = model(questions, topic_entities)
27 | pred_e2s = predict['pred_e2s']
28 | pred_e2_scores = predict['pred_e2_scores']
29 | search_traces = predict['search_traces']
30 | pred_top_e2 = pred_e2s[:, 0].unsqueeze(-1) # [bsz, beam_size] => [bsz] => [bsz, 1]
31 | correct += torch.any(pred_top_e2 == answers, dim=1).float().sum().item()
32 | count += len(answers)
33 | acc = correct / count
34 | logging.info('\nValid Accuracy: %.4f' % acc)
35 | return acc
36 |
37 | def train(args):
38 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
39 |
40 | logging.info("Create train_loader, val_loader.........")
41 | vocab_json = os.path.join(args.input_dir, 'vocab.json')
42 | train_pt = os.path.join(args.input_dir, 'train.pt')
43 | val_pt = os.path.join(args.input_dir, 'val.pt')
44 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, training=True)
45 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size)
46 | vocab = train_loader.vocab
47 |
48 | logging.info("Create model.........")
49 | model = SRN(args, args.dim_word, args.dim_hidden, vocab)
50 | logging.info("Load pretrained word vectors.........")
51 | pretrained = load_glove(args.glove_pt, vocab['id2word'])
52 | model.word_embeddings.weight.data = torch.Tensor(pretrained)
53 | model = model.to(device)
54 | logging.info(model)
55 | if args.opt == 'adam':
56 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
57 | elif args.opt == 'sgd':
58 | optimizer = optim.SGD(model.parameters(), args.lr, weight_decay=args.weight_decay)
59 | elif args.opt == 'adagrad':
60 | optimizer = optim.Adagrad(model.parameters(), args.lr, weight_decay=args.weight_decay)
61 | else:
62 | raise NotImplementedError
63 | # scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[3], gamma=0.1)
64 |
65 | validate(model, val_loader, device)
66 | meters = MetricLogger(delimiter=" ")
67 | logging.info("Start training........")
68 | best_model= copy.deepcopy(model.state_dict())
69 | best_acc = 0.0
70 | eps = 0.00001
71 | for epoch in range(args.num_epoch):
72 | model.train()
73 | for iteration, batch in enumerate(train_loader):
74 | iteration = iteration + 1
75 |
76 | question, topic_entity, answer = [x.to(device) for x in batch]
77 | loss, pt_loss = model(question, topic_entity, answer)
78 | optimizer.zero_grad()
79 | loss.backward()
80 | optimizer.step()
81 | meters.update(loss=pt_loss.item())
82 |
83 | if iteration % (len(train_loader) // 100) == 0:
84 | logging.info(
85 | meters.delimiter.join(
86 | [
87 | "progress: {progress:.3f}",
88 | "{meters}",
89 | "lr: {lr:.6f}",
90 | ]
91 | ).format(
92 | progress=epoch + iteration / len(train_loader),
93 | meters=str(meters),
94 | lr=optimizer.param_groups[0]["lr"],
95 | )
96 | )
97 | break
98 |
99 |
100 | acc = validate(model, val_loader, device)
101 | if acc > best_acc + eps:
102 | best_acc = acc
103 | no_update = 0
104 | best_model = copy.deepcopy(model.state_dict())
105 | logging.info("Validation accuracy increased from previous epoch {}".format(acc))
106 | torch.save(model.state_dict(), os.path.join(args.save_dir, '%s-%s-%s-%s.pt'%(args.opt, str(args.lr), str(args.bandwidth), str(epoch))))
107 | elif (acc < best_acc + eps) and (no_update < args.patience):
108 | no_update +=1
109 | logging.info("Validation accuracy decreases to %f from %f, %d more epoch to check"%(acc, best_acc, args.patience-no_update))
110 | elif no_update == args.patience:
111 | logging.info("Model has exceed patience. Saving best model and exiting")
112 | torch.save(best_model, os.path.join(args.save_dir, "best_score_model.pt"))
113 | exit()
114 |
115 | # acc = validate(model, test_loader, device)
116 | # torch.save(model.state_dict(), os.path.join(args.save_dir, '%s-%s-%d-%.2f'%(args.model_name, args.opt, args.lr, acc)))
117 | # scheduler.step()
118 |
119 |
120 | def main():
121 | parser = argparse.ArgumentParser()
122 | # input and output
123 | parser.add_argument('--input_dir', required=True)
124 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs')
125 | parser.add_argument('--glove_pt', required=True)
126 |
127 | # training parameters
128 | parser.add_argument('--lr', default=0.001, type=float)
129 | parser.add_argument('--weight_decay', default=1e-5, type=float)
130 | parser.add_argument('--num_epoch', default=100, type=int)
131 | parser.add_argument('--batch_size', default=16, type=int)
132 | parser.add_argument('--seed', type=int, default=666, help='random seed')
133 | # model hyperparameters
134 | parser.add_argument('--dim_emb', default=300, type=int)
135 | parser.add_argument('--num_rollout_steps', default=3, type=int)
136 | parser.add_argument('--num_rollouts', default=10, type=int)
137 | parser.add_argument('--dim_word', default=300, type=int)
138 | parser.add_argument('--dim_hidden', default=300, type=int)
139 | parser.add_argument('--bucket_interval', default = 3, type = int)
140 | parser.add_argument('--opt', default = 'adam', type = str)
141 | parser.add_argument('--bandwidth', default = 50, type = int)
142 | parser.add_argument('--gamma', default = 0.95, type = float)
143 | parser.add_argument('--eta', default = 0.95, type = float)
144 | parser.add_argument('--beta', default = 0, type =float)
145 | parser.add_argument('--beam_size', default = 32, type = int)
146 | parser.add_argument('--log_name', default = 'log.txt', type = str)
147 | parser.add_argument('--model_name', default = 'model.pt', type = str)
148 | parser.add_argument('--patience', default = 10, type = int)
149 | args = parser.parse_args()
150 |
151 | # make logging.info display into both shell and file
152 | if not os.path.exists(args.save_dir):
153 | os.makedirs(args.save_dir)
154 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, args.log_name))
155 | fileHandler.setFormatter(logFormatter)
156 | rootLogger.addHandler(fileHandler)
157 | # args display
158 | for k, v in vars(args).items():
159 | logging.info(k+':'+str(v))
160 |
161 | # set random seed
162 | torch.manual_seed(args.seed)
163 |
164 | train(args)
165 |
166 |
167 | if __name__ == '__main__':
168 | main()
169 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import json
4 | from datetime import date
5 | from collections import defaultdict, Counter
6 | from tqdm import tqdm
7 | def whether_equal(answer, pred):
8 | def truncate_float(x):
9 | # convert answer from '100.0 meters' to '100 meters'
10 | try:
11 | v, *u = x.split()
12 | v = float(v)
13 | if v - int(v) < 1e-5:
14 | v = int(v)
15 | if len(u) == 0:
16 | x = str(v)
17 | else:
18 | x = '{} {}'.format(str(v), ' '.join(u))
19 | except:
20 | pass
21 | return x
22 |
23 | def equal_as_date(x, y):
24 | # check whether x and y are equal as type of date or year
25 | try:
26 | x_split = x.split('-')
27 | y_split = y.split('-')
28 | if len(x_split) == 3:
29 | x = date(int(x_split[0]), int(x_split[1]), int(x_split[2]))
30 | else:
31 | x = int(x)
32 | if len(y_split) == 3:
33 | y = date(int(y_split[0]), int(y_split[1]), int(y_split[2]))
34 | else:
35 | y = int(y)
36 | if isinstance(x, date) and isinstance(y, date):
37 | return x == y
38 | else:
39 | x = x.year if isinstance(x, date) else x
40 | y = y.year if isinstance(y, date) else y
41 | return x == y
42 | except:
43 | return False
44 |
45 | answer = truncate_float(answer)
46 | pred = truncate_float(pred)
47 | if equal_as_date(answer, pred):
48 | return True
49 | else:
50 | return answer == pred
51 |
52 |
53 | def load(f):
54 | data = []
55 | for line in f:
56 | data.append(json.loads(line.strip()))
57 | return data
58 | def main():
59 | gt_folder, pred_fn = sys.argv[1], sys.argv[2]
60 |
61 | gt_fn = os.path.join(gt_folder, 'test_answer.json')
62 | gt = json.load(open(gt_fn))
63 | pred = [x.strip() for x in open(pred_fn).readlines()] # one prediction per line
64 | train_set = json.load(open(os.path.join(gt_folder, 'train.json')))
65 | train_answer_set = set(x['answer'] for x in train_set)
66 |
67 | labels = ['overall', 'multihop', 'qualifier', 'comparison', 'logical', 'count', 'verify', 'zero-shot']
68 | total = {k:0 for k in labels}
69 | correct = {k:0 for k in labels}
70 | for i in tqdm(range(len(pred))):
71 | cur_labels = ['overall']
72 | functions = [f['function'] for f in gt[i]['program']]
73 |
74 | for f in functions:
75 | if f in {'Relate'} or f.startswith('Filter'):
76 | cur_labels.append('multihop')
77 | break
78 | for f in functions:
79 | if f in {'QFilterStr', 'QFilterNum', 'QFilterYear', 'QFilterDate', 'QueryAttrUnderCondition', 'QueryAttrQualifier', 'QueryRelationQualifier'}:
80 | cur_labels.append('qualifier')
81 | break
82 | for f in functions:
83 | if f in {'SelectBetween','SelectAmong'}:
84 | cur_labels.append('comparison')
85 | break
86 | for f in functions:
87 | if f in {'And', 'Or'}:
88 | cur_labels.append('logical')
89 | break
90 | for f in functions:
91 | if f in {'Count'}:
92 | cur_labels.append('count')
93 | break
94 | for f in functions:
95 | if f in {'VerifyStr','VerifyNum','VerifyYear','VerifyDate'}:
96 | cur_labels.append('verify')
97 | break
98 |
99 | answer = gt[i]['answer']
100 | if answer not in train_answer_set:
101 | cur_labels.append('zero-shot')
102 |
103 | if whether_equal(answer, pred[i]):
104 | for k in cur_labels:
105 | correct[k] += 1
106 | else:
107 | pass
108 | for k in cur_labels:
109 | total[k] += 1
110 |
111 | for k in labels:
112 | print('{}: {:.2f}% ({}/{})'.format(k, correct[k]/total[k]*100, correct[k], total[k]))
113 | if len(pred) < len(gt):
114 | print('WARNING: there are only {} predictions (need {})'.format(len(pred), len(gt)))
115 |
116 |
117 | if __name__ == '__main__':
118 | main()
119 |
--------------------------------------------------------------------------------
/utils/BiGRU.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class GRU(nn.Module):
5 |
6 | def __init__(self, dim_word, dim_h, num_layers, dropout):
7 | super().__init__()
8 | self.encoder = nn.GRU(input_size=dim_word,
9 | hidden_size=dim_h,
10 | num_layers=num_layers,
11 | dropout=dropout,
12 | batch_first=True,
13 | bidirectional=False)
14 |
15 | def forward_one_step(self, input, last_h):
16 | """
17 | Args:
18 | - input (bsz, 1, w_dim)
19 | - last_h (num_layers, bsz, h_dim)
20 | """
21 | hidden, new_h = self.encoder(input, last_h)
22 | return hidden, new_h # (bsz, 1, h_dim), (num_layers, bsz, h_dim)
23 |
24 |
25 | def generate_sequence(self, word_lookup_func, h_0, classifier, vocab, max_step, early_stop=True):
26 | bsz = h_0.size(1)
27 | device = h_0.device
28 | start_id, end_id, pad_id = vocab[''], vocab[''], vocab['']
29 |
30 | latest = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ]
31 | results = [latest]
32 | last_h = h_0
33 | finished = torch.zeros((bsz,)).bool().to(device) # record whether is produced
34 | for i in range(max_step-1): # exclude
35 | word_emb = word_lookup_func(latest).unsqueeze(1) # [bsz, 1, dim_w]
36 | word_h, last_h = self.forward_one_step(word_emb, last_h) # [bsz, 1, dim_h]
37 |
38 | logit = classifier(word_h).squeeze(1) # [bsz, num_func]
39 | latest = torch.argmax(logit, dim=1).long() # [bsz, ]
40 | latest[finished] = pad_id # set to after
41 | results.append(latest)
42 |
43 | finished = finished | latest.eq(end_id).bool()
44 | if early_stop and finished.sum().item() == bsz:
45 | # print('finished at step {}'.format(i))
46 | break
47 | results = torch.stack(results, dim=1) # [bsz, max_len']
48 | return results
49 |
50 |
51 | def forward(self, input, length, h_0=None):
52 | """
53 | Args:
54 | - input (bsz, len, w_dim)
55 | - length (bsz, )
56 | - h_0 (num_layers, bsz, h_dim)
57 | Return:
58 | - hidden (bsz, len, dim) : hidden state of each word
59 | - output (bsz, dim) : sentence embedding
60 | """
61 | bsz, max_len = input.size(0), input.size(1)
62 | sorted_seq_lengths, indices = torch.sort(length, descending=True)
63 | _, desorted_indices = torch.sort(indices, descending=False)
64 | input = input[indices]
65 | packed_input = nn.utils.rnn.pack_padded_sequence(input, sorted_seq_lengths, batch_first=True)
66 | if h_0 is None:
67 | hidden, h_n = self.encoder(packed_input)
68 | else:
69 | h_0 = h_0[:, indices]
70 | hidden, h_n = self.encoder(packed_input, h_0)
71 | # h_n is (num_layers, bsz, h_dim)
72 | hidden = nn.utils.rnn.pad_packed_sequence(hidden, batch_first=True, total_length=max_len)[0] # (bsz, max_len, h_dim)
73 |
74 | output = h_n[-1, :, :] # (bsz, h_dim), take the last layer's state
75 |
76 | # recover order
77 | hidden = hidden[desorted_indices]
78 | output = output[desorted_indices]
79 | h_n = h_n[:, desorted_indices]
80 | return hidden, output, h_n
81 |
82 |
83 |
84 | class BiGRU(nn.Module):
85 |
86 | def __init__(self, dim_word, dim_h, num_layers, dropout):
87 | super().__init__()
88 | self.encoder = nn.GRU(input_size=dim_word,
89 | hidden_size=dim_h//2,
90 | num_layers=num_layers,
91 | dropout=dropout,
92 | batch_first=True,
93 | bidirectional=True)
94 |
95 | def forward(self, input, length):
96 | """
97 | Args:
98 | - input (bsz, len, w_dim)
99 | - length (bsz, )
100 | Return:
101 | - hidden (bsz, len, dim) : hidden state of each word
102 | - output (bsz, dim) : sentence embedding
103 | - h_n (num_layers * 2, bsz, dim//2)
104 | """
105 | bsz, max_len = input.size(0), input.size(1)
106 | sorted_seq_lengths, indices = torch.sort(length, descending=True)
107 | _, desorted_indices = torch.sort(indices, descending=False)
108 | input = input[indices]
109 | packed_input = nn.utils.rnn.pack_padded_sequence(input, sorted_seq_lengths, batch_first=True)
110 | hidden, h_n = self.encoder(packed_input)
111 | # h_n is (num_layers * num_directions, bsz, h_dim//2)
112 | hidden = nn.utils.rnn.pad_packed_sequence(hidden, batch_first=True, total_length=max_len)[0] # (bsz, max_len, h_dim)
113 |
114 | output = h_n[-2:, :, :] # (2, bsz, h_dim//2), take the last layer's state
115 | output = output.permute(1, 0, 2).contiguous().view(bsz, -1) # (bsz, h_dim), merge forward and backward h_n
116 |
117 | # recover order
118 | hidden = hidden[desorted_indices]
119 | output = output[desorted_indices]
120 | h_n = h_n[:, desorted_indices]
121 | return hidden, output, h_n
122 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, Counter, deque
2 | import torch
3 | import json
4 | import pickle
5 | import numpy as np
6 | import torch.nn as nn
7 | import random
8 | import os
9 | import time
10 | ######################################################
11 | ##################### used in SRN ####################
12 | START_RELATION = 'START_RELATION'
13 | NO_OP_RELATION = 'NO_OP_RELATION'
14 | NO_OP_ENTITY = 'NO_OP_ENTITY'
15 | DUMMY_RELATION = 'DUMMY_RELATION'
16 | DUMMY_ENTITY = 'DUMMY_ENTITY'
17 |
18 | DUMMY_RELATION_ID = 0
19 | START_RELATION_ID = 1
20 | NO_OP_RELATION_ID = 2
21 | DUMMY_ENTITY_ID = 0
22 | NO_OP_ENTITY_ID = 1
23 |
24 | EPSILON = float(np.finfo(float).eps)
25 | HUGE_INT = 1e31
26 |
27 | def format_path(path_trace, id2entity, id2relation):
28 | def get_most_recent_relation(j):
29 | relation_id = int(path_trace[j][0])
30 | if relation_id == NO_OP_RELATION_ID:
31 | return ''
32 | else:
33 | return id2relation[relation_id]
34 |
35 | def get_most_recent_entity(j):
36 | return id2entity[int(path_trace[j][1])]
37 |
38 | path_str = get_most_recent_entity(0)
39 | for j in range(1, len(path_trace)):
40 | rel = get_most_recent_relation(j)
41 | if not rel.endswith('_inv'):
42 | path_str += ' -{}-> '.format(rel)
43 | else:
44 | path_str += ' <-{}- '.format(rel[:-4])
45 | path_str += get_most_recent_entity(j)
46 | return path_str
47 |
48 | def pad_and_cat(a, padding_value, padding_dim=1):
49 | max_dim_size = max([x.size()[padding_dim] for x in a])
50 | padded_a = []
51 | for x in a:
52 | if x.size()[padding_dim] < max_dim_size:
53 | res_len = max_dim_size - x.size()[1]
54 | pad = nn.ConstantPad1d((0, res_len), padding_value)
55 | padded_a.append(pad(x))
56 | else:
57 | padded_a.append(x)
58 | return torch.cat(padded_a, dim=0)
59 |
60 | def safe_log(x):
61 | return torch.log(x + EPSILON)
62 |
63 | def entropy(p):
64 | return torch.sum(- p * safe_log(p), 1)
65 |
66 | def init_word2id():
67 | return {
68 | '': 0,
69 | '': 1,
70 | 'E_S': 2,
71 | }
72 | def init_entity2id():
73 | return {
74 | DUMMY_ENTITY: DUMMY_ENTITY_ID,
75 | NO_OP_ENTITY: NO_OP_ENTITY_ID
76 | }
77 | def init_relation2id():
78 | return {
79 | DUMMY_RELATION: DUMMY_RELATION_ID,
80 | START_RELATION: START_RELATION_ID,
81 | NO_OP_RELATION: NO_OP_RELATION_ID
82 | }
83 |
84 | def add_item_to_x2id(item, x2id):
85 | if not item in x2id:
86 | x2id[item] = len(x2id)
87 |
88 | def tile_along_beam(v, beam_size, dim=0):
89 | """
90 | Tile a tensor along a specified dimension for the specified beam size.
91 | :param v: Input tensor.
92 | :param beam_size: Beam size.
93 | """
94 | if dim == -1:
95 | dim = len(v.size()) - 1
96 | v = v.unsqueeze(dim + 1)
97 | v = torch.cat([v] * beam_size, dim=dim+1)
98 | new_size = []
99 | for i, d in enumerate(v.size()):
100 | if i == dim + 1:
101 | new_size[-1] *= d
102 | else:
103 | new_size.append(d)
104 | return v.view(new_size)
105 | ##################### used in SRN ####################
106 | ######################################################
107 |
108 |
109 |
110 | def init_vocab():
111 | return {
112 | '': 0,
113 | '': 1,
114 | '': 2,
115 | '': 3
116 | }
117 |
118 | def invert_dict(d):
119 | return {v: k for k, v in d.items()}
120 |
121 | def load_glove(glove_pt, idx_to_token):
122 | glove = pickle.load(open(glove_pt, 'rb'))
123 | dim = len(glove['the'])
124 | matrix = []
125 | for i in range(len(idx_to_token)):
126 | token = idx_to_token[i]
127 | tokens = token.split()
128 | if len(tokens) > 1:
129 | v = np.zeros((dim,))
130 | for token in tokens:
131 | v = v + glove.get(token, glove['the'])
132 | v = v / len(tokens)
133 | else:
134 | v = glove.get(token, glove['the'])
135 | matrix.append(v)
136 | matrix = np.asarray(matrix)
137 | return matrix
138 |
139 |
140 | class SmoothedValue(object):
141 | """Track a series of values and provide access to smoothed values over a
142 | window or the global series average.
143 | """
144 |
145 | def __init__(self, window_size=20):
146 | self.deque = deque(maxlen=window_size)
147 | self.series = []
148 | self.total = 0.0
149 | self.count = 0
150 |
151 | def update(self, value):
152 | self.deque.append(value)
153 | self.series.append(value)
154 | self.count += 1
155 | self.total += value
156 |
157 | @property
158 | def median(self):
159 | d = torch.tensor(list(self.deque))
160 | return d.median().item()
161 |
162 | @property
163 | def avg(self):
164 | d = torch.tensor(list(self.deque))
165 | return d.mean().item()
166 |
167 | @property
168 | def global_avg(self):
169 | return self.total / self.count
170 |
171 |
172 | class MetricLogger(object):
173 | def __init__(self, delimiter="\t"):
174 | self.meters = defaultdict(SmoothedValue)
175 | self.delimiter = delimiter
176 |
177 | def update(self, **kwargs):
178 | for k, v in kwargs.items():
179 | if isinstance(v, torch.Tensor):
180 | v = v.item()
181 | assert isinstance(v, (float, int))
182 | self.meters[k].update(v)
183 |
184 | def __getattr__(self, attr):
185 | if attr in self.meters:
186 | return self.meters[attr]
187 | if attr in self.__dict__:
188 | return self.__dict__[attr]
189 | raise AttributeError("'{}' object has no attribute '{}'".format(
190 | type(self).__name__, attr))
191 |
192 | def __str__(self):
193 | loss_str = []
194 | for name, meter in self.meters.items():
195 | loss_str.append(
196 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg)
197 | )
198 | return self.delimiter.join(loss_str)
199 |
200 |
201 | def seed_everything(seed=1029):
202 | '''
203 | 设置整个开发环境的seed
204 | :param seed:
205 | :param device:
206 | :return:
207 | '''
208 | random.seed(seed)
209 | os.environ['PYTHONHASHSEED'] = str(seed)
210 | np.random.seed(seed)
211 | torch.manual_seed(seed)
212 | torch.cuda.manual_seed(seed)
213 | torch.cuda.manual_seed_all(seed)
214 | # some cudnn methods can be random even after fixing the seed
215 | # unless you tell it to be deterministic
216 | torch.backends.cudnn.deterministic = True
217 |
218 |
219 | class ProgressBar(object):
220 | '''
221 | custom progress bar
222 | Example:
223 | >>> pbar = ProgressBar(n_total=30,desc='training')
224 | >>> step = 2
225 | >>> pbar(step=step)
226 | '''
227 | def __init__(self, n_total,width=30,desc = 'Training'):
228 | self.width = width
229 | self.n_total = n_total
230 | self.start_time = time.time()
231 | self.desc = desc
232 |
233 | def __call__(self, step, info={}):
234 | now = time.time()
235 | current = step + 1
236 | recv_per = current / self.n_total
237 | bar = f'[{self.desc}] {current}/{self.n_total} ['
238 | if recv_per >= 1:
239 | recv_per = 1
240 | prog_width = int(self.width * recv_per)
241 | if prog_width > 0:
242 | bar += '=' * (prog_width - 1)
243 | if current< self.n_total:
244 | bar += ">"
245 | else:
246 | bar += '='
247 | bar += '.' * (self.width - prog_width)
248 | bar += ']'
249 | show_bar = f"\r{bar}"
250 | time_per_unit = (now - self.start_time) / current
251 | if current < self.n_total:
252 | eta = time_per_unit * (self.n_total - current)
253 | if eta > 3600:
254 | eta_format = ('%d:%02d:%02d' %
255 | (eta // 3600, (eta % 3600) // 60, eta % 60))
256 | elif eta > 60:
257 | eta_format = '%d:%02d' % (eta // 60, eta % 60)
258 | else:
259 | eta_format = '%ds' % eta
260 | time_info = f' - ETA: {eta_format}'
261 | else:
262 | if time_per_unit >= 1:
263 | time_info = f' {time_per_unit:.1f}s/step'
264 | elif time_per_unit >= 1e-3:
265 | time_info = f' {time_per_unit * 1e3:.1f}ms/step'
266 | else:
267 | time_info = f' {time_per_unit * 1e6:.1f}us/step'
268 |
269 | show_bar += time_info
270 | if len(info) != 0:
271 | show_info = f'{show_bar} ' + \
272 | "-".join([f' {key}: {value:.4f} ' for key, value in info.items()])
273 | print(show_info, end='')
274 | else:
275 | print(show_bar, end='')
--------------------------------------------------------------------------------
/utils/pickle_glove.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import argparse
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 | def main():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--input', required=True)
9 | parser.add_argument('--output', required=True)
10 | args = parser.parse_args()
11 |
12 | res = {}
13 | for line in tqdm(open(args.input, encoding="latin-1")):
14 | word, *vec = line.split()
15 | try:
16 | vec = np.asarray(list(map(float, vec)))
17 | res[word] = vec
18 | except:
19 | print("bad word")
20 |
21 | with open(args.output, 'wb') as f:
22 | pickle.dump(res, f)
23 |
24 |
25 | if __name__ == '__main__':
26 | main()
27 |
--------------------------------------------------------------------------------
/utils/value_class.py:
--------------------------------------------------------------------------------
1 | def comp(a, b, op):
2 | """
3 | Args:
4 | - a (ValueClass): attribute value of a certain entity
5 | - b (ValueClass): comparison target
6 | - op: =/>/!=
7 | Example:
8 | a is someone's birthday, 1960-02-01, b is 1960, op is '=', then return True
9 | """
10 | if b.isTime():
11 | # Note: for time, 'a=b' actually means a in b, 'a!=b' means a not in b
12 | if op == '=':
13 | return b.contains(a)
14 | elif op == '!=':
15 | return not b.contains(a)
16 | if op == '=':
17 | return a == b
18 | elif op == '<':
19 | return a < b
20 | elif op == '>':
21 | return a > b
22 | elif op == '!=':
23 | return a != b
24 |
25 | class ValueClass():
26 | def __init__(self, type, value, unit=None):
27 | """
28 | When type is
29 | - string, value is a str
30 | - quantity, value is a number and unit is required
31 | - year, value is a int
32 | - date, value is a date object
33 | """
34 | self.type = type
35 | self.value = value
36 | self.unit = unit
37 |
38 | def isTime(self):
39 | return self.type in {'year', 'date'}
40 |
41 | def can_compare(self, other):
42 | if self.type == 'string':
43 | return other.type == 'string'
44 | elif self.type == 'quantity':
45 | # NOTE: for two quantity, they can compare only when they have the same unit
46 | return other.type == 'quantity' and other.unit == self.unit
47 | else:
48 | # year can compare with date
49 | return other.type == 'year' or other.type == 'date'
50 |
51 | def contains(self, other):
52 | """
53 | check whether self contains other, which is different from __eq__ and the result is asymmetric
54 | used for conditions like whether 2001-01-01 in 2001, or whether 2001 in 2001-01-01
55 | """
56 | if self.type == 'year': # year can contain year and date
57 | other_value = other.value if other.type == 'year' else other.value.year
58 | return self.value == other_value
59 | elif self.type == 'date': # date can only contain date
60 | return other.type == 'date' and self.value == other.value
61 | else:
62 | raise Exception('not supported type: %s' % self.type)
63 |
64 |
65 | def __eq__(self, other):
66 | """
67 | 2001 and 2001-01-01 is not equal
68 | """
69 | assert self.can_compare(other)
70 | return self.type == other.type and self.value == other.value
71 |
72 | def __lt__(self, other):
73 | """
74 | Comparison between a year and a date will convert them both to year
75 | """
76 | assert self.can_compare(other)
77 | if self.type == 'string':
78 | raise Exception('try to compare two string')
79 | elif self.type == 'quantity':
80 | return self.value < other.value
81 | elif self.type == 'year':
82 | other_value = other.value if other.type == 'year' else other.value.year
83 | return self.value < other_value
84 | elif self.type == 'date':
85 | if other.type == 'year':
86 | return self.value.year < other.value
87 | else:
88 | return self.value < other.value
89 |
90 | def __gt__(self, other):
91 | assert self.can_compare(other)
92 | if self.type == 'string':
93 | raise Exception('try to compare two string')
94 | elif self.type == 'quantity':
95 | return self.value > other.value
96 | elif self.type == 'year':
97 | other_value = other.value if other.type == 'year' else other.value.year
98 | return self.value > other_value
99 | elif self.type == 'date':
100 | if other.type == 'year':
101 | return self.value.year > other.value
102 | else:
103 | return self.value > other.value
104 |
105 | def __str__(self):
106 | if self.type == 'string':
107 | return self.value
108 | elif self.type == 'quantity':
109 | if self.value - int(self.value) < 1e-5:
110 | v = int(self.value)
111 | else:
112 | v = self.value
113 | return '{} {}'.format(v, self.unit) if self.unit != '1' else str(v)
114 | elif self.type == 'year':
115 | return str(self.value)
116 | elif self.type == 'date':
117 | return self.value.isoformat()
118 |
--------------------------------------------------------------------------------