├── .gitignore ├── README.md ├── assets └── model.png ├── attention.py ├── build_emb.py ├── data_generator.py ├── model.py ├── modules.py ├── run.sh ├── run_with_doc.sh ├── script.py ├── test.sh ├── train.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.json 2 | datasets/webqsp/ 3 | *.tar.gz 4 | *.pt 5 | __pycache__/ 6 | tf_logs/* 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code for the ACL 2019 paper: 2 | 3 | ## Improving Question Answering over Incomplete KBs with Knowledge-Aware Reader 4 | 5 | Paper link: [https://arxiv.org/abs/1905.07098](https://arxiv.org/abs/1905.07098) 6 | 7 | Model Overview: 8 |

9 | 10 | ### Requirements 11 | * ``PyTorch 1.0.1`` 12 | * ``tensorboardX`` 13 | * ``tqdm`` 14 | * ``gluonnlp`` 15 | 16 | ### Prepare data 17 | ``` 18 | mkdir datasets && cd datasets && wget https://sites.cs.ucsb.edu/~xwhan/datasets/webqsp.tar.gz && tar -xzvf webqsp.tar.gz && cd .. 19 | ``` 20 | 21 | ### Full KB setting 22 | ``` 23 | CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_full_kb --max_num_neighbors 50 --label_smooth 0.1 --data_folder datasets/webqsp/full/ 24 | ``` 25 | 26 | ### Incomplete KB setting 27 | Note: The Hits@1 should match or be slightly better than the number reported in the paper. More tuning on threshold should give you better F1 score. 28 | #### 30% KB 29 | ``` 30 | CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_03 --max_num_neighbors 50 --use_doc --data_folder datasets/webqsp/kb_03/ --eps 0.05 31 | ``` 32 | 33 | #### 10% KB 34 | ``` 35 | CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_01 --max_num_neighbors 50 --use_doc --data_folder datasets/webqsp/kb_01/ --eps 0.05 36 | ``` 37 | #### 50% KB 38 | ``` 39 | CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_05 --num_layer 1 --max_num_neighbors 100 --use_doc --data_folder datasets/webqsp/kb_05/ --eps 0.05 --seed 3 --hidden_drop 0.05 40 | ``` 41 | 42 | ### Citation 43 | ``` 44 | @inproceedings{xiong-etal-2019-improving, 45 | title = "Improving Question Answering over Incomplete {KB}s with Knowledge-Aware Reader", 46 | author = "Xiong, Wenhan and 47 | Yu, Mo and 48 | Chang, Shiyu and 49 | Guo, Xiaoxiao and 50 | Wang, William Yang", 51 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics", 52 | month = jul, 53 | year = "2019", 54 | address = "Florence, Italy", 55 | publisher = "Association for Computational Linguistics", 56 | url = "https://www.aclweb.org/anthology/P19-1417", 57 | doi = "10.18653/v1/P19-1417", 58 | pages = "4258--4264", 59 | } 60 | ``` 61 | -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xwhan/Knowledge-Aware-Reader/883bb0477563f50b391b3f2e0b71ec8a244a0cb8/assets/model.png -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import copy 6 | import math 7 | 8 | def attention(query, key, value, mask=None, dropout=None): 9 | "Compute 'Scaled Dot Product Attention'" 10 | d_k = query.size(-1) 11 | scores = torch.matmul(query, key.transpose(-2, -1)) \ 12 | / math.sqrt(d_k) 13 | if mask is not None: 14 | scores = scores.masked_fill(mask == 0, -1e9) 15 | p_attn = F.softmax(scores, dim = -1) 16 | if dropout is not None: 17 | p_attn = dropout(p_attn) 18 | return torch.matmul(p_attn, value), p_attn 19 | 20 | 21 | def clones(module, N): 22 | "Produce N identical layers." 23 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 24 | 25 | 26 | class MultiHeadedAttention(nn.Module): 27 | def __init__(self, h, d_model, dropout=0.1): 28 | "Take in model size and number of heads." 29 | super(MultiHeadedAttention, self).__init__() 30 | assert d_model % h == 0 31 | # We assume d_v always equals d_k 32 | self.d_k = d_model // h 33 | self.h = h 34 | self.linears = clones(nn.Linear(d_model, d_model), 4) 35 | self.attn = None 36 | self.dropout = nn.Dropout(p=dropout) 37 | 38 | def forward(self, query, key, value, mask=None): 39 | "Implements Figure 2" 40 | if mask is not None: 41 | # Same mask applied to all h heads. 42 | mask = mask.unsqueeze(1) 43 | nbatches = query.size(0) 44 | 45 | # 1) Do all the linear projections in batch from d_model => h x d_k 46 | query, key, value = \ 47 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 48 | for l, x in zip(self.linears, (query, key, value))] 49 | 50 | # 2) Apply attention on all the projected vectors in batch. 51 | x, self.attn = attention(query, key, value, mask=mask, 52 | dropout=self.dropout) 53 | 54 | # 3) "Concat" using a view and apply a final linear. 55 | x = x.transpose(1, 2).contiguous() \ 56 | .view(nbatches, -1, self.h * self.d_k) 57 | return self.linears[-1](x) 58 | 59 | class PositionwiseFeedForward(nn.Module): 60 | "Implements FFN equation." 61 | def __init__(self, d_model, d_ff, dropout=0.1): 62 | super(PositionwiseFeedForward, self).__init__() 63 | self.w_1 = nn.Linear(d_model, d_ff) 64 | self.w_2 = nn.Linear(d_ff, d_model) 65 | self.dropout = nn.Dropout(dropout) 66 | 67 | def forward(self, x): 68 | return self.w_2(self.dropout(F.relu(self.w_1(x)))) 69 | 70 | class PositionalEncoding(nn.Module): 71 | "Implement the PE function." 72 | def __init__(self, d_model, dropout, max_len=5000): 73 | super(PositionalEncoding, self).__init__() 74 | self.dropout = nn.Dropout(p=dropout) 75 | 76 | # Compute the positional encodings once in log space. 77 | pe = torch.zeros(max_len, d_model) 78 | position = torch.arange(0, max_len).unsqueeze(1).float() 79 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * 80 | -(math.log(10000.0) / d_model)) 81 | pe[:, 0::2] = torch.sin(position * div_term) 82 | pe[:, 1::2] = torch.cos(position * div_term) 83 | pe = pe.unsqueeze(0) 84 | self.register_buffer('pe', pe) 85 | 86 | def forward(self, x): 87 | x = x + Variable(self.pe[:, :x.size(1)], 88 | requires_grad=False) 89 | return self.dropout(x) 90 | 91 | class LayerNorm(nn.Module): 92 | "Construct a layernorm module (See citation for details)." 93 | def __init__(self, features, eps=1e-6): 94 | super(LayerNorm, self).__init__() 95 | self.a_2 = nn.Parameter(torch.ones(features)) 96 | self.b_2 = nn.Parameter(torch.zeros(features)) 97 | self.eps = eps 98 | 99 | def forward(self, x): 100 | mean = x.mean(-1, keepdim=True) 101 | std = x.std(-1, keepdim=True) 102 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 103 | 104 | class SublayerConnection(nn.Module): 105 | """ 106 | A residual connection followed by a layer norm. 107 | Note for code simplicity the norm is first as opposed to last. 108 | """ 109 | def __init__(self, size, dropout): 110 | super(SublayerConnection, self).__init__() 111 | self.norm = LayerNorm(size) 112 | self.dropout = nn.Dropout(dropout) 113 | 114 | def forward(self, x, sublayer): 115 | "Apply residual connection to any sublayer with the same size." 116 | return x + self.dropout(sublayer(self.norm(x))) 117 | 118 | class EncoderLayer(nn.Module): 119 | "Encoder is made up of self-attn and feed forward (defined below)" 120 | def __init__(self, size, self_attn, feed_forward, dropout): 121 | super(EncoderLayer, self).__init__() 122 | self.self_attn = self_attn 123 | self.feed_forward = feed_forward 124 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 125 | self.size = size 126 | 127 | def forward(self, x, mask): 128 | "Follow Figure 1 (left) for connections." 129 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) 130 | return self.sublayer[1](x, self.feed_forward) 131 | 132 | class Encoder(nn.Module): 133 | "Core encoder is a stack of N layers" 134 | def __init__(self, layer, N): 135 | super(Encoder, self).__init__() 136 | self.layers = clones(layer, N) 137 | self.norm = LayerNorm(layer.size) 138 | 139 | def forward(self, x, mask): 140 | "Pass the input (and mask) through each layer in turn." 141 | for layer in self.layers: 142 | x = layer(x, mask) 143 | return self.norm(x) 144 | 145 | class SimpleEncoder(nn.Module): 146 | """ 147 | takes (batch_size, seq_len, embed_dim) as inputs 148 | calculate MASK, POSITION_ENCODING 149 | """ 150 | def __init__(self, embed_dim, head=4, layer=1, dropout=0.1): 151 | super(SimpleEncoder, self).__init__() 152 | d_ff = 2 * embed_dim 153 | 154 | self.position = PositionalEncoding(embed_dim, dropout) 155 | attn = MultiHeadedAttention(head, embed_dim) 156 | ff = PositionwiseFeedForward(embed_dim, d_ff) 157 | self.encoder = Encoder(EncoderLayer(embed_dim, attn, ff, dropout), layer) 158 | 159 | def forward(self, x, mask): 160 | mask = mask.unsqueeze(-2) 161 | x = self.position(x) 162 | x = self.encoder(x, mask) 163 | return x 164 | 165 | if __name__ == '__main__': 166 | encoder = SimpleEncoder(350, 2, 1) 167 | inputs = torch.zeros(1000,50,350) 168 | lens = [10] * 1000 169 | encoder(inputs, lens) -------------------------------------------------------------------------------- /build_emb.py: -------------------------------------------------------------------------------- 1 | import gluonnlp as nlp 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | dataset = 'datasets/webqsp/kb_05' 6 | rel_path = dataset + '/relations.txt' 7 | 8 | word_counter = [] 9 | 10 | # load original vocab 11 | with open(dataset + '/vocab.txt') as f: 12 | for line in f.readlines(): 13 | word_counter.append(line.strip()) 14 | 15 | rel_words = [] 16 | max_num_words = 0 17 | all_relations = [] 18 | 19 | # how to split the relation 20 | if 'webqsp' in dataset: 21 | with open(rel_path) as f: 22 | first_line = True 23 | for line in tqdm(f.readlines()): 24 | if first_line: 25 | first_line = False 26 | continue 27 | line = line.strip() 28 | all_relations.append(line) 29 | line = line[1:-1] 30 | fields = line.split('.') 31 | words = fields[-2].split('_') + fields[-1].split('_') 32 | max_num_words = max(len(words), max_num_words) 33 | rel_words.append(words) 34 | word_counter += words 35 | elif 'wikimovie' in dataset: 36 | with open(rel_path) as f: 37 | for line in tqdm(f.readlines()): 38 | line = line.strip() 39 | all_relations.append(line) 40 | words = line.split('_') 41 | max_num_words = max(len(words), max_num_words) 42 | rel_words.append(words) 43 | word_counter += words 44 | 45 | print('max_num_words: ', max_num_words) 46 | 47 | word_counter = nlp.data.count_tokens(word_counter) 48 | glove_emb = nlp.embedding.create('glove', source='glove.6B.100d') 49 | vocab = nlp.Vocab(word_counter) 50 | vocab.set_embedding(glove_emb) 51 | 52 | emb_mat = vocab.embedding.idx_to_vec.asnumpy() 53 | np.save(dataset + '/glove_word_emb_100d', emb_mat) 54 | 55 | with open(dataset + '/glove_vocab.txt', 'w') as g: 56 | g.write('\n'.join(vocab.idx_to_token)) 57 | 58 | assert False 59 | 60 | rel_word_ids = np.ones((len(rel_words) + 1, max_num_words), dtype=int) # leave the first 1 for padding relation 61 | rel_emb_mat = [] 62 | for rel_idx, words in enumerate(rel_words): 63 | for i, word in enumerate(words): 64 | rel_word_ids[rel_idx + 1, i] = vocab.token_to_idx[word] 65 | 66 | np.save(dataset + '/rel_word_idx', rel_word_ids) 67 | 68 | all_relations = ['pad_rel'] + all_relations 69 | with open(rel_path, 'w') as g: 70 | g.write('\n'.join(all_relations)) 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import nltk 3 | import numpy as np 4 | import random 5 | import torch 6 | 7 | from collections import defaultdict 8 | from tqdm import tqdm 9 | from util import get_config 10 | from util import load_dict 11 | from util import load_documents 12 | 13 | class DataLoader(): 14 | def __init__(self, config, documents, mode='train'): 15 | self.mode = mode 16 | self.use_doc = config['use_doc'] 17 | self.use_inverse_relation = config['use_inverse_relation'] 18 | self.max_query_word = config['max_query_word'] 19 | self.max_document_word = config['max_document_word'] 20 | self.max_char = config['max_char'] 21 | self.documents = documents 22 | self.data_file = config['data_folder'] + config['{}_data'.format(mode)] 23 | self.batch_size = config['batch_size'] if mode == 'train' else config['batch_size'] 24 | self.max_rel_words = config['max_rel_words'] 25 | self.type_rels = config['type_rels'] 26 | self.fact_drop = config['fact_drop'] 27 | 28 | # read all data 29 | self.data = [] 30 | with open(self.data_file) as f: 31 | for line in tqdm(list(f)): 32 | self.data.append(json.loads(line)) 33 | 34 | # word and kb vocab 35 | self.word2id = load_dict(config['data_folder'] + config['word2id']) 36 | self.relation2id = load_dict(config['data_folder'] + config['relation2id']) 37 | self.entity2id = load_dict(config['data_folder'] + config['entity2id']) 38 | self.id2entity = {i:entity for entity, i in self.entity2id.items()} 39 | 40 | self.rel_word_idx = np.load(config['data_folder'] + 'rel_word_idx.npy') 41 | 42 | # for batching 43 | self.max_local_entity = 0 # max num of candidates 44 | self.max_relevant_docs = 0 # max num of retired documents 45 | self.max_kb_neighbors = config['max_num_neighbors'] # max num of neighbors for entity 46 | self.max_kb_neighbors_ = config['max_num_neighbors'] # kb relations are directed 47 | self.max_linked_entities = 0 # max num of linked entities for each doc 48 | self.max_linked_documents = 50 # max num of linked documents for each entity 49 | 50 | self.num_kb_relation = 2 * len(self.relation2id) if self.use_inverse_relation else len(self.relation2id) 51 | 52 | # get the batching parameters 53 | self.get_stats() 54 | 55 | def get_stats(self): 56 | if self.use_doc: 57 | # max_linked_entities 58 | self.useful_docs = {} # filter out documents with out linked entities 59 | for docid, doc in self.documents.items(): 60 | linked_entities = 0 61 | if 'title' in doc: 62 | linked_entities += len(doc['title']['entities']) 63 | offset = len(nltk.word_tokenize(doc['title']['text'])) 64 | else: 65 | offset = 0 66 | for ent in doc['document']['entities']: 67 | if ent['start'] + offset >= self.max_document_word: 68 | continue 69 | else: 70 | linked_entities += 1 71 | if linked_entities > 1: 72 | self.useful_docs[docid] = doc 73 | self.max_linked_entities = max(self.max_linked_entities, linked_entities) 74 | print('max num of linked entities: ', self.max_linked_entities) 75 | 76 | # decide how many neighbors should we consider 77 | # num_neighbors = [] 78 | 79 | num_tuples = [] 80 | 81 | # max_linked_documents, max_relevant_docs, max_local_entity 82 | for line in tqdm(self.data): 83 | candidate_ents = set() 84 | rel_docs = 0 85 | 86 | # question entity 87 | for ent in line['entities']: 88 | candidate_ents.add(ent['text']) 89 | # kb entities 90 | for ent in line['subgraph']['entities']: 91 | candidate_ents.add(ent['text']) 92 | 93 | num_tuples.append(line['subgraph']['tuples']) 94 | 95 | if self.use_doc: 96 | # entities in doc 97 | for passage in line['passages']: 98 | if passage['document_id'] not in self.useful_docs: 99 | continue 100 | rel_docs += 1 101 | document = self.useful_docs[int(passage['document_id'])] 102 | for ent in document['document']['entities']: 103 | candidate_ents.add(ent['text']) 104 | if 'title' in document: 105 | for ent in document['title']['entities']: 106 | candidate_ents.add(ent['text']) 107 | 108 | neighbors = defaultdict(list) 109 | neighbors_ = defaultdict(list) 110 | 111 | for triple in line['subgraph']['tuples']: 112 | s, r, o = triple 113 | neighbors[s['text']].append((r['text'], o['text'])) 114 | neighbors_[o['text']].append((r['text'], s['text'])) 115 | 116 | self.max_relevant_docs = max(self.max_relevant_docs, rel_docs) 117 | self.max_local_entity = max(self.max_local_entity, len(candidate_ents)) 118 | 119 | # np.save('num_neighbors_', num_neighbors) 120 | 121 | print('mean num of triples: ', len(num_tuples)) 122 | 123 | print('max num of relevant docs: ', self.max_relevant_docs) 124 | print('max num of candidate entities: ', self.max_local_entity) 125 | print('max_num of neighbors: ', self.max_kb_neighbors) 126 | print('max_num of neighbors inverse: ', self.max_kb_neighbors_) 127 | 128 | def batcher(self, shuffle=False): 129 | if shuffle: 130 | random.shuffle(self.data) 131 | 132 | device = torch.device('cuda') 133 | 134 | for batch_id in tqdm(range(0, len(self.data), self.batch_size)): 135 | batch = self.data[batch_id:batch_id + self.batch_size] 136 | 137 | batch_size = len(batch) 138 | questions = np.full((batch_size, self.max_query_word), 1, dtype=int) 139 | documents = np.full((batch_size, self.max_relevant_docs, self.max_document_word), 1, dtype=int) 140 | entity_link_documents = np.zeros((batch_size, self.max_local_entity, self.max_linked_documents, self.max_document_word), dtype=int) 141 | entity_link_doc_norm = np.zeros((batch_size, self.max_local_entity, self.max_linked_documents, self.max_document_word), dtype=int) 142 | documents_ans_span = np.zeros((batch_size, self.max_relevant_docs, 2), dtype=int) 143 | entity_link_ents = np.full((batch_size, self.max_local_entity, self.max_kb_neighbors_), -1, dtype=int) # incoming edges 144 | entity_link_rels = np.zeros((batch_size, self.max_local_entity, self.max_kb_neighbors_), dtype=int) 145 | candidate_entities = np.full((batch_size, self.max_local_entity), len(self.entity2id), dtype=int) 146 | ent_degrees = np.zeros((batch_size, self.max_local_entity), dtype=int) 147 | true_answers = np.zeros((batch_size, self.max_local_entity), dtype=float) 148 | query_entities = np.zeros((batch_size, self.max_local_entity), dtype=float) 149 | answers_ = [] 150 | questions_ = [] 151 | 152 | for i, sample in enumerate(batch): 153 | doc_global2local = {} 154 | # answer set 155 | answers = set() 156 | for answer in sample['answers']: 157 | keyword = 'text' if type(answer['kb_id']) == int else 'kb_id' 158 | answers.add(self.entity2id[answer[keyword]]) 159 | 160 | if self.mode != 'train': 161 | answers_.append(list(answers)) 162 | questions_.append(sample['question']) 163 | 164 | # candidate entities, linked_documents 165 | candidates = set() 166 | query_entity = set() 167 | ent2linked_docId = defaultdict(list) 168 | for ent in sample['entities']: 169 | candidates.add(self.entity2id[ent['text']]) 170 | query_entity.add(self.entity2id[ent['text']]) 171 | for ent in sample['subgraph']['entities']: 172 | candidates.add(self.entity2id[ent['text']]) 173 | 174 | if self.use_doc: 175 | for local_id, passage in enumerate(sample['passages']): 176 | if passage['document_id'] not in self.useful_docs: 177 | continue 178 | doc_id = int(passage['document_id']) 179 | doc_global2local[doc_id] = local_id 180 | document = self.useful_docs[doc_id] 181 | for word_pos, word in enumerate([''] + document['tokens']): 182 | if word_pos < self.max_document_word: 183 | documents[i, local_id, word_pos] = self.word2id.get(word, self.word2id['']) 184 | for ent in document['document']['entities']: 185 | if self.entity2id[ent['text']] in answers: 186 | documents_ans_span[i, local_id, 0] = min(ent['start'] + 1, self.max_document_word-1) 187 | documents_ans_span[i, local_id, 1] = min(ent['end'] + 1, self.max_document_word-1) 188 | s, e = ent['start'] + 1, ent['end'] + 1 189 | ent2linked_docId[self.entity2id[ent['text']]].append((doc_id, s, e)) 190 | candidates.add(self.entity2id[ent['text']]) 191 | if 'title' in document: 192 | for ent in document['title']['entities']: 193 | candidates.add(self.entity2id(ent['text'])) 194 | 195 | # kb information 196 | connections = defaultdict(list) 197 | 198 | if self.fact_drop and self.mode == 'train': 199 | all_triples = sample['subgraph']['tuples'] 200 | random.shuffle(all_triples) 201 | num_triples = len(all_triples) 202 | keep_ratio = 1 - self.fact_drop 203 | all_triples = all_triples[:int(num_triples * keep_ratio)] 204 | 205 | else: 206 | all_triples = sample['subgraph']['tuples'] 207 | 208 | for tpl in all_triples: 209 | s,r,o = tpl 210 | 211 | 212 | # only consider one direction of information propagation 213 | connections[self.entity2id[o['text']]].append((self.relation2id[r['text']], self.entity2id[s['text']])) 214 | 215 | if r['text'] in self.type_rels: 216 | connections[self.entity2id[s['text']]].append((self.relation2id[r['text']], self.entity2id[o['text']])) 217 | 218 | 219 | # used for updating entity representations 220 | ent_global2local = {} 221 | candidates = list(candidates) 222 | 223 | # if len(candidates) == 0: 224 | # print('No entities????') 225 | # print(sample) 226 | 227 | for j, entid in enumerate(candidates): 228 | if entid in query_entity: 229 | query_entities[i, j] = 1.0 230 | candidate_entities[i, j] = entid 231 | ent_global2local[entid] = j 232 | if entid in answers: true_answers[i, j] = 1.0 233 | for linked_doc in ent2linked_docId[entid]: 234 | start, end = linked_doc[1], linked_doc[2] 235 | if end - start > 0: 236 | entity_link_documents[i, j, doc_global2local[linked_doc[0]], start:end] = 1.0 237 | entity_link_doc_norm[i, j, doc_global2local[linked_doc[0]], start:end] = 1.0 238 | 239 | for j, entid in enumerate(candidates): 240 | for count, neighbor in enumerate(connections[entid]): 241 | if count < self.max_kb_neighbors_: 242 | r_id, s_id = neighbor 243 | # convert the global ent id to subgraph id, for graph convolution 244 | s_id_local = ent_global2local[s_id] 245 | entity_link_rels[i, j, count] = r_id 246 | entity_link_ents[i, j, count] = s_id_local 247 | ent_degrees[i, s_id_local] += 1 248 | 249 | # questions 250 | for j, word in enumerate(sample['question'].split()): 251 | if j < self.max_query_word: 252 | if word in self.word2id: 253 | questions[i, j] = self.word2id[word] 254 | else: 255 | questions[i, j] = self.word2id[''] 256 | 257 | if self.use_doc: 258 | # exact match features for docs 259 | d_cat = documents.reshape((batch_size, -1)) 260 | em_d = np.array([np.isin(d_, q_) for d_, q_ in zip(d_cat, questions)], dtype=int) # exact match features 261 | em_d = em_d.reshape((batch_size, self.max_relevant_docs, -1)) 262 | 263 | batch_dict = { 264 | 'questions': questions, # (B, q_len) 265 | 'candidate_entities': candidate_entities, 266 | 'entity_link_ents': entity_link_ents, 267 | 'answers': true_answers, 268 | 'query_entities': query_entities, 269 | 'answers_': answers_, 270 | 'questions_': questions_, 271 | 'rel_word_ids': self.rel_word_idx, # (num_rel+1, word_lens) 272 | 'entity_link_rels': entity_link_rels, # (bsize, max_num_candidates, max_num_neighbors) 273 | 'ent_degrees': ent_degrees 274 | } 275 | 276 | if self.use_doc: 277 | batch_dict['documents'] = documents 278 | batch_dict['documents_em'] = em_d 279 | batch_dict['ent_link_doc_spans'] = entity_link_documents 280 | batch_dict['documents_ans_span'] = documents_ans_span 281 | batch_dict['ent_link_doc_norm_spans'] = entity_link_doc_norm 282 | 283 | for k, v in batch_dict.items(): 284 | if k.endswith('_'): 285 | batch_dict[k] = v 286 | continue 287 | if not self.use_doc and 'doc' in k: 288 | continue 289 | batch_dict[k] = torch.from_numpy(v).to(device) 290 | yield batch_dict 291 | 292 | 293 | if __name__ == '__main__': 294 | cfg = get_config() 295 | documents = load_documents(cfg['data_folder'] + cfg['{}_documents'.format(cfg['mode'])]) 296 | # cfg['batch_size'] = 2 297 | train_data = DataLoader(cfg, documents) 298 | # build_squad_like_data(cfg['data_folder'] + cfg['{}_data'.format(cfg['mode'])], cfg['data_folder'] + cfg['{}_documents'.format(cfg['mode'])]) 299 | for batch in train_data.batcher(): 300 | print(batch['documents_ans_span']) 301 | assert False -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from modules import AttnEncoder 7 | from modules import Packed 8 | from modules import SeqAttnMatch 9 | from modules import l_relu 10 | from modules import QueryReform 11 | from modules import ConditionGate 12 | from util import load_dict 13 | 14 | 15 | class KAReader(nn.Module): 16 | """docstring for ClassName""" 17 | def __init__(self, args): 18 | super(KAReader, self).__init__() 19 | 20 | self.entity2id = load_dict(args['data_folder'] + args['entity2id']) 21 | self.word2id = load_dict(args['data_folder'] + args['word2id']) 22 | self.relation2id = load_dict(args['data_folder'] + args['relation2id']) 23 | self.num_entity = len(self.entity2id) 24 | self.num_relation = len(self.relation2id) 25 | self.num_word = len(self.word2id) 26 | self.num_layer = args['num_layer'] 27 | self.use_doc = args['use_doc'] 28 | self.word_drop = args['word_drop'] 29 | self.hidden_drop = args['hidden_drop'] 30 | self.label_smooth = args['label_smooth'] 31 | 32 | for k, v in args.items(): 33 | if k.endswith('dim'): 34 | setattr(self, k, v) 35 | if k.endswith('emb_file'): 36 | setattr(self, k, args['data_folder'] + v) 37 | 38 | # pretrained entity embeddings 39 | self.entity_emb = nn.Embedding(self.num_entity + 1, self.entity_dim, padding_idx=self.num_entity) 40 | self.entity_emb.weight.data.copy_(torch.from_numpy(np.pad(np.load(self.entity_emb_file), ((0, 1), (0, 0)), 'constant'))) 41 | self.entity_emb.weight.requires_grad = False 42 | self.entity_linear = nn.Linear(self.entity_dim, self.entity_dim) 43 | 44 | # word embeddings 45 | self.word_emb = nn.Embedding(self.num_word, self.word_dim, padding_idx=1) 46 | self.word_emb.weight.data.copy_(torch.from_numpy(np.load(self.word_emb_file))) 47 | self.word_emb.weight.requires_grad = False 48 | 49 | self.word_emb_match = SeqAttnMatch(self.word_dim) 50 | 51 | self.hidden_dim = self.entity_dim 52 | # question and doc encoder 53 | self.question_encoder = Packed(nn.LSTM(self.word_dim, self.hidden_dim // 2, batch_first=True, bidirectional=True)) 54 | 55 | 56 | self.self_att_r = AttnEncoder(self.hidden_dim) 57 | self.self_att_q = AttnEncoder(self.hidden_dim) 58 | self.combine_q_rel = nn.Linear(self.hidden_dim*2, self.hidden_dim) 59 | # doc encoder 60 | 61 | self.ent_info_proj = nn.Linear(self.hidden_dim, self.hidden_dim) 62 | self.input_proj = nn.Linear(2*self.word_dim + 1, self.hidden_dim) 63 | self.doc_encoder = Packed(nn.LSTM(self.hidden_dim, self.hidden_dim // 2, batch_first=True, bidirectional=True)) 64 | self.doc_to_ent = nn.Linear(self.hidden_dim, self.hidden_dim) 65 | 66 | self.ent_info_gate = ConditionGate(self.hidden_dim) 67 | self.ent_info_gate_out = ConditionGate(self.hidden_dim) 68 | 69 | self.kg_prop = nn.Linear(self.hidden_dim + self.entity_dim, self.entity_dim) 70 | self.kg_gate = nn.Linear(self.hidden_dim + self.entity_dim, self.entity_dim) 71 | self.self_prop = nn.Linear(self.entity_dim, self.entity_dim) 72 | self.combine_q = nn.Linear(2*self.hidden_dim, self.hidden_dim) 73 | 74 | self.reader_gate = nn.Linear(2*self.hidden_dim, self.hidden_dim) 75 | self.query_update = QueryReform(self.hidden_dim) 76 | 77 | self.attn_match = nn.Linear(self.hidden_dim*3, self.hidden_dim*2) 78 | self.attn_match_q = nn.Linear(self.hidden_dim*2, self.hidden_dim) 79 | self.loss = nn.BCEWithLogitsLoss() 80 | 81 | self.word_drop = nn.Dropout(self.word_drop) 82 | self.hidden_drop = nn.Dropout(self.hidden_drop) 83 | 84 | def forward(self, feed): 85 | # encode questions 86 | question = feed['questions'] 87 | q_mask = (question != 1).float() 88 | q_len = q_mask.sum(-1) # (B, q_len) 89 | q_word_emb = self.word_drop(self.word_emb(question)) 90 | q_emb, _ = self.question_encoder(q_word_emb, q_len, max_length=question.size(1)) 91 | q_emb = self.hidden_drop(q_emb) 92 | 93 | B, max_q_len = question.size(0), question.size(1) 94 | 95 | # candidate ent embeddings 96 | ent_emb_ = self.entity_emb(feed['candidate_entities']) 97 | ent_emb = l_relu(self.entity_linear(ent_emb_)) 98 | 99 | # # keep a copy of the initial ent_emb 100 | # init_ent_emb = ent_emb 101 | ent_mask = (feed['candidate_entities'] != self.num_entity).float() 102 | 103 | # linked relations 104 | max_num_neighbors = feed['entity_link_ents'].size(2) 105 | max_num_candidates = feed['candidate_entities'].size(1) 106 | neighbor_mask = (feed['entity_link_ents'] != self.num_entity).float() # (B, |C|, |N|) 107 | 108 | # encode all relations with question encoder 109 | rel_word_ids = feed['rel_word_ids'] 110 | rel_word_mask = (rel_word_ids != 1).float() 111 | rel_word_lens = rel_word_mask.sum(-1) 112 | rel_word_lens[rel_word_lens == 0] = 1 113 | rel_encoded, _ = self.question_encoder(self.word_drop(self.word_emb(rel_word_ids)), rel_word_lens, max_length=rel_word_ids.size(1)) # (|R|, r_len, h_dim) 114 | # rel_encoded, _ = self.relation_encoder(self.word_drop(self.word_emb(rel_word_ids)), rel_word_lens, max_length=rel_word_ids.size(1)) # (|R|, r_len, h_dim) 115 | rel_encoded = self.hidden_drop(rel_encoded) 116 | rel_encoded = self.self_att_r(rel_encoded, rel_word_mask) 117 | 118 | neighbor_rel_ids = feed['entity_link_rels'].long().view(-1) 119 | neighbor_rel_emb = torch.index_select(rel_encoded, dim=0, index=neighbor_rel_ids).view(B*max_num_candidates, max_num_neighbors, self.hidden_dim) 120 | 121 | # for look up 122 | neighbor_ent_local_index = feed['entity_link_ents'].long() # (B * |C| * max_num_neighbors) 123 | neighbor_ent_local_index = neighbor_ent_local_index.view(B, -1) 124 | neighbor_ent_local_mask = (neighbor_ent_local_index != -1).long() 125 | fix_index = torch.arange(B).long() * max_num_candidates 126 | fix_index = fix_index.to(torch.device('cuda')) 127 | neighbor_ent_local_index = neighbor_ent_local_index + fix_index.view(-1,1) 128 | neighbor_ent_local_index = (neighbor_ent_local_index + 1) * neighbor_ent_local_mask 129 | neighbor_ent_local_index = neighbor_ent_local_index.view(-1) 130 | 131 | ent_seed_info = feed['query_entities'].float() # seed entity will have 1.0 score 132 | ent_is_seed = torch.cat([torch.zeros(1).to(torch.device('cuda')), ent_seed_info.view(-1)], dim=0) 133 | ent_seed_indicator = torch.index_select(ent_is_seed, dim=0, index=neighbor_ent_local_index).view(B*max_num_candidates, max_num_neighbors) 134 | 135 | # v0.0 more find-grained attention 136 | q_emb_expand = q_emb.unsqueeze(1).expand(B, max_num_candidates, max_q_len, -1).contiguous() 137 | q_emb_expand = q_emb_expand.view(B*max_num_candidates, max_q_len, -1) 138 | q_mask_expand = q_mask.unsqueeze(1).expand(B, max_num_candidates, -1).contiguous() 139 | q_mask_expand = q_mask_expand.view(B*max_num_candidates, -1) 140 | q_n_affinity = torch.bmm(q_emb_expand, neighbor_rel_emb.transpose(1, 2)) # (bsize*max_num_candidates, q_len, max_num_neighbors) 141 | q_n_affinity_mask_q = q_n_affinity - (1 - q_mask_expand.unsqueeze(2)) * 1e20 142 | q_n_affinity_mask_n = q_n_affinity - (1 - neighbor_mask.view(B*max_num_candidates, 1, max_num_neighbors)) 143 | normalize_over_q = F.softmax(q_n_affinity_mask_q, dim=1) 144 | normalize_over_n = F.softmax(q_n_affinity_mask_n, dim=2) 145 | retrieve_q = torch.bmm(normalize_over_q.transpose(1,2), q_emb_expand) 146 | q_rel_simi = torch.sum(neighbor_rel_emb * retrieve_q, dim=2) 147 | 148 | init_q_emb = self.self_att_r(q_emb, q_mask) 149 | 150 | retrieve_r = torch.bmm(normalize_over_n, neighbor_rel_emb) 151 | q_and_rel = torch.cat([q_emb_expand, retrieve_r], dim=2) 152 | rel_aware_q = self.combine_q_rel(q_and_rel).tanh().view(B, max_num_candidates, -1, self.hidden_dim) 153 | 154 | # pooling over the q_len dim 155 | q_node_emb = rel_aware_q.max(2)[0] 156 | 157 | ent_emb = l_relu(self.combine_q(torch.cat([ent_emb, q_node_emb], dim=2))) 158 | ent_emb_for_lookup = ent_emb.view(-1, self.entity_dim) 159 | ent_emb_for_lookup = torch.cat([torch.zeros(1, self.entity_dim).to(torch.device('cuda')), ent_emb_for_lookup], dim=0) 160 | neighbor_ent_emb = torch.index_select(ent_emb_for_lookup, dim=0, index=neighbor_ent_local_index) 161 | neighbor_ent_emb = neighbor_ent_emb.view(B*max_num_candidates, max_num_neighbors, -1) 162 | neighbor_vec = torch.cat([neighbor_rel_emb, neighbor_ent_emb], dim =-1).view(B*max_num_candidates, max_num_neighbors, -1) # for propagation 163 | neighbor_scores = q_rel_simi * ent_seed_indicator 164 | neighbor_scores = neighbor_scores - (1 - neighbor_mask.view(B*max_num_candidates, max_num_neighbors)) * 1e8 165 | attn_score = F.softmax(neighbor_scores, dim=1) 166 | aggregate = self.kg_prop(neighbor_vec) * attn_score.unsqueeze(2) 167 | aggregate = l_relu(aggregate.sum(1)).view(B, max_num_candidates, -1) 168 | self_prop_ = l_relu(self.self_prop(ent_emb)) 169 | gate_value = self.kg_gate(torch.cat([aggregate, ent_emb], dim = -1)).sigmoid() 170 | ent_emb = gate_value * self_prop_ + (1 - gate_value) * aggregate 171 | 172 | # read documents 173 | if self.use_doc: 174 | q_for_text = self.query_update(init_q_emb, ent_emb, ent_seed_info, ent_mask) 175 | # q_for_text = q_node_emb.mean(1) 176 | # q_for_text = init_q_emb 177 | 178 | q_node_emb = torch.cat([q_node_emb, q_for_text.unsqueeze(1).expand_as(q_node_emb).contiguous()], dim=-1) 179 | 180 | ent_linked_doc_spans = feed['ent_link_doc_spans'] 181 | doc = feed['documents'] # (B, |D|, d_len) 182 | max_num_doc = doc.size(1) 183 | max_d_len = doc.size(2) 184 | doc_mask = (doc != 1).float() 185 | doc_len = doc_mask.sum(-1) 186 | doc_len += (doc_len == 0).float() # padded documents have 0 words 187 | doc_len = doc_len.view(-1) 188 | d_word_emb = self.word_drop(self.word_emb(doc.view(-1, doc.size(-1)))) # (B*|D|, d_len, emb_dim) 189 | 190 | # input features for documents 191 | q_word_emb = q_word_emb.unsqueeze(1).expand(B, max_num_doc, max_q_len, self.word_dim).contiguous() 192 | q_word_emb = q_word_emb.view(B*max_num_doc, max_q_len, -1) 193 | q_mask_ = (question == 1).unsqueeze(1).expand(B, max_num_doc, max_q_len).contiguous() 194 | q_mask_ = q_mask_.view(B*max_num_doc, -1) 195 | q_weighted_emb = self.word_emb_match(d_word_emb, q_word_emb, q_mask_) 196 | doc_em = feed['documents_em'].float().view(B*max_num_doc, max_d_len, 1) 197 | doc_input = torch.cat([d_word_emb, q_weighted_emb, doc_em], dim=-1) # 2*word_dim + 1 198 | 199 | doc_input = self.input_proj(doc_input).tanh() 200 | word_entity_id = ent_linked_doc_spans.view(B, max_num_candidates, -1).transpose(1,2) 201 | word_ent_info_mask = (word_entity_id.sum(-1, keepdim=True) != 0).float() 202 | word_ent_info = torch.bmm(word_entity_id.float(), ent_emb) # (B, |D|*d_len, h_dim) 203 | word_ent_info = self.ent_info_proj(word_ent_info).tanh() 204 | doc_input = self.ent_info_gate(q_for_text.unsqueeze(1), word_ent_info, doc_input.view(B, max_num_doc*max_d_len, -1), word_ent_info_mask) 205 | 206 | d_emb, _ = self.doc_encoder(doc_input.view(B*max_num_doc, max_d_len, -1), doc_len, max_length=doc.size(2)) 207 | d_emb = self.hidden_drop(d_emb) 208 | 209 | d_emb = self.ent_info_gate_out(q_for_text.unsqueeze(1), word_ent_info, d_emb.view(B, max_num_doc*max_d_len, -1), word_ent_info_mask).view(B*max_num_doc, max_d_len, -1) 210 | 211 | q_for_text = q_for_text.unsqueeze(1).expand(B, max_num_doc, self.hidden_dim).contiguous() 212 | q_for_text = q_for_text.view(B*max_num_doc, -1) # (B*|D|, h_dim) 213 | d_emb = d_emb.view(B*max_num_doc, max_d_len, -1) # (B*|D|, d_len, h_dim) 214 | q_over_d = torch.bmm(q_for_text.unsqueeze(1), d_emb.transpose(1,2)).squeeze(1) # (B*|D|, d_len) 215 | q_over_d = F.softmax(q_over_d - (1 - doc_mask.view(B*max_num_doc, max_d_len))*1e8, dim=-1) 216 | q_retrieve_d = torch.bmm(q_over_d.unsqueeze(1), d_emb).view(B, max_num_doc, -1) # (B, |D|, h_dim) 217 | ent_linked_doc = (ent_linked_doc_spans.sum(-1) != 0).float() # (B, |C|, |D|) 218 | ent_emb_from_doc = torch.bmm(ent_linked_doc, q_retrieve_d) # (B, |C|, h_dim) 219 | # ent_emb_from_doc = F.dropout(ent_emb_from_doc, 0.5, self.training) 220 | 221 | # retrieve_span 222 | ent_emb_from_span = torch.bmm(feed['ent_link_doc_norm_spans'].float().view(B, max_num_candidates, -1), d_emb.view(B, max_num_doc*max_d_len, -1)) 223 | ent_emb_from_span = F.dropout(ent_emb_from_span, 0.2, self.training) 224 | 225 | 226 | # refine KB ent_emb 227 | # refined_ent_emb = self.refine_ent(ent_emb, ent_emb_from_doc) 228 | if self.use_doc: 229 | ent_emb = l_relu(self.attn_match(torch.cat([ent_emb, ent_emb_from_doc, ent_emb_from_span], dim=-1))) 230 | # q_node_emb = self.attn_match_q(q_node_emb) 231 | 232 | ent_scores = (q_node_emb * ent_emb).sum(2) 233 | 234 | answers = feed['answers'].float() 235 | if self.label_smooth: 236 | answers = ((1.0 - self.label_smooth)*answers) + (self.label_smooth/answers.size(1)) 237 | 238 | loss = self.loss(ent_scores, feed['answers'].float()) 239 | 240 | pred_dist = (ent_scores - (1-ent_mask) * 1e8).sigmoid() * ent_mask 241 | pred = torch.max(ent_scores, dim=1)[1] 242 | 243 | return loss, pred, pred_dist 244 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import math 6 | 7 | 8 | class Packed(nn.Module): 9 | 10 | def __init__(self, rnn): 11 | super().__init__() 12 | self.rnn = rnn 13 | 14 | @property 15 | def batch_first(self): 16 | return self.rnn.batch_first 17 | 18 | def forward(self, inputs, lengths, hidden=None, max_length=None): 19 | lens, indices = torch.sort(lengths, 0, True) 20 | inputs = inputs[indices] if self.batch_first else inputs[:, indices] 21 | outputs, (h, c) = self.rnn(nn.utils.rnn.pack_padded_sequence(inputs, lens.tolist(), batch_first=self.batch_first), hidden) 22 | outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=self.batch_first, total_length=max_length) 23 | _, _indices = torch.sort(indices, 0) 24 | outputs = outputs[_indices] if self.batch_first else outputs[:, _indices] 25 | h, c = h[:, _indices, :], c[:, _indices, :] 26 | return outputs, (h, c) 27 | 28 | def gelu(x): 29 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 30 | 31 | def l_relu(x, n_slope=0.01): 32 | return F.leaky_relu(x, n_slope) 33 | 34 | class ConditionGate(nn.Module): 35 | """docstring for ConditionGate""" 36 | def __init__(self, h_dim): 37 | super(ConditionGate, self).__init__() 38 | self.gate = nn.Linear(2*h_dim, h_dim, bias=False) 39 | # self.q_to_x = nn.Linear(h_dim, h_dim) 40 | # self.q_to_y = nn.Linear(h_dim, h_dim) 41 | 42 | def forward(self, q, x, y, gate_mask): 43 | q_x_sim = x*q 44 | q_y_sim = y*q 45 | gate_val = self.gate(torch.cat([q_x_sim, q_y_sim], dim=-1)).sigmoid() 46 | gate_val = gate_val * gate_mask 47 | return gate_val * x + (1 - gate_val) * y 48 | 49 | 50 | class Fusion(nn.Module): 51 | """docstring for Fusion""" 52 | def __init__(self, d_hid): 53 | super(Fusion, self).__init__() 54 | self.r = nn.Linear(d_hid*4, d_hid, bias=False) 55 | self.g = nn.Linear(d_hid*4, d_hid, bias=False) 56 | 57 | def forward(self, x, y): 58 | r_ = self.r(torch.cat([x,y,x-y,x*y], dim=-1)).tanh() 59 | g_ = torch.sigmoid(self.g(torch.cat([x,y,x-y,x*y], dim=-1))) 60 | return g_ * r_ + (1 - g_) * x 61 | 62 | class AttnEncoder(nn.Module): 63 | """docstring for ClassName""" 64 | def __init__(self, d_hid): 65 | super(AttnEncoder, self).__init__() 66 | self.attn_linear = nn.Linear(d_hid, 1, bias=False) 67 | 68 | def forward(self, x, x_mask): 69 | """ 70 | x: (B, len, d_hid) 71 | x_mask: (B, len) 72 | return: (B, d_hid) 73 | """ 74 | x_attn = self.attn_linear(x) 75 | x_attn = x_attn - (1 - x_mask.unsqueeze(2))*1e8 76 | x_attn = F.softmax(x_attn, dim=1) 77 | return (x*x_attn).sum(1) 78 | 79 | class BilinearSeqAttn(nn.Module): 80 | """A bilinear attention layer over a sequence X w.r.t y: 81 | * o_i = softmax(x_i'Wy) for x_i in X. 82 | Optionally don't normalize output weights. 83 | """ 84 | 85 | def __init__(self, x_size, y_size, identity=False, normalize=True): 86 | super(BilinearSeqAttn, self).__init__() 87 | self.normalize = normalize 88 | 89 | # If identity is true, we just use a dot product without transformation. 90 | if not identity: 91 | self.linear = nn.Linear(y_size, x_size) 92 | else: 93 | self.linear = None 94 | 95 | def forward(self, x, y, x_mask): 96 | """ 97 | Args: 98 | x: batch * len * hdim1 99 | y: batch * hdim2 100 | x_mask: batch * len (1 for padding, 0 for true) 101 | Output: 102 | alpha = batch * len 103 | """ 104 | Wy = self.linear(y) if self.linear is not None else y 105 | xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2) 106 | xWy.data.masked_fill_(x_mask.data, -float('inf')) 107 | if self.normalize: 108 | if self.training: 109 | # In training we output log-softmax for NLL 110 | alpha = F.log_softmax(xWy, dim=-1) 111 | else: 112 | # ...Otherwise 0-1 probabilities 113 | alpha = F.softmax(xWy, dim=-1) 114 | else: 115 | alpha = xWy.exp() 116 | return alpha 117 | 118 | class SeqAttnMatch(nn.Module): 119 | """Given sequences X and Y, match sequence Y to each element in X. 120 | * o_i = sum(alpha_j * y_j) for i in X 121 | * alpha_j = softmax(y_j * x_i) 122 | """ 123 | 124 | def __init__(self, input_size, identity=False): 125 | super(SeqAttnMatch, self).__init__() 126 | if not identity: 127 | self.linear = nn.Linear(input_size, input_size) 128 | else: 129 | self.linear = None 130 | 131 | def forward(self, x, y, y_mask): 132 | """ 133 | Args: 134 | x: batch * len1 * hdim 135 | y: batch * len2 * hdim 136 | y_mask: batch * len2 (1 for padding, 0 for true) 137 | Output: 138 | matched_seq: batch * len1 * hdim 139 | """ 140 | # Project vectors 141 | if self.linear: 142 | x_proj = self.linear(x.view(-1, x.size(2))).view(x.size()) 143 | x_proj = F.relu(x_proj) 144 | y_proj = self.linear(y.view(-1, y.size(2))).view(y.size()) 145 | y_proj = F.relu(y_proj) 146 | else: 147 | x_proj = x 148 | y_proj = y 149 | 150 | # Compute scores 151 | scores = x_proj.bmm(y_proj.transpose(2, 1)) 152 | 153 | # Mask padding 154 | y_mask = y_mask.unsqueeze(1).expand(scores.size()) 155 | scores.data.masked_fill_(y_mask.data, -float('inf')) 156 | 157 | # Normalize with softmax 158 | alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1) 159 | alpha = alpha_flat.view(-1, x.size(1), y.size(1)) 160 | 161 | # Take weighted average 162 | matched_seq = alpha.bmm(y) 163 | 164 | return matched_seq 165 | 166 | 167 | class QueryReform(nn.Module): 168 | """docstring for QueryReform""" 169 | def __init__(self, h_dim): 170 | super(QueryReform, self).__init__() 171 | # self.q_encoder = AttnEncoder(h_dim) 172 | self.fusion = Fusion(h_dim) 173 | self.q_ent_attn = nn.Linear(h_dim, h_dim) 174 | 175 | def forward(self, q_node, ent_emb, seed_info, ent_mask): 176 | ''' 177 | q: (B,q_len,h_dim) 178 | q_mask: (B,q_len) 179 | q_ent_span: (B,q_len) 180 | ent_emb: (B,C,h_dim) 181 | seed_info: (B, C) 182 | ent_mask: (B, C) 183 | ''' 184 | # q_node = self.q_encoder(q, q_mask) 185 | q_ent_attn = (self.q_ent_attn(q_node).unsqueeze(1) * ent_emb).sum(2, keepdim=True) 186 | q_ent_attn = F.softmax(q_ent_attn - (1 - ent_mask.unsqueeze(2)) * 1e8, dim=1) 187 | # attn_retrieve = (q_ent_attn * ent_emb).sum(1) 188 | 189 | seed_retrieve = torch.bmm(seed_info.unsqueeze(1), ent_emb).squeeze(1) # (B, 1, h_dim) 190 | # how to calculate the gate 191 | 192 | # return self.fusion(q_node, attn_retrieve) 193 | return self.fusion(q_node, seed_retrieve) 194 | 195 | 196 | # retrieved = self.transform(torch.cat([seed_retrieve, attn_retrieve], dim=-1)).relu() 197 | # gate_val = self.gate(torch.cat([q.squeeze(1), seed_retrieve, attn_retrieve], dim=-1)).sigmoid() 198 | # return self.fusion(q.squeeze(1), retrieved).unsqueeze(1) 199 | # return (gate_val * q.squeeze(1) + (1 - gate_val) * torch.tanh(self.transform(torch.cat([q.squeeze(1), seed_retrieve, attn_retrieve], dim=-1)))).unsqueeze(1) 200 | 201 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=$1 python train.py --model_id $2 --num_layer 1 --max_num_neighbors 50 --label_smooth 0.1 --data_folder datasets/webqsp/full/ -------------------------------------------------------------------------------- /run_with_doc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=$1 python train.py --model_id $2 --num_layer 1 --max_num_neighbors 100 --use_doc --data_folder datasets/webqsp/kb_05/ --eps 0.12 -------------------------------------------------------------------------------- /script.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | from tqdm import tqdm 4 | from collections import Counter 5 | from itertools import izip 6 | 7 | def combine_dist(dist1, dist2, w1): 8 | ensemble_dist = dist2.copy() 9 | for gid, prob in dist1.items(): 10 | if gid in ensemble_dist: 11 | ensemble_dist[gid] = (1 - w1) * ensemble_dist[gid] + w1 * prob 12 | else: 13 | ensemble_dist[gid] = prob 14 | return ensemble_dist 15 | 16 | def get_one_f1(entities, dist, eps, answers): 17 | best_entity = -1 18 | max_prob = 0.0 19 | preds = [] 20 | for entity in entities: 21 | if dist[entity] > max_prob: 22 | max_prob = dist[entity] 23 | best_entity = entity 24 | if dist[entity] > eps: 25 | preds.append(entity) 26 | 27 | return cal_eval_metric(best_entity, preds, answers) 28 | 29 | def cal_eval_metric(best_pred, preds, answers): 30 | correct, total = 0.0, 0.0 31 | for entity in preds: 32 | if entity in answers: 33 | correct += 1 34 | total += 1 35 | if len(answers) == 0: 36 | if total == 0: 37 | return 1.0, 1.0, 1.0, 1.0 # precision, recall, f1, hits 38 | else: 39 | return 0.0, 1.0, 0.0, 1.0 # precision, recall, f1, hits 40 | else: 41 | hits = float(best_pred in answers) 42 | if total == 0: 43 | return 1.0, 0.0, 0.0, hits # precision, recall, f1, hits 44 | else: 45 | precision, recall = correct / total, correct / len(answers) 46 | f1 = 2.0 / (1.0 / precision + 1.0 / recall) if precision != 0 and recall != 0 else 0.0 47 | return precision, recall, f1, hits 48 | 49 | def compare_pr(kb_pred_file, doc_pred_file, hybrid_pred_file, w_kb, eps_doc, eps_kb, eps_ensemble, eps_hybrid, eps_ensemble_all): 50 | doc_only_recall, doc_only_precision, doc_only_f1, doc_only_hits = [], [], [], [] 51 | kb_only_recall, kb_only_precision, kb_only_f1, kb_only_hits = [], [], [], [] 52 | ensemble_recall, ensemble_precision, ensemble_f1, ensemble_hits = [], [], [], [] 53 | hybrid_recall, hybrid_precision, hybrid_f1, hybrid_hits = [], [], [], [] 54 | ensemble_all_recall, ensemble_all_precision, ensemble_all_f1, ensemble_all_hits = [], [], [], [] 55 | 56 | # total_not_answerable = 0.0 57 | with open(kb_pred_file) as f_kb, open(doc_pred_file) as f_doc, open(hybrid_pred_file) as f_hybrid: 58 | line_id = 0 59 | for line_kb, line_doc, line_hybrid in tqdm(zip(f_kb, f_doc, f_hybrid)): 60 | line_id += 1 61 | line_kb = json.loads(line_kb) 62 | line_doc = json.loads(line_doc) 63 | line_hybrid = json.loads(line_hybrid) 64 | assert line_kb['answers'] == line_doc['answers'] == line_hybrid['answers'] 65 | answers = set([unicode(answer) for answer in line_kb['answers']]) 66 | # total_not_answerable += (len(answers) == 0) 67 | # assert len(answers) > 0 68 | 69 | dist_kb = line_kb['dist'] 70 | dist_doc = line_doc['dist'] 71 | dist_hybrid = line_hybrid['dist'] 72 | dist_ensemble = combine_dist(dist_kb, dist_doc, w_kb) 73 | dist_ensemble_all = combine_dist(dist_ensemble, dist_hybrid, w1=0.3) 74 | 75 | kb_entities = set(dist_kb.keys()) 76 | doc_entities = set(dist_doc.keys()) 77 | either_entities = kb_entities | doc_entities 78 | assert either_entities == set(dist_hybrid.keys()) 79 | 80 | p, r, f1, hits = get_one_f1(doc_entities, dist_doc, eps_doc, answers) 81 | doc_only_precision.append(p) 82 | doc_only_recall.append(r) 83 | doc_only_f1.append(f1) 84 | doc_only_hits.append(hits) 85 | 86 | p, r, f1, hits = get_one_f1(kb_entities, dist_kb, eps_kb, answers) 87 | kb_only_precision.append(p) 88 | kb_only_recall.append(r) 89 | kb_only_f1.append(f1) 90 | kb_only_hits.append(hits) 91 | 92 | p, r, f1, hits = get_one_f1(either_entities, dist_ensemble, eps_ensemble, answers) 93 | ensemble_precision.append(p) 94 | ensemble_recall.append(r) 95 | ensemble_f1.append(f1) 96 | ensemble_hits.append(hits) 97 | 98 | p, r, f1, hits = get_one_f1(either_entities, dist_hybrid, eps_hybrid, answers) 99 | hybrid_precision.append(p) 100 | hybrid_recall.append(r) 101 | hybrid_f1.append(f1) 102 | hybrid_hits.append(hits) 103 | 104 | p, r, f1, hits = get_one_f1(either_entities, dist_ensemble_all, eps_ensemble_all, answers) 105 | ensemble_all_precision.append(p) 106 | ensemble_all_recall.append(r) 107 | ensemble_all_f1.append(f1) 108 | ensemble_all_hits.append(hits) 109 | 110 | 111 | print('text only setting:') 112 | print('hits: ', sum(doc_only_hits) / len(doc_only_hits)) 113 | print('precision: ', sum(doc_only_precision) / len(doc_only_precision)) 114 | print('recall: ', sum(doc_only_recall) / len(doc_only_recall)) 115 | print('f1: ', sum(doc_only_f1) / len(doc_only_f1)) 116 | print('\n') 117 | 118 | print('kb only setting:') 119 | print('hits: ', sum(kb_only_hits) / len(kb_only_hits)) 120 | print('precision: ', sum(kb_only_precision) / len(kb_only_precision)) 121 | print('recall: ', sum(kb_only_recall) / len(kb_only_recall)) 122 | print('f1: ', sum(kb_only_f1) / len(kb_only_f1)) 123 | print('\n') 124 | 125 | print('late fusion:') 126 | print('hits: ', sum(ensemble_hits) / len(ensemble_hits)) 127 | print('precision: ', sum(ensemble_precision) / len(ensemble_precision)) 128 | print('recall: ', sum(ensemble_recall) / len(ensemble_recall)) 129 | print('f1: ', sum(ensemble_f1) / len(ensemble_f1)) 130 | print('\n') 131 | 132 | print('early fusion:') 133 | print('hits: ', sum(hybrid_hits) / len(hybrid_hits)) 134 | print('precision: ', sum(hybrid_precision) / len(hybrid_precision)) 135 | print('recall: ', sum(hybrid_recall) / len(hybrid_recall)) 136 | print('f1: ', sum(hybrid_f1) / len(hybrid_f1)) 137 | print('\n') 138 | 139 | print('early & late fusion:') 140 | print('hits: ', sum(ensemble_all_hits) / len(ensemble_all_hits)) 141 | print('precision: ', sum(ensemble_all_precision) / len(ensemble_all_precision)) 142 | print('recall: ', sum(ensemble_all_recall) / len(ensemble_all_recall)) 143 | print('f1: ', sum(ensemble_all_f1) / len(ensemble_all_f1)) 144 | print('\n') 145 | 146 | 147 | if __name__ == "__main__": 148 | dataset = sys.argv[1] 149 | pred_kb_file = sys.argv[2] 150 | pred_doc_file = sys.argv[3] 151 | pred_hybrid_file = sys.argv[4] 152 | if dataset == "wikimovie": 153 | w_kb = 0.9 154 | eps_doc, eps_kb, eps_ensemble, eps_hybrid, eps_ensemble_all = 0.5, 0.55, 0.6, 0.5, 0.55 155 | elif dataset == "webqsp": 156 | w_kb = 1.0 157 | eps_doc, eps_kb, eps_ensemble, eps_hybrid, eps_ensemble_all = 0.15, 0.2, 0.2, 0.2, 0.3 158 | else: 159 | assert False, "dataset not recognized" 160 | 161 | compare_pr(pred_kb_file, pred_doc_file, pred_hybrid_file, w_kb, eps_doc, eps_kb, eps_ensemble, eps_hybrid, eps_ensemble_all) 162 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=$1 python train.py --model_id $2 --num_layer 1 --max_num_neighbors 50 --mode test --eps 0.08 --data_folder datasets/webqsp/kb_03/ -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | from data_generator import DataLoader 6 | from model import KAReader 7 | from util import get_config, cal_accuracy, load_documents 8 | 9 | from tensorboardX import SummaryWriter 10 | 11 | def f1_and_hits(answers, candidate2prob, eps): 12 | retrieved = [] 13 | correct = 0 14 | best_ans, max_prob = -1, 0 15 | for c, prob in candidate2prob.items(): 16 | if prob > max_prob: 17 | max_prob = prob 18 | best_ans = c 19 | if prob > eps: 20 | retrieved.append(c) 21 | if c in answers: 22 | correct += 1 23 | if len(answers) == 0: 24 | if len(retrieved) == 0: 25 | return 1.0, 1.0 26 | else: 27 | return 0.0, 1.0 28 | else: 29 | hits = float(best_ans in answers) 30 | if len(retrieved) == 0: 31 | return 0.0, hits 32 | else: 33 | p, r = correct / len(retrieved), correct / len(answers) 34 | f1 = 2.0 / (1.0 / p + 1.0 / r) if p != 0 and r != 0 else 0.0 35 | return f1, hits 36 | 37 | def get_best_ans(candidate2prob): 38 | best_ans, max_prob = -1, 0 39 | for c, prob in candidate2prob.items(): 40 | if prob > max_prob: 41 | max_prob = prob 42 | best_ans = c 43 | return best_ans 44 | 45 | def train(cfg): 46 | tf_logger = SummaryWriter('tf_logs/' + cfg['model_id']) 47 | 48 | # train and test share the same set of documents 49 | documents = load_documents(cfg['data_folder'] + cfg['{}_documents'.format(cfg['mode'])]) 50 | 51 | # train data 52 | train_data = DataLoader(cfg, documents) 53 | valid_data = DataLoader(cfg, documents, mode='dev') 54 | 55 | model = KAReader(cfg) 56 | model = model.to(torch.device('cuda')) 57 | 58 | trainable = filter(lambda p: p.requires_grad, model.parameters()) 59 | optim = torch.optim.Adam(trainable, lr=cfg['learning_rate']) 60 | 61 | if cfg['lr_schedule']: 62 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, [30], gamma=0.5) 63 | 64 | model.train() 65 | best_val_f1 = 0 66 | best_val_hits = 0 67 | for epoch in range(cfg['num_epoch']): 68 | batcher = train_data.batcher(shuffle=True) 69 | train_loss = [] 70 | for feed in batcher: 71 | loss, pred, pred_dist = model(feed) 72 | train_loss.append(loss.item()) 73 | # acc, max_acc = cal_accuracy(pred, feed['answers'].cpu().numpy()) 74 | # train_acc.append(acc) 75 | # train_max_acc.append(max_acc) 76 | optim.zero_grad() 77 | loss.backward() 78 | if cfg['gradient_clip'] != 0: 79 | torch.nn.utils.clip_grad_norm_(trainable, cfg['gradient_clip']) 80 | optim.step() 81 | tf_logger.add_scalar('avg_batch_loss', np.mean(train_loss), epoch) 82 | 83 | val_f1, val_hits = test(model, valid_data, cfg['eps']) 84 | if cfg['lr_schedule']: 85 | scheduler.step() 86 | tf_logger.add_scalar('eval_f1', val_f1, epoch) 87 | tf_logger.add_scalar('eval_hits', val_hits, epoch) 88 | if val_f1 > best_val_f1: 89 | best_val_f1 = val_f1 90 | if val_hits > best_val_hits: 91 | best_val_hits = val_hits 92 | torch.save(model.state_dict(), 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id'])) 93 | print('evaluation best f1:{} current:{}'.format(best_val_f1, val_f1)) 94 | print('evaluation best hits:{} current:{}'.format(best_val_hits, val_hits)) 95 | 96 | print('save final model') 97 | torch.save(model.state_dict(), 'model/{}/{}_final.pt'.format(cfg['name'], cfg['model_id'])) 98 | 99 | 100 | # model_save_path = 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id']) 101 | # model.load_state_dict(torch.load(model_save_path)) 102 | 103 | 104 | print('\n..........Finished training, start testing.......') 105 | 106 | test_data = DataLoader(cfg, documents, mode='test') 107 | model.eval() 108 | print('finished training, testing final model...') 109 | test(model, test_data, cfg['eps']) 110 | 111 | # print('testing best model...') 112 | # model_save_path = 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id']) 113 | # model.load_state_dict(torch.load(model_save_path)) 114 | # model.eval() 115 | # test(model, test_data, cfg['eps']) 116 | 117 | 118 | def test(model, test_data, eps): 119 | 120 | model.eval() 121 | batcher = test_data.batcher() 122 | id2entity = test_data.id2entity 123 | f1s, hits = [], [] 124 | questions = [] 125 | pred_answers = [] 126 | for feed in batcher: 127 | _, pred, pred_dist = model(feed) 128 | acc, max_acc = cal_accuracy(pred, feed['answers'].cpu().numpy()) 129 | batch_size = pred_dist.size(0) 130 | batch_answers = feed['answers_'] 131 | questions += feed['questions_'] 132 | batch_candidates = feed['candidate_entities'] 133 | pad_ent_id = len(id2entity) 134 | for batch_id in range(batch_size): 135 | answers = batch_answers[batch_id] 136 | candidates = batch_candidates[batch_id,:].tolist() 137 | probs = pred_dist[batch_id, :].tolist() 138 | candidate2prob = {} 139 | for c, p in zip(candidates, probs): 140 | if c == pad_ent_id: 141 | continue 142 | else: 143 | candidate2prob[c] = p 144 | f1, hit = f1_and_hits(answers, candidate2prob, eps) 145 | best_ans = get_best_ans(candidate2prob) 146 | best_ans = id2entity.get(best_ans, '') 147 | 148 | pred_answers.append(best_ans) 149 | f1s.append(f1) 150 | hits.append(hit) 151 | print('evaluation.......') 152 | print('how many eval samples......', len(f1s)) 153 | print('avg_f1', np.mean(f1s)) 154 | print('avg_hits', np.mean(hits)) 155 | 156 | model.train() 157 | return np.mean(f1s), np.mean(hits) 158 | 159 | if __name__ == "__main__": 160 | # config_file = sys.argv[2] 161 | cfg = get_config() 162 | random.seed(cfg['seed']) 163 | np.random.seed(cfg['seed']) 164 | torch.manual_seed(cfg['seed']) 165 | torch.cuda.manual_seed_all(cfg['seed']) 166 | if cfg['mode'] == 'train': 167 | train(cfg) 168 | elif cfg['mode'] == 'test': 169 | documents = load_documents(cfg['data_folder'] + cfg['{}_documents'.format(cfg['mode'])]) 170 | test_data = DataLoader(cfg, documents, mode='test') 171 | model = KAReader(cfg) 172 | model = model.to(torch.device('cuda')) 173 | model_save_path = 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id']) 174 | model.load_state_dict(torch.load(model_save_path)) 175 | model.eval() 176 | test(model, test_data, cfg['eps']) 177 | else: 178 | assert False, "--train or --test?" 179 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import nltk 3 | import numpy as np 4 | import os 5 | import torch 6 | import yaml 7 | 8 | from collections import Counter 9 | 10 | from torch.autograd import Variable 11 | from tqdm import tqdm 12 | 13 | import argparse 14 | 15 | def get_config(config_path=None): 16 | if not config_path: 17 | parser = argparse.ArgumentParser() 18 | 19 | # datasets 20 | parser.add_argument('--name', default='webqsp', type=str) 21 | parser.add_argument('--data_folder', default='datasets/webqsp/kb_03/', type=str) 22 | parser.add_argument('--train_data', default='train.json', type=str) 23 | parser.add_argument('--train_documents', default='documents.json', type=str) 24 | parser.add_argument('--dev_data', default='dev.json', type=str) 25 | parser.add_argument('--dev_documents', default='documents.json', type=str) 26 | parser.add_argument('--test_data', default='test.json', type=str) 27 | parser.add_argument('--test_documents', default='documents.json', type=str) 28 | parser.add_argument('--max_query_word', default=10, type=int) 29 | parser.add_argument('--max_document_word', default=50, type=int) 30 | parser.add_argument('--max_char', default=25, type=int) 31 | parser.add_argument('--max_num_neighbors', default=100, type=int) 32 | parser.add_argument('--max_rel_words', default=8, type=int) 33 | 34 | # embeddings 35 | parser.add_argument('--word2id', default='glove_vocab.txt', type=str) 36 | parser.add_argument('--relation2id', default='relations.txt', type=str) 37 | parser.add_argument('--entity2id', default='entities.txt', type=str) 38 | parser.add_argument('--char2id', default='chars.txt', type=str) 39 | parser.add_argument('--word_emb_file', default='glove_word_emb.npy', type=str) 40 | parser.add_argument('--entity_emb_file', default='entity_emb_100d.npy', type=str) 41 | parser.add_argument('--rel_word_ids', default='rel_word_idx.npy', type=str) 42 | 43 | # dimensions, layers, dropout 44 | parser.add_argument('--num_layer', default=1, type=int) 45 | parser.add_argument('--entity_dim', default=100, type=int) 46 | parser.add_argument('--word_dim', default=300, type=int) 47 | parser.add_argument('--hidden_drop', default=0.2, type=float) 48 | parser.add_argument('--word_drop', default=0.2, type=float) 49 | 50 | # optimization 51 | parser.add_argument('--num_epoch', default=100, type=int) 52 | parser.add_argument('--batch_size', default=8, type=int) 53 | parser.add_argument('--gradient_clip', default=1.0, type=float) 54 | parser.add_argument('--learning_rate', default=0.001, type=float) 55 | parser.add_argument('--seed', default=19940715, type=int) 56 | parser.add_argument('--lr_schedule', action='store_true') 57 | parser.add_argument('--label_smooth', default=0.1, type=float) 58 | parser.add_argument('--fact_drop', default=0, type=float) 59 | 60 | # model options 61 | parser.add_argument('--use_doc', action='store_true') 62 | parser.add_argument('--use_inverse_relation', action='store_true') 63 | parser.add_argument('--model_id', default='debug', type=str) 64 | parser.add_argument('--load_model_file', default=None, type=str) 65 | parser.add_argument('--mode', default='train', type=str) 66 | parser.add_argument('--eps', default=0.05, type=float) # threshold for f1 67 | 68 | args = parser.parse_args() 69 | 70 | if args.name == 'webqsp': 71 | args.type_rels = ['', '', '', '', '', '', '', '', '', ''] 72 | else: 73 | args.type_rels = [] 74 | 75 | config = vars(args) 76 | config['to_save_model'] = True # always save model 77 | config['save_model_file'] = 'model/' + config['name'] + '/best_{}.pt'.format(config['model_id']) 78 | config['pred_file'] = 'results/' + config['name'] + '/best_{}.pred'.format(config['model_id']) 79 | else: 80 | with open(config_path, "r") as setting: 81 | config = yaml.load(setting) 82 | 83 | print('-'* 10 + 'Experiment Config' + '-' * 10) 84 | for k, v in config.items(): 85 | print(k + ': ', v) 86 | print('-'* 10 + 'Experiment Config' + '-' * 10 + '\n') 87 | 88 | return config 89 | 90 | def use_cuda(var): 91 | if torch.cuda.is_available(): 92 | return var.cuda() 93 | else: 94 | return var 95 | 96 | def save_model(the_model, path): 97 | if os.path.exists(path): 98 | path = path + '_copy' 99 | print("saving model to ...", path) 100 | torch.save(the_model, path) 101 | 102 | 103 | def load_model(path): 104 | if not os.path.exists(path): 105 | assert False, 'cannot find model: ' + path 106 | print("loading model from ...", path) 107 | return torch.load(path) 108 | 109 | def load_dict(filename): 110 | word2id = dict() 111 | with open(filename) as f_in: 112 | for line in f_in: 113 | word = line.strip() 114 | word2id[word] = len(word2id) 115 | return word2id 116 | 117 | def load_documents(document_file): 118 | print('loading document from', document_file) 119 | documents = dict() 120 | with open(document_file) as f_in: 121 | for line in tqdm(list(f_in)): 122 | passage = json.loads(line) 123 | # tokenize document 124 | document_token = nltk.word_tokenize(passage['document']['text']) 125 | if 'title' in passage: 126 | title_token = nltk.word_tokenize(passage['title']['text']) 127 | passage['tokens'] = title_token + ['|'] + document_token 128 | # passage['tokens'] = title_token 129 | else: 130 | passage['tokens'] = document_token 131 | documents[int(passage['documentId'])] = passage 132 | return documents 133 | 134 | def cal_accuracy(pred, answer_dist): 135 | """ 136 | pred: batch_size 137 | answer_dist: batch_size, max_local_entity 138 | """ 139 | num_correct = 0.0 140 | num_answerable = 0.0 141 | for i, l in enumerate(pred): 142 | num_correct += (answer_dist[i, l] != 0) 143 | for dist in answer_dist: 144 | if np.sum(dist) != 0: 145 | num_answerable += 1 146 | return num_correct / len(pred), num_answerable / len(pred) 147 | 148 | def char_vocab(word2id, data_path): 149 | # build char embeddings 150 | char_counter = Counter() 151 | max_char = 0 152 | with open(word2id) as f: 153 | for word in f: 154 | word = word.strip() 155 | max_char = max(max_char, len(word)) 156 | for char in word: 157 | char_counter[char] += 1 158 | 159 | char2id = {c: idx for idx, c in enumerate(char_counter.keys(), 1)} 160 | char2id['__unk__'] = 0 161 | 162 | id2char = {id_:c for c, id_ in char2id.items()} 163 | 164 | vocab_size = len(char2id) 165 | char_vocabs = [] 166 | for _ in range(vocab_size): 167 | char_vocabs.append(id2char[_]) 168 | 169 | with open(data_path, 'w') as g: 170 | g.write('\n'.join(char_vocabs)) 171 | 172 | print(max_char) 173 | 174 | class LeftMMFixed(torch.autograd.Function): 175 | """ 176 | Implementation of matrix multiplication of a Sparse Variable with a Dense Variable, returning a Dense one. 177 | This is added because there's no autograd for sparse yet. No gradient computed on the sparse weights. 178 | """ 179 | 180 | def __init__(self): 181 | super(LeftMMFixed, self).__init__() 182 | self.sparse_weights = None 183 | 184 | def forward(self, sparse_weights, x): 185 | if self.sparse_weights is None: 186 | self.sparse_weights = sparse_weights 187 | return torch.mm(self.sparse_weights, x) 188 | 189 | def backward(self, grad_output): 190 | sparse_weights = self.sparse_weights 191 | return None, torch.mm(sparse_weights.t(), grad_output) 192 | 193 | 194 | def sparse_bmm(X, Y): 195 | """Batch multiply X and Y where X is sparse, Y is dense. 196 | Args: 197 | X: Sparse tensor of size BxMxN. Consists of two tensors, 198 | I:3xZ indices, and V:1xZ values. 199 | Y: Dense tensor of size BxNxK. 200 | Returns: 201 | batched-matmul(X, Y): BxMxK 202 | """ 203 | I = X._indices() 204 | V = X._values() 205 | B, M, N = X.size() 206 | _, _, K = Y.size() 207 | Z = I.size()[1] 208 | lookup = Y[I[0, :], I[2, :], :] 209 | X_I = torch.stack((I[0, :] * M + I[1, :], use_cuda(torch.arange(Z).type(torch.LongTensor))), 0) 210 | S = use_cuda(Variable(torch.cuda.sparse.FloatTensor(X_I, V, torch.Size([B * M, Z])), requires_grad=False)) 211 | prod_op = LeftMMFixed() 212 | prod = prod_op(S, lookup) 213 | return prod.view(B, M, K) 214 | 215 | if __name__ == "__main__": 216 | # load_documents('datasets/wikimovie/full_doc/documents.json') 217 | char_vocab('datasets/webqsp/kb_05/vocab.txt', 'datasets/webqsp/kb_05/chars.txt') 218 | --------------------------------------------------------------------------------