├── .gitignore ├── LICENSE ├── README.md ├── constants.py ├── dataset.py ├── evaluate.py ├── params.json ├── preprocess.py ├── qanet ├── __init__.py ├── context_query_attention.py ├── depthwise_separable_conv.py ├── embedding_encoder.py ├── encoder_block.py ├── highway.py ├── input_embedding.py ├── layer_norm.py ├── model_encoder.py ├── output.py ├── position_encoding.py ├── qanet.py └── self_attention.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | .vscode/* 3 | data/* 4 | tmp/* 5 | log/* 6 | analysis/* 7 | runs/* 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # celery beat schedule file 87 | celerybeat-schedule 88 | 89 | # SageMath parsed files 90 | *.sage.py 91 | 92 | # Environments 93 | .env 94 | .venv 95 | env/ 96 | venv/ 97 | ENV/ 98 | env.bak/ 99 | venv.bak/ 100 | 101 | # Spyder project settings 102 | .spyderproject 103 | .spyproject 104 | 105 | # Rope project settings 106 | .ropeproject 107 | 108 | # mkdocs documentation 109 | /site 110 | 111 | # mypy 112 | .mypy_cache/ 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Hackiey 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Introduction 2 | A Pytorch implementation of [QANet](https://arxiv.org/pdf/1804.09541.pdf) 3 | This repository is based on [NLPLearn/QANet](https://github.com/NLPLearn/QANet) and [marquezo/qanet-impl](https://github.com/marquezo/qanet-impl) 4 | 5 | It can get **em: 70.155** and **f1: 79.432** peformance after 22 epochs(2730 batches per epoch) with EMA. 6 | 7 | ### Requirements 8 | - PyTorch >= 0.4.0 9 | - [torcheras](https://github.com/hackiey/torcheras) 10 | - spacy 11 | - tqdm 12 | 13 | ### Usage 14 | #### Preprocess 15 | ``` 16 | $ mkdir data 17 | $ python preprocess.py 18 | ``` 19 | 20 | #### Train 21 | ``` 22 | $ mkdir log 23 | $ mkdir log/qanet 24 | $ python train.py 'some description' 25 | ``` 26 | 27 | #### Evaluate 28 | First set the log folder and epoch number in evaluate.py then execute the script. 29 | ``` 30 | $ python evaluate.py 31 | ``` 32 | 33 | ### Known issues 34 | - pickle.dump will get an "OSError: [Errno 22] Invalid argument" error on OS X when saving the "train context char" data, it's ok on Ubuntu 16.04. -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu:0') 4 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class QANetDataset(Dataset): 8 | 9 | def __init__(self, data_dir, data_type): 10 | 11 | self.context_idxs = np.array( 12 | pickle.load(open(os.path.join(data_dir, data_type+'_context_idxs.pkl'), 'rb')), dtype=np.int64) 13 | self.context_char_idxs = np.array( 14 | pickle.load(open(os.path.join(data_dir, data_type+'_context_char_idxs.pkl'), 'rb')), dtype=np.int64) 15 | self.ques_idxs = np.array( 16 | pickle.load(open(os.path.join(data_dir, data_type+'_ques_idxs.pkl'), 'rb')), dtype=np.int64) 17 | self.ques_char_idxs = np.array( 18 | pickle.load(open(os.path.join(data_dir, data_type+'_ques_char_idxs.pkl'), 'rb')), dtype=np.int64) 19 | self.y = np.array( 20 | pickle.load(open(os.path.join(data_dir, data_type+'_y.pkl'), 'rb')), dtype=np.int64) 21 | self.ids = np.array( 22 | pickle.load(open(os.path.join(data_dir, data_type+'_ids.pkl'), 'rb')), dtype=np.int64) 23 | 24 | def __getitem__(self, index): 25 | context_idxs = self.context_idxs[index] 26 | ques_idxs = self.ques_idxs[index] 27 | 28 | c_mask = np.array(context_idxs > 0, dtype=np.float32) 29 | q_mask = np.array(ques_idxs > 0, dtype=np.float32) 30 | 31 | return (context_idxs, self.context_char_idxs[index], 32 | ques_idxs, self.ques_char_idxs[index], c_mask, q_mask), \ 33 | (self.y[index], self.ids[index]) 34 | 35 | def __len__(self): 36 | return len(self.context_idxs) -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import pickle 5 | import numpy as np 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | import torcheras 10 | 11 | from dataset import QANetDataset 12 | from constants import device 13 | 14 | from qanet.qanet import QANet 15 | from utils import convert_tokens, evaluate 16 | 17 | def variable_data(sample_batched, device): 18 | x = sample_batched[0] 19 | y = sample_batched[1] 20 | if type(x) is list or type(x) is tuple: 21 | for i, _x in enumerate(x): 22 | x[i] = x[i].to(device) 23 | else: 24 | x = x.to(device) 25 | if type(y) is list or type(y) is tuple: 26 | for i, _y in enumerate(y): 27 | y[i] = y[i].to(device) 28 | else: 29 | y = y.to(device) 30 | 31 | return x, y 32 | 33 | def evaluate_scores(y_true, y_pred, test_eval): 34 | qa_id = y_true[1] 35 | c_mask, q_mask = y_pred[2:] 36 | 37 | y_p1 = F.softmax(y_pred[0], dim=-1) 38 | y_p2 = F.softmax(y_pred[1], dim=-1) 39 | 40 | p1 = [] 41 | p2 = [] 42 | p_matrix = torch.bmm(y_p1.unsqueeze(2), y_p2.unsqueeze(1)) 43 | for i in range(p_matrix.shape[0]): 44 | p = torch.triu(p_matrix[i]) 45 | indexes = torch.argmax(p).item() 46 | p1.append(indexes // p.shape[0]) 47 | p2.append(indexes % p.shape[0]) 48 | 49 | answer_dict, _ = convert_tokens( 50 | test_eval, qa_id.tolist(), p1, p2) 51 | metrics = evaluate(test_eval, answer_dict) 52 | 53 | return metrics 54 | 55 | 56 | def evaluate_model(params, dtype='test', model_folder='', model_epoch=''): 57 | test_dataset = QANetDataset('data', dtype) 58 | test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True) 59 | 60 | test_eval = pickle.load(open('data/' + dtype + '_eval.pkl', 'rb')) 61 | 62 | word_emb_mat = np.array(pickle.load(open(os.path.join(params['target_dir'], 'word_emb_mat.pkl'), 'rb')), 63 | dtype=np.float32) 64 | char_emb_mat = np.array(pickle.load(open(os.path.join(params['target_dir'], 'char_emb_mat.pkl'), 'rb')), 65 | dtype=np.float32) 66 | 67 | qanet = QANet(params, word_emb_mat, char_emb_mat).to(device) 68 | qanet = torcheras.Model(qanet, 'log/qanet') 69 | qanet.load_model(model_folder, epoch=model_epoch, ema=True) 70 | qanet = qanet.model 71 | qanet.eval() 72 | 73 | all_scores = {'em': 0, 'f1': 0} 74 | with torch.no_grad(): 75 | for i_batch, sample_batched in enumerate(test_dataloader): 76 | x, y_true = variable_data(sample_batched, device) 77 | y_pred = qanet(x) 78 | metrics = evaluate_scores(y_true, y_pred, test_eval) 79 | print(metrics) 80 | all_scores['em'] += metrics['exact_match'] 81 | all_scores['f1'] += metrics['f1'] 82 | 83 | print('em', all_scores['em'] / i_batch, 'f1', all_scores['f1'] / i_batch) 84 | 85 | if __name__ == '__main__': 86 | params = json.load(open('params.json', 'r')) 87 | 88 | model_folder = '2018_7_24_13_45_8_514568' 89 | model_epoch = 25 90 | 91 | evaluate_model(params, dtype='test', model_folder=model_folder, model_epoch=model_epoch) 92 | -------------------------------------------------------------------------------- /params.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_epochs": 200, 3 | "batch_size":32, 4 | "learning_rate": 1e-3, 5 | "beta1": 0.8, 6 | "beta2": 0.999, 7 | "weight_decay": 3e-7, 8 | 9 | "word_embed_dim": 300, 10 | 11 | "char_dim": 64, 12 | "char_embed_n_filters": 128, 13 | "char_embed_kernel_size": 5, 14 | "char_embed_pad": 2, 15 | 16 | "highway_n_layers": 2, 17 | 18 | "hidden_size": 128, 19 | 20 | "embed_encoder_resize_kernel_size": 5, 21 | "embed_encoder_resize_pad": 3, 22 | 23 | "embed_encoder_n_blocks": 1, 24 | "embed_encoder_n_conv": 4, 25 | "embed_encoder_kernel_size": 7, 26 | "embed_encoder_pad": 3, 27 | "embed_encoder_conv_type": "depthwise_separable", 28 | "embed_encoder_with_self_attn": false, 29 | "embed_encoder_n_heads": 1, 30 | 31 | "model_encoder_n_blocks": 7, 32 | "model_encoder_n_conv": 2, 33 | "model_encoder_kernel_size": 5, 34 | "model_encoder_pad": 2, 35 | "model_encoder_conv_type": "depthwise_separable", 36 | "model_encoder_with_self_attn": false, 37 | "model_encoder_n_heads": 1, 38 | 39 | "data_dir": "../SQuAD/data", 40 | "target_dir": "data", 41 | "train_file": "train-v1.1.json", 42 | "dev_file": "dev-v1.1.json", 43 | "test_file": "dev-v1.1.json", 44 | 45 | "train_record_file": "train_record.pkl", 46 | "dev_record_file": "dev_record.pkl", 47 | 48 | "glove_dir": "../SQuAD/data/glove", 49 | "word_embedding_file": "glove.840B.300d.txt", 50 | "glove_word_size": 2200000, 51 | "glove_dim": 300, 52 | 53 | "para_limit": 400, 54 | "ques_limit": 50, 55 | "ans_limit": 30, 56 | "char_limit": 16 57 | } 58 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | import spacy 6 | import pickle 7 | 8 | # from 9 | from collections import Counter 10 | from tqdm import tqdm 11 | 12 | nlp = spacy.blank("en") 13 | 14 | 15 | def word_tokenize(sent): 16 | doc = nlp(sent) 17 | return [token.text for token in doc] 18 | 19 | 20 | def convert_idx(text, tokens): 21 | current = 0 22 | spans = [] 23 | for token in tokens: 24 | current = text.find(token, current) 25 | if current < 0: 26 | print("Token {} cannot be found".format(token)) 27 | raise Exception() 28 | spans.append((current, current + len(token))) 29 | current += len(token) 30 | return spans 31 | 32 | 33 | def process_file(filename, data_type, word_counter, char_counter): 34 | print("Generating {} examples...".format(data_type)) 35 | examples = [] 36 | eval_examples = {} 37 | total = 0 38 | with open(filename, "r") as fh: 39 | source = json.load(fh) 40 | for article in tqdm(source["data"]): 41 | for para in article["paragraphs"]: 42 | context = para["context"].replace( 43 | "''", '" ').replace("``", '" ') 44 | # tokenize 45 | context_tokens = word_tokenize(context) 46 | context_chars = [list(token) for token in context_tokens] 47 | spans = convert_idx(context, context_tokens) 48 | for token in context_tokens: 49 | word_counter[token] += len(para["qas"]) 50 | for char in token: 51 | char_counter[char] += len(para["qas"]) 52 | for qa in para["qas"]: 53 | total += 1 54 | ques = qa["question"].replace( 55 | "''", '" ').replace("``", '" ') 56 | # tokenize 57 | ques_tokens = word_tokenize(ques) 58 | ques_chars = [list(token) for token in ques_tokens] 59 | for token in ques_tokens: 60 | word_counter[token] += 1 61 | for char in token: 62 | char_counter[char] += 1 63 | y1s, y2s = [], [] 64 | answer_texts = [] 65 | for answer in qa["answers"]: 66 | answer_text = answer["text"] 67 | answer_start = answer['answer_start'] 68 | answer_end = answer_start + len(answer_text) 69 | answer_texts.append(answer_text) 70 | answer_span = [] 71 | for idx, span in enumerate(spans): 72 | if not (answer_end <= span[0] or answer_start >= span[1]): 73 | answer_span.append(idx) 74 | y1, y2 = answer_span[0], answer_span[-1] 75 | y1s.append(y1) 76 | y2s.append(y2) 77 | example = {"context_tokens": context_tokens, "context_chars": context_chars, "ques_tokens": ques_tokens, 78 | "ques_chars": ques_chars, "y1s": y1s, "y2s": y2s, "id": total} 79 | examples.append(example) 80 | eval_examples[str(total)] = { 81 | "context": context, "spans": spans, "answers": answer_texts, "uuid": qa["id"]} 82 | print("{} questions in total".format(len(examples))) 83 | return examples, eval_examples 84 | 85 | 86 | def get_embedding(counter, data_type, limit=-1, emb_file=None, size=None, vec_size=None): 87 | print("Generating {} embedding...".format(data_type)) 88 | embedding_dict = {} 89 | filtered_elements = [k for k, v in counter.items() if v > limit] 90 | if emb_file is not None: 91 | assert size is not None 92 | assert vec_size is not None 93 | with open(emb_file, "r", encoding="utf-8") as fh: 94 | for line in tqdm(fh, total=size): 95 | array = line.split() 96 | word = "".join(array[0:-vec_size]) 97 | vector = list(map(float, array[-vec_size:])) 98 | if word in counter and counter[word] > limit: 99 | embedding_dict[word] = vector 100 | print("{} / {} tokens have corresponding {} embedding vector".format( 101 | len(embedding_dict), len(filtered_elements), data_type)) 102 | else: 103 | assert vec_size is not None 104 | for token in filtered_elements: 105 | embedding_dict[token] = [np.random.normal( 106 | scale=0.1) for _ in range(vec_size)] 107 | print("{} tokens have corresponding embedding vector".format( 108 | len(filtered_elements))) 109 | 110 | NULL = "--NULL--" 111 | OOV = "--OOV--" 112 | token2idx_dict = {token: idx for idx, 113 | token in enumerate(embedding_dict.keys(), 2)} 114 | token2idx_dict[NULL] = 0 115 | token2idx_dict[OOV] = 1 116 | embedding_dict[NULL] = [0. for _ in range(vec_size)] 117 | embedding_dict[OOV] = [0. for _ in range(vec_size)] 118 | idx2emb_dict = {idx: embedding_dict[token] 119 | for token, idx in token2idx_dict.items()} 120 | emb_mat = [idx2emb_dict[idx] for idx in range(len(idx2emb_dict))] 121 | return emb_mat, token2idx_dict 122 | 123 | 124 | def build_features(params, examples, data_type, word2idx_dict, char2idx_dict): 125 | 126 | para_limit = params['para_limit'] 127 | ques_limit = params['ques_limit'] 128 | ans_limit = params['ans_limit'] 129 | char_limit = params['char_limit'] 130 | 131 | def filter_func(example): 132 | return len(example["context_tokens"]) > para_limit or \ 133 | len(example["ques_tokens"]) > ques_limit or \ 134 | (example["y2s"][0] - example["y1s"][0]) > ans_limit 135 | 136 | print("Processing {} examples...".format(data_type)) 137 | # writer = tf.python_io.TFRecordWriter(out_file) 138 | total = 0 139 | total_ = 0 140 | meta = {} 141 | 142 | context_idxs_list = [] 143 | ques_idxs_list = [] 144 | context_char_idxs_list = [] 145 | ques_char_idxs_list = [] 146 | y_list = [] 147 | ids_list = [] 148 | 149 | for example in tqdm(examples): 150 | total_ += 1 151 | 152 | if filter_func(example): 153 | continue 154 | 155 | total += 1 156 | context_idxs = np.zeros([para_limit], dtype=np.int32) 157 | context_char_idxs = np.zeros([para_limit, char_limit], dtype=np.int32) 158 | ques_idxs = np.zeros([ques_limit], dtype=np.int32) 159 | ques_char_idxs = np.zeros([ques_limit, char_limit], dtype=np.int32) 160 | y1 = np.zeros([para_limit], dtype=np.float32) 161 | y2 = np.zeros([para_limit], dtype=np.float32) 162 | 163 | def _get_word(word): 164 | for each in (word, word.lower(), word.capitalize(), word.upper()): 165 | if each in word2idx_dict: 166 | return word2idx_dict[each] 167 | return 1 168 | 169 | def _get_char(char): 170 | if char in char2idx_dict: 171 | return char2idx_dict[char] 172 | return 1 173 | 174 | for i, token in enumerate(example["context_tokens"]): 175 | context_idxs[i] = _get_word(token) 176 | 177 | for i, token in enumerate(example["ques_tokens"]): 178 | ques_idxs[i] = _get_word(token) 179 | 180 | for i, token in enumerate(example["context_chars"]): 181 | for j, char in enumerate(token): 182 | if j == char_limit: 183 | break 184 | context_char_idxs[i, j] = _get_char(char) 185 | 186 | for i, token in enumerate(example["ques_chars"]): 187 | for j, char in enumerate(token): 188 | if j == char_limit: 189 | break 190 | ques_char_idxs[i, j] = _get_char(char) 191 | 192 | start, end = example["y1s"][-1], example["y2s"][-1] 193 | y1[start], y2[end] = 1.0, 1.0 194 | 195 | context_idxs_list.append(context_idxs) 196 | ques_idxs_list.append(ques_idxs) 197 | context_char_idxs_list.append(context_char_idxs) 198 | ques_char_idxs_list.append(ques_char_idxs) 199 | y_list.append([start, end]) 200 | ids_list.append(example['id']) 201 | 202 | print("Built {} / {} instances of features in total".format(total, total_)) 203 | meta["total"] = total 204 | 205 | pickle.dump(context_idxs_list, open(os.path.join(params['target_dir'], data_type + '_context_idxs.pkl'), 'wb')) 206 | pickle.dump(ques_idxs_list, open(os.path.join(params['target_dir'], data_type + '_ques_idxs.pkl'), 'wb')) 207 | pickle.dump(context_char_idxs_list, open(os.path.join(params['target_dir'], data_type + '_context_char_idxs.pkl'), 'wb')) 208 | pickle.dump(ques_char_idxs_list, open(os.path.join(params['target_dir'], data_type + '_ques_char_idxs.pkl'), 'wb')) 209 | pickle.dump(y_list, open(os.path.join(params['target_dir'], data_type + '_y.pkl'), 'wb')) 210 | pickle.dump(ids_list, open(os.path.join(params['target_dir'], data_type+'_ids.pkl'), 'wb')) 211 | 212 | return meta 213 | 214 | 215 | def preprocess(params): 216 | # files 217 | train_file = os.path.join(params['data_dir'], params['train_file']) 218 | dev_file = os.path.join(params['data_dir'], params['dev_file']) 219 | test_file = os.path.join(params['data_dir'], params['test_file']) 220 | word_emb_file = os.path.join(params['glove_dir'], params['word_embedding_file']) 221 | 222 | word_counter, char_counter = Counter(), Counter() 223 | 224 | # examples 225 | train_examples, train_eval = process_file(train_file, "train", word_counter, char_counter) 226 | dev_examples, dev_eval = process_file(dev_file, "dev", word_counter, char_counter) 227 | test_examples, test_eval = process_file(test_file, 'test', word_counter, char_counter) 228 | 229 | # embedding 230 | word_emb_mat, word2idx_dict = get_embedding( 231 | word_counter, "word", emb_file=word_emb_file, size=params['glove_word_size'], vec_size=params['glove_dim']) 232 | char_emb_mat, char2idx_dict = get_embedding( 233 | char_counter, "char", emb_file=None, size=None, vec_size=params['char_dim']) 234 | 235 | pickle.dump(train_examples, open(os.path.join(params['target_dir'], 'train_examples.pkl'), 'wb')) 236 | pickle.dump(train_eval, open(os.path.join(params['target_dir'], 'train_eval.pkl'), 'wb')) 237 | pickle.dump(dev_examples, open(os.path.join(params['target_dir'], 'dev_examples.pkl'), 'wb')) 238 | pickle.dump(dev_eval, open(os.path.join(params['target_dir'], 'dev_eval.pkl'), 'wb')) 239 | pickle.dump(test_examples, open(os.path.join(params['target_dir'], 'test_examples.pkl'), 'wb')) 240 | pickle.dump(test_eval, open(os.path.join(params['target_dir'], 'test_eval.pkl'), 'wb')) 241 | 242 | pickle.dump(word_emb_mat, open(os.path.join(params['target_dir'], 'word_emb_mat.pkl'), 'wb')) 243 | pickle.dump(word2idx_dict, open(os.path.join(params['target_dir'], 'word2idx_dict.pkl'), 'wb')) 244 | pickle.dump(char_emb_mat, open(os.path.join(params['target_dir'], 'char_emb_mat.pkl'), 'wb')) 245 | pickle.dump(char2idx_dict, open(os.path.join(params['target_dir'], 'char2idx_dict.pkl'), 'wb')) 246 | 247 | pickle.dump(word_counter, open(os.path.join(params['target_dir'], 'word_counter.pkl'), 'wb')) 248 | pickle.dump(char_counter, open(os.path.join(params['target_dir'], 'char_counter.pkl'), 'wb')) 249 | 250 | # ======================== need remove =============================== 251 | # train_examples = json.load(open('data/train_examples.json', 'r')) 252 | # train_eval = json.load(open('data/train_eval.json', 'r')) 253 | # dev_examples = json.load(open('data/dev_examples.json', 'r')) 254 | # dev_eval = json.load(open('data/dev_eval.json', 'r')) 255 | 256 | # word_counter = pickle.load(open('tmp/word_counter.pkl', 'rb')) 257 | # char_counter = pickle.load(open('tmp/char_counter.pkl', 'rb')) 258 | 259 | # word_emb_mat = pickle.load(open('data/word_emb_mat.pkl', 'rb')) 260 | # word2idx_dict = pickle.load(open('data/word2idx_dict.pkl', 'rb')) 261 | 262 | # char_emb_mat = pickle.load(open('data/char_emb_mat.pkl', 'rb')) 263 | # char2idx_dict = pickle.load(open('data/char2idx_dict.pkl', 'rb')) 264 | # ==================================================================== 265 | 266 | build_features(params, train_examples, "train", word2idx_dict, char2idx_dict) 267 | dev_meta = build_features(params, dev_examples, "dev", word2idx_dict, char2idx_dict) 268 | test_meta = build_features(params, test_examples, "test", word2idx_dict, char2idx_dict) 269 | 270 | if __name__ == '__main__': 271 | params = json.load(open('params.json', 'r')) 272 | preprocess(params) 273 | print('yeah') 274 | -------------------------------------------------------------------------------- /qanet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hackiey/QAnet-pytorch/aa9225bf27498bf43f4c9321fba6d00dfd312490/qanet/__init__.py -------------------------------------------------------------------------------- /qanet/context_query_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ContextQueryAttention(nn.Module): 6 | 7 | def __init__(self, hidden_size=128): 8 | super(ContextQueryAttention, self).__init__() 9 | 10 | self.d = hidden_size 11 | 12 | self.W0 = nn.Linear(3 * self.d, 1) 13 | nn.init.xavier_normal_(self.W0.weight) 14 | 15 | def forward(self, C, Q, c_mask, q_mask): 16 | 17 | batch_size = C.shape[0] 18 | 19 | n = C.shape[2] 20 | m = Q.shape[2] 21 | 22 | q_mask.unsqueeze(-1) 23 | 24 | # Evaluate the Similarity matrix, S 25 | S = self.similarity(C.permute(0, 2, 1), Q.permute(0, 2, 1), n, m, batch_size) 26 | 27 | S_ = F.softmax(S - 1e30*(1-q_mask.unsqueeze(-1).permute(0, 2, 1).expand(batch_size, n, m)), dim=2) 28 | S__ = F.softmax(S - 1e30*(1-c_mask.unsqueeze(-1).expand(batch_size, n, m)), dim=1) 29 | 30 | A = torch.bmm(S_, Q.permute(0, 2, 1)) 31 | # AT = A.permute(0,2,1) 32 | B = torch.matmul(torch.bmm(S_, S__.permute(0, 2, 1)), C.permute(0, 2, 1)) 33 | # BT = B.permute(0,2,1) 34 | 35 | # following the paper, this layer should return the context2query attention 36 | # and the query2context attention 37 | return A, B 38 | 39 | def similarity(self, C, Q, n, m, batch_size): 40 | 41 | C = F.dropout(C, p=0.1, training=self.training) 42 | Q = F.dropout(Q, p=0.1, training=self.training) 43 | 44 | # Create QSim (#batch x n*m x d) where each of the m original rows are repeated n times 45 | Q_sim = self.repeat_rows_tensor(Q, n) 46 | # Create CSim (#batch x n*m x d) where C is reapted m times 47 | C_sim = C.repeat(1, m, 1) 48 | assert Q_sim.shape == C_sim.shape 49 | QC_sim = Q_sim * C_sim 50 | 51 | # The "learned" Similarity in 1 col, put back 52 | Sim_col = self.W0(torch.cat((Q_sim, C_sim, QC_sim), dim=2)) 53 | # Put it back in right dim 54 | Sim = Sim_col.view(batch_size, m, n).permute(0, 2, 1) 55 | 56 | return Sim 57 | 58 | def repeat_rows_tensor(self, X, rep): 59 | (depth, _, col) = X.shape 60 | # Open dim after batch ("depth") 61 | X = torch.unsqueeze(X, 1) 62 | # Repeat the matrix in the dim opened ("depth") 63 | X = X.repeat(1, rep, 1, 1) 64 | # Permute depth and lines to get the repeat over lines 65 | X = X.permute(0, 2, 1, 3) 66 | # Return to input (#batch x #lines*#repeat x #cols) 67 | X = X.contiguous().view(depth, -1, col) 68 | 69 | return X 70 | 71 | -------------------------------------------------------------------------------- /qanet/depthwise_separable_conv.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class DepthwiseSeparableConv1d(nn.Module): 7 | 8 | def __init__(self, n_filters=128, kernel_size=7, padding=3): 9 | super(DepthwiseSeparableConv1d, self).__init__() 10 | 11 | self.depthwise = nn.Conv1d(n_filters, n_filters, kernel_size=kernel_size, padding=padding, groups=n_filters) 12 | self.separable = nn.Conv1d(n_filters, n_filters, kernel_size=1) 13 | 14 | def forward(self, x): 15 | x = self.depthwise(x) 16 | x = self.separable(x) 17 | 18 | return x -------------------------------------------------------------------------------- /qanet/embedding_encoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | from qanet.encoder_block import EncoderBlock 4 | 5 | 6 | class EmbeddingEncoder(nn.Module): 7 | def __init__(self, resize_in=500, hidden_size=128, resize_kernel=7, resize_pad=3, 8 | n_blocks=1, n_conv=4, kernel_size=7, padding=3, 9 | conv_type='depthwise_separable', n_heads=8, context_length=400, question_length=50): 10 | 11 | super(EmbeddingEncoder, self).__init__() 12 | 13 | self.n_conv = n_conv 14 | self.n_blocks = n_blocks 15 | self.total_layers = (n_conv+2)*n_blocks 16 | 17 | self.stacked_encoderBlocks = nn.ModuleList([EncoderBlock(n_conv=n_conv, 18 | kernel_size=kernel_size, 19 | padding=padding, 20 | n_filters=hidden_size, 21 | conv_type=conv_type, 22 | n_heads=n_heads) for i in range(n_blocks)]) 23 | 24 | def forward(self, context_emb, question_emb, c_mask, q_mask): 25 | for i in range(self.n_blocks): 26 | context_emb = self.stacked_encoderBlocks[i](context_emb, c_mask, i*(self.n_conv+2)+1, self.total_layers) 27 | question_emb = self.stacked_encoderBlocks[i](question_emb, q_mask, i*(self.n_conv+2)+1, self.total_layers) 28 | 29 | return context_emb, question_emb 30 | -------------------------------------------------------------------------------- /qanet/encoder_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from qanet.position_encoding import PositionEncoding 6 | from qanet.layer_norm import LayerNorm1d 7 | from qanet.depthwise_separable_conv import DepthwiseSeparableConv1d 8 | from qanet.self_attention import SelfAttention 9 | 10 | class EncoderBlock(nn.Module): 11 | 12 | def __init__(self, n_conv, kernel_size=7, padding=3, n_filters=128, n_heads=8, conv_type='depthwise_separable'): 13 | super(EncoderBlock, self).__init__() 14 | 15 | self.n_conv = n_conv 16 | self.n_filters = n_filters 17 | 18 | self.position_encoding = PositionEncoding(n_filters=n_filters) 19 | 20 | # self.layer_norm = LayerNorm1d(n_features=n_filters) 21 | 22 | self.layer_norm = nn.ModuleList([LayerNorm1d(n_features=n_filters) for i in range(n_conv+2)]) 23 | 24 | self.conv = nn.ModuleList([DepthwiseSeparableConv1d(n_filters, 25 | kernel_size=kernel_size, 26 | padding=padding) for i in range(n_conv)]) 27 | self.self_attention = SelfAttention(n_heads, n_filters) 28 | 29 | self.fc = nn.Conv1d(n_filters, n_filters, kernel_size=1) 30 | 31 | def layer_dropout(self, inputs, residual, dropout): 32 | if self.training: 33 | if torch.rand(1) > dropout: 34 | outputs = F.dropout(inputs, p=0.1, training=self.training) 35 | return outputs + residual 36 | else: 37 | return residual 38 | else: 39 | return inputs + residual 40 | 41 | def forward(self, x, mask, start_index, total_layers): 42 | 43 | outputs = self.position_encoding(x) 44 | 45 | # convolutional layers 46 | for i in range(self.n_conv): 47 | residual = outputs 48 | outputs = self.layer_norm[i](outputs) 49 | 50 | if i % 2 == 0: 51 | outputs = F.dropout(outputs, p=0.1, training=self.training) 52 | outputs = F.relu(self.conv[i](outputs)) 53 | 54 | # layer dropout 55 | outputs = self.layer_dropout(outputs, residual, (0.1 * start_index / total_layers)) 56 | start_index += 1 57 | 58 | # self attention 59 | residual = outputs 60 | outputs = self.layer_norm[-2](outputs) 61 | 62 | outputs = F.dropout(outputs, p=0.1, training=self.training) 63 | outputs = outputs.permute(0, 2, 1) 64 | outputs = self.self_attention(outputs, mask) 65 | outputs = outputs.permute(0, 2, 1) 66 | 67 | outputs = self.layer_dropout(outputs, residual, 0.1 * start_index / total_layers) 68 | start_index += 1 69 | 70 | # fully connected layer 71 | residual = outputs 72 | outputs = self.layer_norm[-1](outputs) 73 | outputs = F.dropout(outputs, p=0.1, training=self.training) 74 | outputs = self.fc(outputs) 75 | outputs = self.layer_dropout(outputs, residual, 0.1 * start_index / total_layers) 76 | 77 | return outputs 78 | -------------------------------------------------------------------------------- /qanet/highway.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class Highway(nn.Module): 6 | """ Version 1 : carry gate is (1 - transform gate)""" 7 | 8 | def __init__(self, input_size=500, n_layers=2): 9 | super(Highway, self).__init__() 10 | 11 | self.n_layers = n_layers 12 | 13 | self.transform = nn.ModuleList( 14 | [nn.Conv1d(input_size, input_size, kernel_size=1) for i in range(n_layers)]) 15 | self.fc = nn.ModuleList([nn.Conv1d(input_size, input_size, kernel_size=1) for i in range(n_layers)]) 16 | 17 | def forward(self, x): 18 | for i in range(self.n_layers): 19 | transformed = F.sigmoid(self.transform[i](x)) 20 | carried = F.dropout(self.fc[i](x), p=0.1, training=self.training) 21 | x = carried * transformed + x * (1 - transformed) 22 | x = F.relu(x) 23 | 24 | return x 25 | 26 | 27 | class Highway_v2(nn.Module): 28 | """ Version 2 : carry gate is independent from transform gate """ 29 | 30 | def __init__(self, input_size=500, n_layers=2): 31 | super(Highway_v2, self).__init__() 32 | 33 | self.n_layers = n_layers 34 | 35 | self.transform = nn.ModuleList( 36 | [nn.Linear(in_features=input_size, out_features=input_size) for i in range(n_layers)]) 37 | self.carry = nn.ModuleList( 38 | [nn.Linear(in_features=input_size, out_features=input_size) for i in range(n_layers)]) 39 | self.fc = nn.ModuleList([nn.Linear(in_features=input_size, out_features=input_size) for i in range(n_layers)]) 40 | 41 | def forward(self, x): 42 | for i in range(self.n_layers): 43 | transformed = F.sigmoid(self.transform[i](x)) 44 | carried = F.sigmoid(self.carry[i](x)) 45 | x = transformed * self.fc[i](x) + carried * x 46 | x = F.relu(x) 47 | 48 | return x -------------------------------------------------------------------------------- /qanet/input_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from qanet.highway import Highway 6 | 7 | class WordEmbedding(nn.Module): 8 | def __init__(self, word_embeddings): 9 | super(WordEmbedding, self).__init__() 10 | 11 | self.word_embedding = nn.Embedding(num_embeddings=word_embeddings.shape[0], 12 | embedding_dim=word_embeddings.shape[1]) 13 | 14 | self.word_embedding.weight = nn.Parameter(torch.from_numpy(word_embeddings).float()) 15 | self.word_embedding.weight.requires_grad = False 16 | 17 | def forward(self, input_context, input_question): 18 | context_word_emb = self.word_embedding(input_context) 19 | context_word_emb = F.dropout(context_word_emb, p=0.1, training=self.training) 20 | 21 | question_word_emb = self.word_embedding(input_question) 22 | question_word_emb = F.dropout(question_word_emb, p=0.1, training=self.training) 23 | 24 | return context_word_emb, question_word_emb 25 | 26 | 27 | class CharacterEmbedding(nn.Module): 28 | def __init__(self, char_embeddings, embedding_dim=32, n_filters=200, kernel_size=5, padding=2): 29 | super(CharacterEmbedding, self).__init__() 30 | 31 | self.num_embeddings = len(char_embeddings) 32 | self.embedding_dim = embedding_dim 33 | self.kernel_size = (1, kernel_size) 34 | self.padding = (0, padding) 35 | 36 | self.char_embedding = nn.Embedding(num_embeddings=self.num_embeddings, embedding_dim=embedding_dim) 37 | self.char_embedding.weight = nn.Parameter(torch.from_numpy(char_embeddings).float()) 38 | 39 | self.char_conv = nn.Conv2d(in_channels=embedding_dim, 40 | out_channels=n_filters, 41 | kernel_size=self.kernel_size, 42 | padding=self.padding) 43 | 44 | def forward(self, x): 45 | batch_size = x.shape[0] 46 | word_length = x.shape[-1] 47 | 48 | x = x.view(batch_size, -1) 49 | x = self.char_embedding(x) 50 | x = x.view(batch_size, -1, word_length, self.embedding_dim) 51 | 52 | # embedding dim of characters is number of channels of conv layer 53 | x = x.permute(0, 3, 1, 2) 54 | x = F.relu(self.char_conv(x)) 55 | x = x.permute(0, 2, 3, 1) 56 | 57 | # max pooling over word length to have final tensor 58 | x, _ = torch.max(x, dim=2) 59 | 60 | x = F.dropout(x, p=0.05, training=self.training) 61 | 62 | return x 63 | 64 | 65 | class InputEmbedding(nn.Module): 66 | def __init__(self, word_embeddings, char_embeddings, word_embed_dim=300, 67 | char_embed_dim=32, char_embed_n_filters=200, 68 | char_embed_kernel_size=7, char_embed_pad=3, highway_n_layers=2, hidden_size=128): 69 | 70 | super(InputEmbedding, self).__init__() 71 | 72 | self.word_embedding = WordEmbedding(word_embeddings) 73 | self.character_embedding = CharacterEmbedding(char_embeddings, 74 | embedding_dim=char_embed_dim, 75 | n_filters=char_embed_n_filters, 76 | kernel_size=char_embed_kernel_size, 77 | padding=char_embed_pad) 78 | 79 | self.projection = nn.Conv1d(word_embed_dim + char_embed_n_filters, hidden_size, 1) 80 | 81 | self.highway = Highway(input_size=hidden_size, n_layers=highway_n_layers) 82 | 83 | def forward(self, context_word, context_char, question_word, question_char): 84 | context_word, question_word = self.word_embedding(context_word, question_word) 85 | context_char = self.character_embedding(context_char) 86 | question_char = self.character_embedding(question_char) 87 | 88 | context = torch.cat((context_word, context_char), dim=-1) 89 | question = torch.cat((question_word, question_char), dim=-1) 90 | 91 | context = self.projection(context.permute(0, 2, 1)) 92 | question = self.projection(question.permute(0, 2, 1)) 93 | 94 | context = self.highway(context) 95 | question = self.highway(question) 96 | 97 | return context, question -------------------------------------------------------------------------------- /qanet/layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class LayerNorm1d(nn.Module): 5 | 6 | def __init__(self, n_features, eps=1e-6): 7 | super().__init__() 8 | self.gamma = nn.Parameter(torch.ones(n_features)) 9 | self.beta = nn.Parameter(torch.zeros(n_features)) 10 | self.eps = eps 11 | 12 | def forward(self, x): 13 | x = x.permute(0, 2, 1) 14 | mean = x.mean(-1, keepdim=True) 15 | std = x.std(-1, keepdim=True) 16 | x = self.gamma * (x - mean) / (std + self.eps) + self.beta 17 | return x.permute(0, 2, 1) -------------------------------------------------------------------------------- /qanet/model_encoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | from qanet.encoder_block import EncoderBlock 4 | 5 | class ModelEncoder(nn.Module): 6 | def __init__(self, n_blocks=7, n_conv=2, kernel_size=7, padding=3, 7 | hidden_size=128, conv_type='depthwise_separable', n_heads=8, context_length=400): 8 | 9 | super(ModelEncoder, self).__init__() 10 | 11 | self.n_conv = n_conv 12 | self.n_blocks = n_blocks 13 | self.total_layers = (n_conv + 2) * n_blocks 14 | 15 | self.stacked_encoderBlocks = nn.ModuleList([EncoderBlock(n_conv=n_conv, 16 | kernel_size=kernel_size, 17 | padding=padding, 18 | n_filters=hidden_size, 19 | conv_type=conv_type, 20 | n_heads=n_heads) for i in range(n_blocks)]) 21 | 22 | def forward(self, x, mask): 23 | 24 | for i in range(self.n_blocks): 25 | x = self.stacked_encoderBlocks[i](x, mask, i*(self.n_conv+2)+1, self.total_layers) 26 | M0 = x 27 | 28 | for i in range(self.n_blocks): 29 | x = self.stacked_encoderBlocks[i](x, mask, i*(self.n_conv+2)+1, self.total_layers) 30 | M1 = x 31 | 32 | for i in range(self.n_blocks): 33 | x = self.stacked_encoderBlocks[i](x, mask, i*(self.n_conv+2)+1, self.total_layers) 34 | 35 | M2 = x 36 | 37 | return M0, M1, M2 38 | -------------------------------------------------------------------------------- /qanet/output.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Output(nn.Module): 5 | def __init__(self, input_dim = 512): 6 | super(Output, self).__init__() 7 | 8 | self.d = input_dim 9 | 10 | self.W1 = nn.Linear(2*self.d, 1) 11 | self.W2 = nn.Linear(2*self.d, 1) 12 | 13 | nn.init.xavier_uniform_(self.W1.weight) 14 | nn.init.xavier_uniform_(self.W2.weight) 15 | 16 | def forward(self, M0, M1, M2): 17 | 18 | p1 = self.W1(torch.cat((M0,M1), -1)).squeeze() 19 | p2 = self.W2(torch.cat((M0,M2), -1)).squeeze() 20 | return p1, p2 21 | -------------------------------------------------------------------------------- /qanet/position_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from constants import device 5 | 6 | 7 | class PositionEncoding(nn.Module): 8 | def __init__(self, n_filters=128, min_timescale=1.0, max_timescale=1.0e4): 9 | 10 | super(PositionEncoding, self).__init__() 11 | 12 | self.min_timescale = min_timescale 13 | self.max_timescale = max_timescale 14 | self.d = n_filters 15 | 16 | # we use the fact that cos(x) = sin(x + pi/2) to compute everything with one sin statement 17 | self.freqs = torch.Tensor( 18 | [max_timescale ** (-i / self.d) if i % 2 == 0 else max_timescale ** (-(i - 1) / self.d) for i in 19 | range(self.d)]).unsqueeze(1).to(device) 20 | self.phases = torch.Tensor([0 if i % 2 == 0 else np.pi / 2 for i in range(self.d)]).unsqueeze(1).to(device) 21 | 22 | def forward(self, x): 23 | 24 | # *************** speed up, static pos_enc ****************** 25 | l = x.shape[-1] 26 | 27 | # computing signal 28 | pos = torch.arange(l, dtype=torch.float32).repeat(self.d, 1).to(device) 29 | tmp = pos * self.freqs + self.phases 30 | pos_enc = torch.sin(tmp) 31 | x = x + pos_enc 32 | 33 | return x 34 | 35 | 36 | if __name__ == '__main__': 37 | mdl = PositionEncoding() 38 | 39 | batch_size = 8 40 | n_channels = 128 41 | n_items = 60 42 | 43 | input = torch.ones(batch_size, n_channels, n_items) 44 | 45 | out = mdl(input) 46 | print(out) 47 | -------------------------------------------------------------------------------- /qanet/qanet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from qanet.input_embedding import InputEmbedding 6 | from qanet.embedding_encoder import EmbeddingEncoder 7 | from qanet.context_query_attention import ContextQueryAttention 8 | from qanet.model_encoder import ModelEncoder 9 | from qanet.output import Output 10 | 11 | class QANet(nn.Module): 12 | ''' All-in-one wrapper for all modules ''' 13 | 14 | def __init__(self, params, word_embeddings, char_embeddings): 15 | super(QANet, self).__init__() 16 | 17 | self.batch_size = params['batch_size'] 18 | 19 | # Defining dimensions using data from the params.json file 20 | self.word_embed_dim = params['word_embed_dim'] 21 | self.char_word_len = params["char_limit"] 22 | 23 | self.context_length = params['para_limit'] 24 | self.question_length = params['ques_limit'] 25 | 26 | self.char_embed_dim = params["char_dim"] 27 | self.char_embed_n_filters = params["char_embed_n_filters"] 28 | self.char_embed_kernel_size = params["char_embed_kernel_size"] 29 | self.char_embed_pad = params["char_embed_pad"] 30 | 31 | self.highway_n_layers = params["highway_n_layers"] 32 | 33 | self.hidden_size = params["hidden_size"] 34 | 35 | self.embed_encoder_resize_kernel_size = params["embed_encoder_resize_kernel_size"] 36 | self.embed_encoder_resize_pad = params["embed_encoder_resize_pad"] 37 | 38 | self.embed_encoder_n_blocks = params["embed_encoder_n_blocks"] 39 | self.embed_encoder_n_conv = params["embed_encoder_n_conv"] 40 | self.embed_encoder_kernel_size = params["embed_encoder_kernel_size"] 41 | self.embed_encoder_pad = params["embed_encoder_pad"] 42 | self.embed_encoder_conv_type = params["embed_encoder_conv_type"] 43 | self.embed_encoder_with_self_attn = params["embed_encoder_with_self_attn"] 44 | self.embed_encoder_n_heads = params["embed_encoder_n_heads"] 45 | 46 | self.model_encoder_n_blocks = params["model_encoder_n_blocks"] 47 | self.model_encoder_n_conv = params["model_encoder_n_conv"] 48 | self.model_encoder_kernel_size = params["model_encoder_kernel_size"] 49 | self.model_encoder_pad = params["model_encoder_pad"] 50 | self.model_encoder_conv_type = params["model_encoder_conv_type"] 51 | self.model_encoder_with_self_attn = params["model_encoder_with_self_attn"] 52 | self.model_encoder_n_heads = params["model_encoder_n_heads"] 53 | 54 | # Initializing model layers 55 | word_embeddings = np.array(word_embeddings) 56 | self.input_embedding = InputEmbedding(word_embeddings, 57 | char_embeddings, 58 | word_embed_dim=self.word_embed_dim, 59 | char_embed_dim=self.char_embed_dim, 60 | char_embed_n_filters=self.char_embed_n_filters, 61 | char_embed_kernel_size=self.char_embed_kernel_size, 62 | char_embed_pad=self.char_embed_pad, 63 | highway_n_layers=self.highway_n_layers, 64 | hidden_size=self.hidden_size) 65 | 66 | self.embedding_encoder = EmbeddingEncoder(resize_in=self.word_embed_dim + self.char_embed_n_filters, 67 | hidden_size=self.hidden_size, 68 | resize_kernel=self.embed_encoder_resize_kernel_size, 69 | resize_pad=self.embed_encoder_resize_pad, 70 | n_blocks=self.embed_encoder_n_blocks, 71 | n_conv=self.embed_encoder_n_conv, 72 | kernel_size=self.embed_encoder_kernel_size, 73 | padding=self.embed_encoder_pad, 74 | conv_type=self.embed_encoder_conv_type, 75 | n_heads=self.embed_encoder_n_heads, 76 | context_length=self.context_length, 77 | question_length=self.question_length) 78 | 79 | self.context_query_attention = ContextQueryAttention(hidden_size=self.hidden_size) 80 | 81 | self.projection = nn.Conv1d(4 * self.hidden_size, self.hidden_size, kernel_size=1) 82 | 83 | self.model_encoder = ModelEncoder(n_blocks=self.model_encoder_n_blocks, 84 | n_conv=self.model_encoder_n_conv, 85 | kernel_size=self.model_encoder_kernel_size, 86 | padding=self.model_encoder_pad, 87 | hidden_size=self.hidden_size, 88 | conv_type=self.model_encoder_conv_type, 89 | n_heads=self.model_encoder_n_heads) 90 | self.output = Output(input_dim=self.hidden_size) 91 | 92 | def forward(self, x): 93 | context_word, context_char, question_word, question_char, c_mask, q_mask = x 94 | 95 | c_maxlen = int(c_mask.sum(1).max().item()) 96 | q_maxlen = int(q_mask.sum(1).max().item()) 97 | context_word = context_word[:, :c_maxlen] 98 | context_char = context_char[:, :c_maxlen, :] 99 | question_word = question_word[:, :q_maxlen] 100 | question_char = question_char[:, :q_maxlen, :] 101 | c_mask = c_mask[:, :c_maxlen] 102 | q_mask = q_mask[:, :q_maxlen] 103 | 104 | context_emb, question_emb = self.input_embedding(context_word, context_char, question_word, question_char) 105 | context_emb, question_emb = self.embedding_encoder(context_emb, question_emb, c_mask, q_mask) 106 | 107 | c2q_attn, q2c_attn = self.context_query_attention(context_emb, question_emb, c_mask, q_mask) 108 | mdl_emb = torch.cat((context_emb, 109 | c2q_attn.permute(0, 2, 1), 110 | context_emb*c2q_attn.permute(0, 2, 1), 111 | context_emb*q2c_attn.permute(0, 2, 1)), 1) 112 | 113 | mdl_emb = self.projection(mdl_emb) 114 | 115 | M0, M1, M2 = self.model_encoder(mdl_emb, c_mask) 116 | 117 | p1, p2 = self.output(M0.permute(0,2,1), M1.permute(0,2,1), M2.permute(0,2,1)) 118 | 119 | p1 = p1 - 1e30*(1-c_mask) 120 | p2 = p2 - 1e30*(1-c_mask) 121 | 122 | return p1, p2, c_mask, q_mask 123 | 124 | 125 | if __name__ == '__main__': 126 | import json 127 | params = json.load(open('params.json', 'r')) 128 | qanet = QANet(params) 129 | 130 | print(dir(qanet)) 131 | -------------------------------------------------------------------------------- /qanet/self_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from constants import device 6 | 7 | 8 | class SelfAttention(nn.Module): 9 | 10 | def __init__(self, n_heads=8, n_filters=128): 11 | super(SelfAttention, self).__init__() 12 | 13 | self.n_filters = n_filters 14 | self.n_heads = n_heads 15 | 16 | self.key_dim = n_filters // n_heads 17 | self.value_dim = n_filters // n_heads 18 | 19 | self.fc_query = nn.ModuleList([nn.Linear(n_filters, self.key_dim) for i in range(n_heads)]) 20 | self.fc_key = nn.ModuleList([nn.Linear(n_filters, self.key_dim) for i in range(n_heads)]) 21 | self.fc_value = nn.ModuleList([nn.Linear(n_filters, self.value_dim) for i in range(n_heads)]) 22 | self.fc_out = nn.Linear(n_heads * self.value_dim, n_filters) 23 | 24 | def forward(self, x, mask): 25 | batch_size = x.shape[0] 26 | l = x.shape[1] 27 | 28 | mask = mask.unsqueeze(-1).expand(x.shape[0], x.shape[1], x.shape[1]).permute(0,2,1) 29 | 30 | heads = torch.zeros(self.n_heads, batch_size, l, self.value_dim, device=device) 31 | 32 | for i in range(self.n_heads): 33 | Q = self.fc_query[i](x) 34 | K = self.fc_key[i](x) 35 | V = self.fc_value[i](x) 36 | 37 | # scaled dot-product attention 38 | tmp = torch.bmm(Q, K.permute(0,2,1)) 39 | tmp = tmp / np.sqrt(self.key_dim) 40 | tmp = F.softmax(tmp - 1e30*(1-mask), dim=-1) 41 | 42 | tmp = F.dropout(tmp, p=0.1, training=self.training) 43 | 44 | heads[i] = torch.bmm(tmp, V) 45 | 46 | # concatenation is the same as reshaping our tensor 47 | x = heads.permute(1,2,0,3).contiguous().view(batch_size, l, -1) 48 | x = self.fc_out(x) 49 | 50 | return x 51 | 52 | 53 | if __name__ == "__main__": 54 | batch_size = 8 55 | l = 60 56 | n_filters = 128 57 | 58 | mdl = SelfAttention() 59 | 60 | x = torch.ones(batch_size, l, n_filters) 61 | 62 | print(mdl(x)) 63 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # import ipdb 2 | import os 3 | import pickle 4 | import json 5 | import math 6 | import torch 7 | import torch.nn.functional as F 8 | import pickle 9 | import torcheras 10 | import argparse 11 | import numpy as np 12 | 13 | # from torchsummary import summary 14 | from collections import Counter 15 | from torch.utils.data import DataLoader 16 | 17 | from qanet.qanet import QANet 18 | from dataset import QANetDataset 19 | from constants import device 20 | from utils import convert_tokens, evaluate 21 | 22 | parser = argparse.ArgumentParser(description='save description') 23 | parser.add_argument('description', default='') 24 | 25 | criterion = torch.nn.CrossEntropyLoss() 26 | 27 | def loss_function(y_pred, y_true): 28 | span = y_true[0] 29 | 30 | loss = (criterion(y_pred[0], span[:, 0]) + criterion(y_pred[1], span[:, 1])) / 2 31 | # loss += criterion(y_pred[1], span[:, 1]) 32 | return loss 33 | 34 | def count_parameters(model): 35 | parameters = [p for p in model.parameters() if p.requires_grad] 36 | counts = [p.numel() for p in parameters] 37 | 38 | for p, c in zip(parameters, counts): 39 | print(p.shape, c) 40 | 41 | return sum(counts) 42 | 43 | def train(params, description): 44 | train_dataset = QANetDataset('data', 'train') 45 | dev_dataset = QANetDataset('data', 'dev') 46 | 47 | train_eval = pickle.load(open('data/train_eval.pkl', 'rb')) 48 | dev_eval = pickle.load(open('data/dev_eval.pkl', 'rb')) 49 | 50 | def evaluate_em(y_true, y_pred): 51 | qa_id = y_true[1] 52 | c_mask, q_mask = y_pred[2:] 53 | 54 | y_p1 = F.softmax(y_pred[0], dim=-1) 55 | y_p2 = F.softmax(y_pred[1], dim=-1) 56 | 57 | p1 = [] 58 | p2 = [] 59 | 60 | p_matrix = torch.bmm(y_p1.unsqueeze(2), y_p2.unsqueeze(1)) 61 | for i in range(p_matrix.shape[0]): 62 | p = torch.triu(p_matrix[i]) 63 | indexes = torch.argmax(p).item() 64 | p1.append(indexes // p.shape[0]) 65 | p2.append(indexes % p.shape[0]) 66 | 67 | if y_pred[0].requires_grad: 68 | answer_dict, _ = convert_tokens( 69 | train_eval, qa_id.tolist(), p1, p2) 70 | metrics = evaluate(train_eval, answer_dict) 71 | else: 72 | answer_dict, _ = convert_tokens( 73 | dev_eval, qa_id.tolist(), p1, p2) 74 | metrics = evaluate(dev_eval, answer_dict) 75 | 76 | return torch.Tensor([metrics['exact_match']]) 77 | 78 | train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True) 79 | dev_loader = DataLoader(dev_dataset, batch_size=params['batch_size'], shuffle=True) 80 | 81 | word_emb_mat = np.array(pickle.load(open(os.path.join(params['target_dir'], 'word_emb_mat.pkl'), 'rb')), 82 | dtype=np.float32) 83 | char_emb_mat = np.array(pickle.load(open(os.path.join(params['target_dir'], 'char_emb_mat.pkl'), 'rb')), 84 | dtype=np.float32) 85 | 86 | qanet = QANet(params, word_emb_mat, char_emb_mat).to(device) 87 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, qanet.parameters()), 88 | lr=params['learning_rate'], betas=(params['beta1'], params['beta2']), 89 | weight_decay=params['weight_decay']) 90 | crit = 1 / math.log(1000) 91 | 92 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 93 | lr_lambda=lambda ee: crit * math.log(ee + 1) if ( 94 | ee + 1) <= 1000 else 1) 95 | 96 | qanet = torcheras.Model(qanet, 'log/qanet') 97 | 98 | print(description) 99 | qanet.set_description(description) 100 | 101 | custom_objects = {'em': evaluate_em} 102 | qanet.compile(loss_function, scheduler, metrics=['em'], device=device, custom_objects=custom_objects) 103 | qanet.fit(train_loader, dev_loader, ema_decay=0.9999, grad_clip=5) 104 | 105 | 106 | if __name__ == '__main__': 107 | args = parser.parse_args() 108 | 109 | params = json.load(open('params.json', 'r')) 110 | train(params, args.description) 111 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import ipdb 2 | from collections import Counter 3 | import string 4 | import re 5 | 6 | def mask_inputs(inputs, mask, mask_value = -1e30): 7 | return inputs + mask_value * (1 - mask) 8 | 9 | def convert_tokens(eval_file, qa_id, pp1, pp2): 10 | answer_dict = {} 11 | remapped_dict = {} 12 | for qid, p1, p2 in zip(qa_id, pp1, pp2): 13 | context = eval_file[str(qid)]["context"] 14 | spans = eval_file[str(qid)]["spans"] 15 | uuid = eval_file[str(qid)]["uuid"] 16 | start_idx = spans[p1][0] 17 | end_idx = spans[p2][1] 18 | answer_dict[str(qid)] = context[start_idx: end_idx] 19 | remapped_dict[uuid] = context[start_idx: end_idx] 20 | return answer_dict, remapped_dict 21 | 22 | def evaluate(eval_file, answer_dict): 23 | f1 = exact_match = total = 0 24 | for key, value in answer_dict.items(): 25 | total += 1 26 | ground_truths = eval_file[key]["answers"] 27 | prediction = value 28 | exact_match += metric_max_over_ground_truths( 29 | exact_match_score, prediction, ground_truths) 30 | f1 += metric_max_over_ground_truths(f1_score, 31 | prediction, ground_truths) 32 | exact_match = 100.0 * exact_match / total 33 | f1 = 100.0 * f1 / total 34 | return {'exact_match': exact_match, 'f1': f1} 35 | 36 | def evaluate_single(eval_file, answer_dict): 37 | correct_keys = [] 38 | 39 | all_predictions = {} 40 | for key, value in answer_dict.items(): 41 | # ipdb.set_trace() 42 | ground_truths = eval_file[key]['answers'] 43 | prediction = value 44 | em = metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) 45 | 46 | if em > 0: 47 | correct_keys.append(key) 48 | 49 | all_predictions[key] = prediction 50 | 51 | return correct_keys, all_predictions 52 | 53 | def normalize_answer(s): 54 | 55 | def remove_articles(text): 56 | return re.sub(r'\b(a|an|the)\b', ' ', text) 57 | 58 | def white_space_fix(text): 59 | return ' '.join(text.split()) 60 | 61 | def remove_punc(text): 62 | exclude = set(string.punctuation) 63 | return ''.join(ch for ch in text if ch not in exclude) 64 | 65 | def lower(text): 66 | return text.lower() 67 | 68 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 69 | 70 | 71 | def f1_score(prediction, ground_truth): 72 | prediction_tokens = normalize_answer(prediction).split() 73 | ground_truth_tokens = normalize_answer(ground_truth).split() 74 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 75 | num_same = sum(common.values()) 76 | if num_same == 0: 77 | return 0 78 | precision = 1.0 * num_same / len(prediction_tokens) 79 | recall = 1.0 * num_same / len(ground_truth_tokens) 80 | f1 = (2 * precision * recall) / (precision + recall) 81 | return f1 82 | 83 | 84 | def exact_match_score(prediction, ground_truth): 85 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 86 | 87 | 88 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 89 | scores_for_ground_truths = [] 90 | for ground_truth in ground_truths: 91 | score = metric_fn(prediction, ground_truth) 92 | scores_for_ground_truths.append(score) 93 | return max(scores_for_ground_truths) 94 | --------------------------------------------------------------------------------