├── .gitignore
├── README.md
├── assets
└── model.png
├── attention.py
├── build_emb.py
├── data_generator.py
├── model.py
├── modules.py
├── run.sh
├── run_with_doc.sh
├── script.py
├── test.sh
├── train.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.json
2 | datasets/webqsp/
3 | *.tar.gz
4 | *.pt
5 | __pycache__/
6 | tf_logs/*
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Code for the ACL 2019 paper:
2 |
3 | ## Improving Question Answering over Incomplete KBs with Knowledge-Aware Reader
4 |
5 | Paper link: [https://arxiv.org/abs/1905.07098](https://arxiv.org/abs/1905.07098)
6 |
7 | Model Overview:
8 |

9 |
10 | ### Requirements
11 | * ``PyTorch 1.0.1``
12 | * ``tensorboardX``
13 | * ``tqdm``
14 | * ``gluonnlp``
15 |
16 | ### Prepare data
17 | ```
18 | mkdir datasets && cd datasets && wget https://sites.cs.ucsb.edu/~xwhan/datasets/webqsp.tar.gz && tar -xzvf webqsp.tar.gz && cd ..
19 | ```
20 |
21 | ### Full KB setting
22 | ```
23 | CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_full_kb --max_num_neighbors 50 --label_smooth 0.1 --data_folder datasets/webqsp/full/
24 | ```
25 |
26 | ### Incomplete KB setting
27 | Note: The Hits@1 should match or be slightly better than the number reported in the paper. More tuning on threshold should give you better F1 score.
28 | #### 30% KB
29 | ```
30 | CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_03 --max_num_neighbors 50 --use_doc --data_folder datasets/webqsp/kb_03/ --eps 0.05
31 | ```
32 |
33 | #### 10% KB
34 | ```
35 | CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_01 --max_num_neighbors 50 --use_doc --data_folder datasets/webqsp/kb_01/ --eps 0.05
36 | ```
37 | #### 50% KB
38 | ```
39 | CUDA_VISIBLE_DEVICES=0 python train.py --model_id KAReader_kb_05 --num_layer 1 --max_num_neighbors 100 --use_doc --data_folder datasets/webqsp/kb_05/ --eps 0.05 --seed 3 --hidden_drop 0.05
40 | ```
41 |
42 | ### Citation
43 | ```
44 | @inproceedings{xiong-etal-2019-improving,
45 | title = "Improving Question Answering over Incomplete {KB}s with Knowledge-Aware Reader",
46 | author = "Xiong, Wenhan and
47 | Yu, Mo and
48 | Chang, Shiyu and
49 | Guo, Xiaoxiao and
50 | Wang, William Yang",
51 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
52 | month = jul,
53 | year = "2019",
54 | address = "Florence, Italy",
55 | publisher = "Association for Computational Linguistics",
56 | url = "https://www.aclweb.org/anthology/P19-1417",
57 | doi = "10.18653/v1/P19-1417",
58 | pages = "4258--4264",
59 | }
60 | ```
61 |
--------------------------------------------------------------------------------
/assets/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xwhan/Knowledge-Aware-Reader/883bb0477563f50b391b3f2e0b71ec8a244a0cb8/assets/model.png
--------------------------------------------------------------------------------
/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 | import copy
6 | import math
7 |
8 | def attention(query, key, value, mask=None, dropout=None):
9 | "Compute 'Scaled Dot Product Attention'"
10 | d_k = query.size(-1)
11 | scores = torch.matmul(query, key.transpose(-2, -1)) \
12 | / math.sqrt(d_k)
13 | if mask is not None:
14 | scores = scores.masked_fill(mask == 0, -1e9)
15 | p_attn = F.softmax(scores, dim = -1)
16 | if dropout is not None:
17 | p_attn = dropout(p_attn)
18 | return torch.matmul(p_attn, value), p_attn
19 |
20 |
21 | def clones(module, N):
22 | "Produce N identical layers."
23 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
24 |
25 |
26 | class MultiHeadedAttention(nn.Module):
27 | def __init__(self, h, d_model, dropout=0.1):
28 | "Take in model size and number of heads."
29 | super(MultiHeadedAttention, self).__init__()
30 | assert d_model % h == 0
31 | # We assume d_v always equals d_k
32 | self.d_k = d_model // h
33 | self.h = h
34 | self.linears = clones(nn.Linear(d_model, d_model), 4)
35 | self.attn = None
36 | self.dropout = nn.Dropout(p=dropout)
37 |
38 | def forward(self, query, key, value, mask=None):
39 | "Implements Figure 2"
40 | if mask is not None:
41 | # Same mask applied to all h heads.
42 | mask = mask.unsqueeze(1)
43 | nbatches = query.size(0)
44 |
45 | # 1) Do all the linear projections in batch from d_model => h x d_k
46 | query, key, value = \
47 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
48 | for l, x in zip(self.linears, (query, key, value))]
49 |
50 | # 2) Apply attention on all the projected vectors in batch.
51 | x, self.attn = attention(query, key, value, mask=mask,
52 | dropout=self.dropout)
53 |
54 | # 3) "Concat" using a view and apply a final linear.
55 | x = x.transpose(1, 2).contiguous() \
56 | .view(nbatches, -1, self.h * self.d_k)
57 | return self.linears[-1](x)
58 |
59 | class PositionwiseFeedForward(nn.Module):
60 | "Implements FFN equation."
61 | def __init__(self, d_model, d_ff, dropout=0.1):
62 | super(PositionwiseFeedForward, self).__init__()
63 | self.w_1 = nn.Linear(d_model, d_ff)
64 | self.w_2 = nn.Linear(d_ff, d_model)
65 | self.dropout = nn.Dropout(dropout)
66 |
67 | def forward(self, x):
68 | return self.w_2(self.dropout(F.relu(self.w_1(x))))
69 |
70 | class PositionalEncoding(nn.Module):
71 | "Implement the PE function."
72 | def __init__(self, d_model, dropout, max_len=5000):
73 | super(PositionalEncoding, self).__init__()
74 | self.dropout = nn.Dropout(p=dropout)
75 |
76 | # Compute the positional encodings once in log space.
77 | pe = torch.zeros(max_len, d_model)
78 | position = torch.arange(0, max_len).unsqueeze(1).float()
79 | div_term = torch.exp(torch.arange(0, d_model, 2).float() *
80 | -(math.log(10000.0) / d_model))
81 | pe[:, 0::2] = torch.sin(position * div_term)
82 | pe[:, 1::2] = torch.cos(position * div_term)
83 | pe = pe.unsqueeze(0)
84 | self.register_buffer('pe', pe)
85 |
86 | def forward(self, x):
87 | x = x + Variable(self.pe[:, :x.size(1)],
88 | requires_grad=False)
89 | return self.dropout(x)
90 |
91 | class LayerNorm(nn.Module):
92 | "Construct a layernorm module (See citation for details)."
93 | def __init__(self, features, eps=1e-6):
94 | super(LayerNorm, self).__init__()
95 | self.a_2 = nn.Parameter(torch.ones(features))
96 | self.b_2 = nn.Parameter(torch.zeros(features))
97 | self.eps = eps
98 |
99 | def forward(self, x):
100 | mean = x.mean(-1, keepdim=True)
101 | std = x.std(-1, keepdim=True)
102 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
103 |
104 | class SublayerConnection(nn.Module):
105 | """
106 | A residual connection followed by a layer norm.
107 | Note for code simplicity the norm is first as opposed to last.
108 | """
109 | def __init__(self, size, dropout):
110 | super(SublayerConnection, self).__init__()
111 | self.norm = LayerNorm(size)
112 | self.dropout = nn.Dropout(dropout)
113 |
114 | def forward(self, x, sublayer):
115 | "Apply residual connection to any sublayer with the same size."
116 | return x + self.dropout(sublayer(self.norm(x)))
117 |
118 | class EncoderLayer(nn.Module):
119 | "Encoder is made up of self-attn and feed forward (defined below)"
120 | def __init__(self, size, self_attn, feed_forward, dropout):
121 | super(EncoderLayer, self).__init__()
122 | self.self_attn = self_attn
123 | self.feed_forward = feed_forward
124 | self.sublayer = clones(SublayerConnection(size, dropout), 2)
125 | self.size = size
126 |
127 | def forward(self, x, mask):
128 | "Follow Figure 1 (left) for connections."
129 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
130 | return self.sublayer[1](x, self.feed_forward)
131 |
132 | class Encoder(nn.Module):
133 | "Core encoder is a stack of N layers"
134 | def __init__(self, layer, N):
135 | super(Encoder, self).__init__()
136 | self.layers = clones(layer, N)
137 | self.norm = LayerNorm(layer.size)
138 |
139 | def forward(self, x, mask):
140 | "Pass the input (and mask) through each layer in turn."
141 | for layer in self.layers:
142 | x = layer(x, mask)
143 | return self.norm(x)
144 |
145 | class SimpleEncoder(nn.Module):
146 | """
147 | takes (batch_size, seq_len, embed_dim) as inputs
148 | calculate MASK, POSITION_ENCODING
149 | """
150 | def __init__(self, embed_dim, head=4, layer=1, dropout=0.1):
151 | super(SimpleEncoder, self).__init__()
152 | d_ff = 2 * embed_dim
153 |
154 | self.position = PositionalEncoding(embed_dim, dropout)
155 | attn = MultiHeadedAttention(head, embed_dim)
156 | ff = PositionwiseFeedForward(embed_dim, d_ff)
157 | self.encoder = Encoder(EncoderLayer(embed_dim, attn, ff, dropout), layer)
158 |
159 | def forward(self, x, mask):
160 | mask = mask.unsqueeze(-2)
161 | x = self.position(x)
162 | x = self.encoder(x, mask)
163 | return x
164 |
165 | if __name__ == '__main__':
166 | encoder = SimpleEncoder(350, 2, 1)
167 | inputs = torch.zeros(1000,50,350)
168 | lens = [10] * 1000
169 | encoder(inputs, lens)
--------------------------------------------------------------------------------
/build_emb.py:
--------------------------------------------------------------------------------
1 | import gluonnlp as nlp
2 | import numpy as np
3 | from tqdm import tqdm
4 |
5 | dataset = 'datasets/webqsp/kb_05'
6 | rel_path = dataset + '/relations.txt'
7 |
8 | word_counter = []
9 |
10 | # load original vocab
11 | with open(dataset + '/vocab.txt') as f:
12 | for line in f.readlines():
13 | word_counter.append(line.strip())
14 |
15 | rel_words = []
16 | max_num_words = 0
17 | all_relations = []
18 |
19 | # how to split the relation
20 | if 'webqsp' in dataset:
21 | with open(rel_path) as f:
22 | first_line = True
23 | for line in tqdm(f.readlines()):
24 | if first_line:
25 | first_line = False
26 | continue
27 | line = line.strip()
28 | all_relations.append(line)
29 | line = line[1:-1]
30 | fields = line.split('.')
31 | words = fields[-2].split('_') + fields[-1].split('_')
32 | max_num_words = max(len(words), max_num_words)
33 | rel_words.append(words)
34 | word_counter += words
35 | elif 'wikimovie' in dataset:
36 | with open(rel_path) as f:
37 | for line in tqdm(f.readlines()):
38 | line = line.strip()
39 | all_relations.append(line)
40 | words = line.split('_')
41 | max_num_words = max(len(words), max_num_words)
42 | rel_words.append(words)
43 | word_counter += words
44 |
45 | print('max_num_words: ', max_num_words)
46 |
47 | word_counter = nlp.data.count_tokens(word_counter)
48 | glove_emb = nlp.embedding.create('glove', source='glove.6B.100d')
49 | vocab = nlp.Vocab(word_counter)
50 | vocab.set_embedding(glove_emb)
51 |
52 | emb_mat = vocab.embedding.idx_to_vec.asnumpy()
53 | np.save(dataset + '/glove_word_emb_100d', emb_mat)
54 |
55 | with open(dataset + '/glove_vocab.txt', 'w') as g:
56 | g.write('\n'.join(vocab.idx_to_token))
57 |
58 | assert False
59 |
60 | rel_word_ids = np.ones((len(rel_words) + 1, max_num_words), dtype=int) # leave the first 1 for padding relation
61 | rel_emb_mat = []
62 | for rel_idx, words in enumerate(rel_words):
63 | for i, word in enumerate(words):
64 | rel_word_ids[rel_idx + 1, i] = vocab.token_to_idx[word]
65 |
66 | np.save(dataset + '/rel_word_idx', rel_word_ids)
67 |
68 | all_relations = ['pad_rel'] + all_relations
69 | with open(rel_path, 'w') as g:
70 | g.write('\n'.join(all_relations))
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | import json
2 | import nltk
3 | import numpy as np
4 | import random
5 | import torch
6 |
7 | from collections import defaultdict
8 | from tqdm import tqdm
9 | from util import get_config
10 | from util import load_dict
11 | from util import load_documents
12 |
13 | class DataLoader():
14 | def __init__(self, config, documents, mode='train'):
15 | self.mode = mode
16 | self.use_doc = config['use_doc']
17 | self.use_inverse_relation = config['use_inverse_relation']
18 | self.max_query_word = config['max_query_word']
19 | self.max_document_word = config['max_document_word']
20 | self.max_char = config['max_char']
21 | self.documents = documents
22 | self.data_file = config['data_folder'] + config['{}_data'.format(mode)]
23 | self.batch_size = config['batch_size'] if mode == 'train' else config['batch_size']
24 | self.max_rel_words = config['max_rel_words']
25 | self.type_rels = config['type_rels']
26 | self.fact_drop = config['fact_drop']
27 |
28 | # read all data
29 | self.data = []
30 | with open(self.data_file) as f:
31 | for line in tqdm(list(f)):
32 | self.data.append(json.loads(line))
33 |
34 | # word and kb vocab
35 | self.word2id = load_dict(config['data_folder'] + config['word2id'])
36 | self.relation2id = load_dict(config['data_folder'] + config['relation2id'])
37 | self.entity2id = load_dict(config['data_folder'] + config['entity2id'])
38 | self.id2entity = {i:entity for entity, i in self.entity2id.items()}
39 |
40 | self.rel_word_idx = np.load(config['data_folder'] + 'rel_word_idx.npy')
41 |
42 | # for batching
43 | self.max_local_entity = 0 # max num of candidates
44 | self.max_relevant_docs = 0 # max num of retired documents
45 | self.max_kb_neighbors = config['max_num_neighbors'] # max num of neighbors for entity
46 | self.max_kb_neighbors_ = config['max_num_neighbors'] # kb relations are directed
47 | self.max_linked_entities = 0 # max num of linked entities for each doc
48 | self.max_linked_documents = 50 # max num of linked documents for each entity
49 |
50 | self.num_kb_relation = 2 * len(self.relation2id) if self.use_inverse_relation else len(self.relation2id)
51 |
52 | # get the batching parameters
53 | self.get_stats()
54 |
55 | def get_stats(self):
56 | if self.use_doc:
57 | # max_linked_entities
58 | self.useful_docs = {} # filter out documents with out linked entities
59 | for docid, doc in self.documents.items():
60 | linked_entities = 0
61 | if 'title' in doc:
62 | linked_entities += len(doc['title']['entities'])
63 | offset = len(nltk.word_tokenize(doc['title']['text']))
64 | else:
65 | offset = 0
66 | for ent in doc['document']['entities']:
67 | if ent['start'] + offset >= self.max_document_word:
68 | continue
69 | else:
70 | linked_entities += 1
71 | if linked_entities > 1:
72 | self.useful_docs[docid] = doc
73 | self.max_linked_entities = max(self.max_linked_entities, linked_entities)
74 | print('max num of linked entities: ', self.max_linked_entities)
75 |
76 | # decide how many neighbors should we consider
77 | # num_neighbors = []
78 |
79 | num_tuples = []
80 |
81 | # max_linked_documents, max_relevant_docs, max_local_entity
82 | for line in tqdm(self.data):
83 | candidate_ents = set()
84 | rel_docs = 0
85 |
86 | # question entity
87 | for ent in line['entities']:
88 | candidate_ents.add(ent['text'])
89 | # kb entities
90 | for ent in line['subgraph']['entities']:
91 | candidate_ents.add(ent['text'])
92 |
93 | num_tuples.append(line['subgraph']['tuples'])
94 |
95 | if self.use_doc:
96 | # entities in doc
97 | for passage in line['passages']:
98 | if passage['document_id'] not in self.useful_docs:
99 | continue
100 | rel_docs += 1
101 | document = self.useful_docs[int(passage['document_id'])]
102 | for ent in document['document']['entities']:
103 | candidate_ents.add(ent['text'])
104 | if 'title' in document:
105 | for ent in document['title']['entities']:
106 | candidate_ents.add(ent['text'])
107 |
108 | neighbors = defaultdict(list)
109 | neighbors_ = defaultdict(list)
110 |
111 | for triple in line['subgraph']['tuples']:
112 | s, r, o = triple
113 | neighbors[s['text']].append((r['text'], o['text']))
114 | neighbors_[o['text']].append((r['text'], s['text']))
115 |
116 | self.max_relevant_docs = max(self.max_relevant_docs, rel_docs)
117 | self.max_local_entity = max(self.max_local_entity, len(candidate_ents))
118 |
119 | # np.save('num_neighbors_', num_neighbors)
120 |
121 | print('mean num of triples: ', len(num_tuples))
122 |
123 | print('max num of relevant docs: ', self.max_relevant_docs)
124 | print('max num of candidate entities: ', self.max_local_entity)
125 | print('max_num of neighbors: ', self.max_kb_neighbors)
126 | print('max_num of neighbors inverse: ', self.max_kb_neighbors_)
127 |
128 | def batcher(self, shuffle=False):
129 | if shuffle:
130 | random.shuffle(self.data)
131 |
132 | device = torch.device('cuda')
133 |
134 | for batch_id in tqdm(range(0, len(self.data), self.batch_size)):
135 | batch = self.data[batch_id:batch_id + self.batch_size]
136 |
137 | batch_size = len(batch)
138 | questions = np.full((batch_size, self.max_query_word), 1, dtype=int)
139 | documents = np.full((batch_size, self.max_relevant_docs, self.max_document_word), 1, dtype=int)
140 | entity_link_documents = np.zeros((batch_size, self.max_local_entity, self.max_linked_documents, self.max_document_word), dtype=int)
141 | entity_link_doc_norm = np.zeros((batch_size, self.max_local_entity, self.max_linked_documents, self.max_document_word), dtype=int)
142 | documents_ans_span = np.zeros((batch_size, self.max_relevant_docs, 2), dtype=int)
143 | entity_link_ents = np.full((batch_size, self.max_local_entity, self.max_kb_neighbors_), -1, dtype=int) # incoming edges
144 | entity_link_rels = np.zeros((batch_size, self.max_local_entity, self.max_kb_neighbors_), dtype=int)
145 | candidate_entities = np.full((batch_size, self.max_local_entity), len(self.entity2id), dtype=int)
146 | ent_degrees = np.zeros((batch_size, self.max_local_entity), dtype=int)
147 | true_answers = np.zeros((batch_size, self.max_local_entity), dtype=float)
148 | query_entities = np.zeros((batch_size, self.max_local_entity), dtype=float)
149 | answers_ = []
150 | questions_ = []
151 |
152 | for i, sample in enumerate(batch):
153 | doc_global2local = {}
154 | # answer set
155 | answers = set()
156 | for answer in sample['answers']:
157 | keyword = 'text' if type(answer['kb_id']) == int else 'kb_id'
158 | answers.add(self.entity2id[answer[keyword]])
159 |
160 | if self.mode != 'train':
161 | answers_.append(list(answers))
162 | questions_.append(sample['question'])
163 |
164 | # candidate entities, linked_documents
165 | candidates = set()
166 | query_entity = set()
167 | ent2linked_docId = defaultdict(list)
168 | for ent in sample['entities']:
169 | candidates.add(self.entity2id[ent['text']])
170 | query_entity.add(self.entity2id[ent['text']])
171 | for ent in sample['subgraph']['entities']:
172 | candidates.add(self.entity2id[ent['text']])
173 |
174 | if self.use_doc:
175 | for local_id, passage in enumerate(sample['passages']):
176 | if passage['document_id'] not in self.useful_docs:
177 | continue
178 | doc_id = int(passage['document_id'])
179 | doc_global2local[doc_id] = local_id
180 | document = self.useful_docs[doc_id]
181 | for word_pos, word in enumerate([''] + document['tokens']):
182 | if word_pos < self.max_document_word:
183 | documents[i, local_id, word_pos] = self.word2id.get(word, self.word2id[''])
184 | for ent in document['document']['entities']:
185 | if self.entity2id[ent['text']] in answers:
186 | documents_ans_span[i, local_id, 0] = min(ent['start'] + 1, self.max_document_word-1)
187 | documents_ans_span[i, local_id, 1] = min(ent['end'] + 1, self.max_document_word-1)
188 | s, e = ent['start'] + 1, ent['end'] + 1
189 | ent2linked_docId[self.entity2id[ent['text']]].append((doc_id, s, e))
190 | candidates.add(self.entity2id[ent['text']])
191 | if 'title' in document:
192 | for ent in document['title']['entities']:
193 | candidates.add(self.entity2id(ent['text']))
194 |
195 | # kb information
196 | connections = defaultdict(list)
197 |
198 | if self.fact_drop and self.mode == 'train':
199 | all_triples = sample['subgraph']['tuples']
200 | random.shuffle(all_triples)
201 | num_triples = len(all_triples)
202 | keep_ratio = 1 - self.fact_drop
203 | all_triples = all_triples[:int(num_triples * keep_ratio)]
204 |
205 | else:
206 | all_triples = sample['subgraph']['tuples']
207 |
208 | for tpl in all_triples:
209 | s,r,o = tpl
210 |
211 |
212 | # only consider one direction of information propagation
213 | connections[self.entity2id[o['text']]].append((self.relation2id[r['text']], self.entity2id[s['text']]))
214 |
215 | if r['text'] in self.type_rels:
216 | connections[self.entity2id[s['text']]].append((self.relation2id[r['text']], self.entity2id[o['text']]))
217 |
218 |
219 | # used for updating entity representations
220 | ent_global2local = {}
221 | candidates = list(candidates)
222 |
223 | # if len(candidates) == 0:
224 | # print('No entities????')
225 | # print(sample)
226 |
227 | for j, entid in enumerate(candidates):
228 | if entid in query_entity:
229 | query_entities[i, j] = 1.0
230 | candidate_entities[i, j] = entid
231 | ent_global2local[entid] = j
232 | if entid in answers: true_answers[i, j] = 1.0
233 | for linked_doc in ent2linked_docId[entid]:
234 | start, end = linked_doc[1], linked_doc[2]
235 | if end - start > 0:
236 | entity_link_documents[i, j, doc_global2local[linked_doc[0]], start:end] = 1.0
237 | entity_link_doc_norm[i, j, doc_global2local[linked_doc[0]], start:end] = 1.0
238 |
239 | for j, entid in enumerate(candidates):
240 | for count, neighbor in enumerate(connections[entid]):
241 | if count < self.max_kb_neighbors_:
242 | r_id, s_id = neighbor
243 | # convert the global ent id to subgraph id, for graph convolution
244 | s_id_local = ent_global2local[s_id]
245 | entity_link_rels[i, j, count] = r_id
246 | entity_link_ents[i, j, count] = s_id_local
247 | ent_degrees[i, s_id_local] += 1
248 |
249 | # questions
250 | for j, word in enumerate(sample['question'].split()):
251 | if j < self.max_query_word:
252 | if word in self.word2id:
253 | questions[i, j] = self.word2id[word]
254 | else:
255 | questions[i, j] = self.word2id['']
256 |
257 | if self.use_doc:
258 | # exact match features for docs
259 | d_cat = documents.reshape((batch_size, -1))
260 | em_d = np.array([np.isin(d_, q_) for d_, q_ in zip(d_cat, questions)], dtype=int) # exact match features
261 | em_d = em_d.reshape((batch_size, self.max_relevant_docs, -1))
262 |
263 | batch_dict = {
264 | 'questions': questions, # (B, q_len)
265 | 'candidate_entities': candidate_entities,
266 | 'entity_link_ents': entity_link_ents,
267 | 'answers': true_answers,
268 | 'query_entities': query_entities,
269 | 'answers_': answers_,
270 | 'questions_': questions_,
271 | 'rel_word_ids': self.rel_word_idx, # (num_rel+1, word_lens)
272 | 'entity_link_rels': entity_link_rels, # (bsize, max_num_candidates, max_num_neighbors)
273 | 'ent_degrees': ent_degrees
274 | }
275 |
276 | if self.use_doc:
277 | batch_dict['documents'] = documents
278 | batch_dict['documents_em'] = em_d
279 | batch_dict['ent_link_doc_spans'] = entity_link_documents
280 | batch_dict['documents_ans_span'] = documents_ans_span
281 | batch_dict['ent_link_doc_norm_spans'] = entity_link_doc_norm
282 |
283 | for k, v in batch_dict.items():
284 | if k.endswith('_'):
285 | batch_dict[k] = v
286 | continue
287 | if not self.use_doc and 'doc' in k:
288 | continue
289 | batch_dict[k] = torch.from_numpy(v).to(device)
290 | yield batch_dict
291 |
292 |
293 | if __name__ == '__main__':
294 | cfg = get_config()
295 | documents = load_documents(cfg['data_folder'] + cfg['{}_documents'.format(cfg['mode'])])
296 | # cfg['batch_size'] = 2
297 | train_data = DataLoader(cfg, documents)
298 | # build_squad_like_data(cfg['data_folder'] + cfg['{}_data'.format(cfg['mode'])], cfg['data_folder'] + cfg['{}_documents'.format(cfg['mode'])])
299 | for batch in train_data.batcher():
300 | print(batch['documents_ans_span'])
301 | assert False
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from modules import AttnEncoder
7 | from modules import Packed
8 | from modules import SeqAttnMatch
9 | from modules import l_relu
10 | from modules import QueryReform
11 | from modules import ConditionGate
12 | from util import load_dict
13 |
14 |
15 | class KAReader(nn.Module):
16 | """docstring for ClassName"""
17 | def __init__(self, args):
18 | super(KAReader, self).__init__()
19 |
20 | self.entity2id = load_dict(args['data_folder'] + args['entity2id'])
21 | self.word2id = load_dict(args['data_folder'] + args['word2id'])
22 | self.relation2id = load_dict(args['data_folder'] + args['relation2id'])
23 | self.num_entity = len(self.entity2id)
24 | self.num_relation = len(self.relation2id)
25 | self.num_word = len(self.word2id)
26 | self.num_layer = args['num_layer']
27 | self.use_doc = args['use_doc']
28 | self.word_drop = args['word_drop']
29 | self.hidden_drop = args['hidden_drop']
30 | self.label_smooth = args['label_smooth']
31 |
32 | for k, v in args.items():
33 | if k.endswith('dim'):
34 | setattr(self, k, v)
35 | if k.endswith('emb_file'):
36 | setattr(self, k, args['data_folder'] + v)
37 |
38 | # pretrained entity embeddings
39 | self.entity_emb = nn.Embedding(self.num_entity + 1, self.entity_dim, padding_idx=self.num_entity)
40 | self.entity_emb.weight.data.copy_(torch.from_numpy(np.pad(np.load(self.entity_emb_file), ((0, 1), (0, 0)), 'constant')))
41 | self.entity_emb.weight.requires_grad = False
42 | self.entity_linear = nn.Linear(self.entity_dim, self.entity_dim)
43 |
44 | # word embeddings
45 | self.word_emb = nn.Embedding(self.num_word, self.word_dim, padding_idx=1)
46 | self.word_emb.weight.data.copy_(torch.from_numpy(np.load(self.word_emb_file)))
47 | self.word_emb.weight.requires_grad = False
48 |
49 | self.word_emb_match = SeqAttnMatch(self.word_dim)
50 |
51 | self.hidden_dim = self.entity_dim
52 | # question and doc encoder
53 | self.question_encoder = Packed(nn.LSTM(self.word_dim, self.hidden_dim // 2, batch_first=True, bidirectional=True))
54 |
55 |
56 | self.self_att_r = AttnEncoder(self.hidden_dim)
57 | self.self_att_q = AttnEncoder(self.hidden_dim)
58 | self.combine_q_rel = nn.Linear(self.hidden_dim*2, self.hidden_dim)
59 | # doc encoder
60 |
61 | self.ent_info_proj = nn.Linear(self.hidden_dim, self.hidden_dim)
62 | self.input_proj = nn.Linear(2*self.word_dim + 1, self.hidden_dim)
63 | self.doc_encoder = Packed(nn.LSTM(self.hidden_dim, self.hidden_dim // 2, batch_first=True, bidirectional=True))
64 | self.doc_to_ent = nn.Linear(self.hidden_dim, self.hidden_dim)
65 |
66 | self.ent_info_gate = ConditionGate(self.hidden_dim)
67 | self.ent_info_gate_out = ConditionGate(self.hidden_dim)
68 |
69 | self.kg_prop = nn.Linear(self.hidden_dim + self.entity_dim, self.entity_dim)
70 | self.kg_gate = nn.Linear(self.hidden_dim + self.entity_dim, self.entity_dim)
71 | self.self_prop = nn.Linear(self.entity_dim, self.entity_dim)
72 | self.combine_q = nn.Linear(2*self.hidden_dim, self.hidden_dim)
73 |
74 | self.reader_gate = nn.Linear(2*self.hidden_dim, self.hidden_dim)
75 | self.query_update = QueryReform(self.hidden_dim)
76 |
77 | self.attn_match = nn.Linear(self.hidden_dim*3, self.hidden_dim*2)
78 | self.attn_match_q = nn.Linear(self.hidden_dim*2, self.hidden_dim)
79 | self.loss = nn.BCEWithLogitsLoss()
80 |
81 | self.word_drop = nn.Dropout(self.word_drop)
82 | self.hidden_drop = nn.Dropout(self.hidden_drop)
83 |
84 | def forward(self, feed):
85 | # encode questions
86 | question = feed['questions']
87 | q_mask = (question != 1).float()
88 | q_len = q_mask.sum(-1) # (B, q_len)
89 | q_word_emb = self.word_drop(self.word_emb(question))
90 | q_emb, _ = self.question_encoder(q_word_emb, q_len, max_length=question.size(1))
91 | q_emb = self.hidden_drop(q_emb)
92 |
93 | B, max_q_len = question.size(0), question.size(1)
94 |
95 | # candidate ent embeddings
96 | ent_emb_ = self.entity_emb(feed['candidate_entities'])
97 | ent_emb = l_relu(self.entity_linear(ent_emb_))
98 |
99 | # # keep a copy of the initial ent_emb
100 | # init_ent_emb = ent_emb
101 | ent_mask = (feed['candidate_entities'] != self.num_entity).float()
102 |
103 | # linked relations
104 | max_num_neighbors = feed['entity_link_ents'].size(2)
105 | max_num_candidates = feed['candidate_entities'].size(1)
106 | neighbor_mask = (feed['entity_link_ents'] != self.num_entity).float() # (B, |C|, |N|)
107 |
108 | # encode all relations with question encoder
109 | rel_word_ids = feed['rel_word_ids']
110 | rel_word_mask = (rel_word_ids != 1).float()
111 | rel_word_lens = rel_word_mask.sum(-1)
112 | rel_word_lens[rel_word_lens == 0] = 1
113 | rel_encoded, _ = self.question_encoder(self.word_drop(self.word_emb(rel_word_ids)), rel_word_lens, max_length=rel_word_ids.size(1)) # (|R|, r_len, h_dim)
114 | # rel_encoded, _ = self.relation_encoder(self.word_drop(self.word_emb(rel_word_ids)), rel_word_lens, max_length=rel_word_ids.size(1)) # (|R|, r_len, h_dim)
115 | rel_encoded = self.hidden_drop(rel_encoded)
116 | rel_encoded = self.self_att_r(rel_encoded, rel_word_mask)
117 |
118 | neighbor_rel_ids = feed['entity_link_rels'].long().view(-1)
119 | neighbor_rel_emb = torch.index_select(rel_encoded, dim=0, index=neighbor_rel_ids).view(B*max_num_candidates, max_num_neighbors, self.hidden_dim)
120 |
121 | # for look up
122 | neighbor_ent_local_index = feed['entity_link_ents'].long() # (B * |C| * max_num_neighbors)
123 | neighbor_ent_local_index = neighbor_ent_local_index.view(B, -1)
124 | neighbor_ent_local_mask = (neighbor_ent_local_index != -1).long()
125 | fix_index = torch.arange(B).long() * max_num_candidates
126 | fix_index = fix_index.to(torch.device('cuda'))
127 | neighbor_ent_local_index = neighbor_ent_local_index + fix_index.view(-1,1)
128 | neighbor_ent_local_index = (neighbor_ent_local_index + 1) * neighbor_ent_local_mask
129 | neighbor_ent_local_index = neighbor_ent_local_index.view(-1)
130 |
131 | ent_seed_info = feed['query_entities'].float() # seed entity will have 1.0 score
132 | ent_is_seed = torch.cat([torch.zeros(1).to(torch.device('cuda')), ent_seed_info.view(-1)], dim=0)
133 | ent_seed_indicator = torch.index_select(ent_is_seed, dim=0, index=neighbor_ent_local_index).view(B*max_num_candidates, max_num_neighbors)
134 |
135 | # v0.0 more find-grained attention
136 | q_emb_expand = q_emb.unsqueeze(1).expand(B, max_num_candidates, max_q_len, -1).contiguous()
137 | q_emb_expand = q_emb_expand.view(B*max_num_candidates, max_q_len, -1)
138 | q_mask_expand = q_mask.unsqueeze(1).expand(B, max_num_candidates, -1).contiguous()
139 | q_mask_expand = q_mask_expand.view(B*max_num_candidates, -1)
140 | q_n_affinity = torch.bmm(q_emb_expand, neighbor_rel_emb.transpose(1, 2)) # (bsize*max_num_candidates, q_len, max_num_neighbors)
141 | q_n_affinity_mask_q = q_n_affinity - (1 - q_mask_expand.unsqueeze(2)) * 1e20
142 | q_n_affinity_mask_n = q_n_affinity - (1 - neighbor_mask.view(B*max_num_candidates, 1, max_num_neighbors))
143 | normalize_over_q = F.softmax(q_n_affinity_mask_q, dim=1)
144 | normalize_over_n = F.softmax(q_n_affinity_mask_n, dim=2)
145 | retrieve_q = torch.bmm(normalize_over_q.transpose(1,2), q_emb_expand)
146 | q_rel_simi = torch.sum(neighbor_rel_emb * retrieve_q, dim=2)
147 |
148 | init_q_emb = self.self_att_r(q_emb, q_mask)
149 |
150 | retrieve_r = torch.bmm(normalize_over_n, neighbor_rel_emb)
151 | q_and_rel = torch.cat([q_emb_expand, retrieve_r], dim=2)
152 | rel_aware_q = self.combine_q_rel(q_and_rel).tanh().view(B, max_num_candidates, -1, self.hidden_dim)
153 |
154 | # pooling over the q_len dim
155 | q_node_emb = rel_aware_q.max(2)[0]
156 |
157 | ent_emb = l_relu(self.combine_q(torch.cat([ent_emb, q_node_emb], dim=2)))
158 | ent_emb_for_lookup = ent_emb.view(-1, self.entity_dim)
159 | ent_emb_for_lookup = torch.cat([torch.zeros(1, self.entity_dim).to(torch.device('cuda')), ent_emb_for_lookup], dim=0)
160 | neighbor_ent_emb = torch.index_select(ent_emb_for_lookup, dim=0, index=neighbor_ent_local_index)
161 | neighbor_ent_emb = neighbor_ent_emb.view(B*max_num_candidates, max_num_neighbors, -1)
162 | neighbor_vec = torch.cat([neighbor_rel_emb, neighbor_ent_emb], dim =-1).view(B*max_num_candidates, max_num_neighbors, -1) # for propagation
163 | neighbor_scores = q_rel_simi * ent_seed_indicator
164 | neighbor_scores = neighbor_scores - (1 - neighbor_mask.view(B*max_num_candidates, max_num_neighbors)) * 1e8
165 | attn_score = F.softmax(neighbor_scores, dim=1)
166 | aggregate = self.kg_prop(neighbor_vec) * attn_score.unsqueeze(2)
167 | aggregate = l_relu(aggregate.sum(1)).view(B, max_num_candidates, -1)
168 | self_prop_ = l_relu(self.self_prop(ent_emb))
169 | gate_value = self.kg_gate(torch.cat([aggregate, ent_emb], dim = -1)).sigmoid()
170 | ent_emb = gate_value * self_prop_ + (1 - gate_value) * aggregate
171 |
172 | # read documents
173 | if self.use_doc:
174 | q_for_text = self.query_update(init_q_emb, ent_emb, ent_seed_info, ent_mask)
175 | # q_for_text = q_node_emb.mean(1)
176 | # q_for_text = init_q_emb
177 |
178 | q_node_emb = torch.cat([q_node_emb, q_for_text.unsqueeze(1).expand_as(q_node_emb).contiguous()], dim=-1)
179 |
180 | ent_linked_doc_spans = feed['ent_link_doc_spans']
181 | doc = feed['documents'] # (B, |D|, d_len)
182 | max_num_doc = doc.size(1)
183 | max_d_len = doc.size(2)
184 | doc_mask = (doc != 1).float()
185 | doc_len = doc_mask.sum(-1)
186 | doc_len += (doc_len == 0).float() # padded documents have 0 words
187 | doc_len = doc_len.view(-1)
188 | d_word_emb = self.word_drop(self.word_emb(doc.view(-1, doc.size(-1)))) # (B*|D|, d_len, emb_dim)
189 |
190 | # input features for documents
191 | q_word_emb = q_word_emb.unsqueeze(1).expand(B, max_num_doc, max_q_len, self.word_dim).contiguous()
192 | q_word_emb = q_word_emb.view(B*max_num_doc, max_q_len, -1)
193 | q_mask_ = (question == 1).unsqueeze(1).expand(B, max_num_doc, max_q_len).contiguous()
194 | q_mask_ = q_mask_.view(B*max_num_doc, -1)
195 | q_weighted_emb = self.word_emb_match(d_word_emb, q_word_emb, q_mask_)
196 | doc_em = feed['documents_em'].float().view(B*max_num_doc, max_d_len, 1)
197 | doc_input = torch.cat([d_word_emb, q_weighted_emb, doc_em], dim=-1) # 2*word_dim + 1
198 |
199 | doc_input = self.input_proj(doc_input).tanh()
200 | word_entity_id = ent_linked_doc_spans.view(B, max_num_candidates, -1).transpose(1,2)
201 | word_ent_info_mask = (word_entity_id.sum(-1, keepdim=True) != 0).float()
202 | word_ent_info = torch.bmm(word_entity_id.float(), ent_emb) # (B, |D|*d_len, h_dim)
203 | word_ent_info = self.ent_info_proj(word_ent_info).tanh()
204 | doc_input = self.ent_info_gate(q_for_text.unsqueeze(1), word_ent_info, doc_input.view(B, max_num_doc*max_d_len, -1), word_ent_info_mask)
205 |
206 | d_emb, _ = self.doc_encoder(doc_input.view(B*max_num_doc, max_d_len, -1), doc_len, max_length=doc.size(2))
207 | d_emb = self.hidden_drop(d_emb)
208 |
209 | d_emb = self.ent_info_gate_out(q_for_text.unsqueeze(1), word_ent_info, d_emb.view(B, max_num_doc*max_d_len, -1), word_ent_info_mask).view(B*max_num_doc, max_d_len, -1)
210 |
211 | q_for_text = q_for_text.unsqueeze(1).expand(B, max_num_doc, self.hidden_dim).contiguous()
212 | q_for_text = q_for_text.view(B*max_num_doc, -1) # (B*|D|, h_dim)
213 | d_emb = d_emb.view(B*max_num_doc, max_d_len, -1) # (B*|D|, d_len, h_dim)
214 | q_over_d = torch.bmm(q_for_text.unsqueeze(1), d_emb.transpose(1,2)).squeeze(1) # (B*|D|, d_len)
215 | q_over_d = F.softmax(q_over_d - (1 - doc_mask.view(B*max_num_doc, max_d_len))*1e8, dim=-1)
216 | q_retrieve_d = torch.bmm(q_over_d.unsqueeze(1), d_emb).view(B, max_num_doc, -1) # (B, |D|, h_dim)
217 | ent_linked_doc = (ent_linked_doc_spans.sum(-1) != 0).float() # (B, |C|, |D|)
218 | ent_emb_from_doc = torch.bmm(ent_linked_doc, q_retrieve_d) # (B, |C|, h_dim)
219 | # ent_emb_from_doc = F.dropout(ent_emb_from_doc, 0.5, self.training)
220 |
221 | # retrieve_span
222 | ent_emb_from_span = torch.bmm(feed['ent_link_doc_norm_spans'].float().view(B, max_num_candidates, -1), d_emb.view(B, max_num_doc*max_d_len, -1))
223 | ent_emb_from_span = F.dropout(ent_emb_from_span, 0.2, self.training)
224 |
225 |
226 | # refine KB ent_emb
227 | # refined_ent_emb = self.refine_ent(ent_emb, ent_emb_from_doc)
228 | if self.use_doc:
229 | ent_emb = l_relu(self.attn_match(torch.cat([ent_emb, ent_emb_from_doc, ent_emb_from_span], dim=-1)))
230 | # q_node_emb = self.attn_match_q(q_node_emb)
231 |
232 | ent_scores = (q_node_emb * ent_emb).sum(2)
233 |
234 | answers = feed['answers'].float()
235 | if self.label_smooth:
236 | answers = ((1.0 - self.label_smooth)*answers) + (self.label_smooth/answers.size(1))
237 |
238 | loss = self.loss(ent_scores, feed['answers'].float())
239 |
240 | pred_dist = (ent_scores - (1-ent_mask) * 1e8).sigmoid() * ent_mask
241 | pred = torch.max(ent_scores, dim=1)[1]
242 |
243 | return loss, pred, pred_dist
244 |
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import math
6 |
7 |
8 | class Packed(nn.Module):
9 |
10 | def __init__(self, rnn):
11 | super().__init__()
12 | self.rnn = rnn
13 |
14 | @property
15 | def batch_first(self):
16 | return self.rnn.batch_first
17 |
18 | def forward(self, inputs, lengths, hidden=None, max_length=None):
19 | lens, indices = torch.sort(lengths, 0, True)
20 | inputs = inputs[indices] if self.batch_first else inputs[:, indices]
21 | outputs, (h, c) = self.rnn(nn.utils.rnn.pack_padded_sequence(inputs, lens.tolist(), batch_first=self.batch_first), hidden)
22 | outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=self.batch_first, total_length=max_length)
23 | _, _indices = torch.sort(indices, 0)
24 | outputs = outputs[_indices] if self.batch_first else outputs[:, _indices]
25 | h, c = h[:, _indices, :], c[:, _indices, :]
26 | return outputs, (h, c)
27 |
28 | def gelu(x):
29 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
30 |
31 | def l_relu(x, n_slope=0.01):
32 | return F.leaky_relu(x, n_slope)
33 |
34 | class ConditionGate(nn.Module):
35 | """docstring for ConditionGate"""
36 | def __init__(self, h_dim):
37 | super(ConditionGate, self).__init__()
38 | self.gate = nn.Linear(2*h_dim, h_dim, bias=False)
39 | # self.q_to_x = nn.Linear(h_dim, h_dim)
40 | # self.q_to_y = nn.Linear(h_dim, h_dim)
41 |
42 | def forward(self, q, x, y, gate_mask):
43 | q_x_sim = x*q
44 | q_y_sim = y*q
45 | gate_val = self.gate(torch.cat([q_x_sim, q_y_sim], dim=-1)).sigmoid()
46 | gate_val = gate_val * gate_mask
47 | return gate_val * x + (1 - gate_val) * y
48 |
49 |
50 | class Fusion(nn.Module):
51 | """docstring for Fusion"""
52 | def __init__(self, d_hid):
53 | super(Fusion, self).__init__()
54 | self.r = nn.Linear(d_hid*4, d_hid, bias=False)
55 | self.g = nn.Linear(d_hid*4, d_hid, bias=False)
56 |
57 | def forward(self, x, y):
58 | r_ = self.r(torch.cat([x,y,x-y,x*y], dim=-1)).tanh()
59 | g_ = torch.sigmoid(self.g(torch.cat([x,y,x-y,x*y], dim=-1)))
60 | return g_ * r_ + (1 - g_) * x
61 |
62 | class AttnEncoder(nn.Module):
63 | """docstring for ClassName"""
64 | def __init__(self, d_hid):
65 | super(AttnEncoder, self).__init__()
66 | self.attn_linear = nn.Linear(d_hid, 1, bias=False)
67 |
68 | def forward(self, x, x_mask):
69 | """
70 | x: (B, len, d_hid)
71 | x_mask: (B, len)
72 | return: (B, d_hid)
73 | """
74 | x_attn = self.attn_linear(x)
75 | x_attn = x_attn - (1 - x_mask.unsqueeze(2))*1e8
76 | x_attn = F.softmax(x_attn, dim=1)
77 | return (x*x_attn).sum(1)
78 |
79 | class BilinearSeqAttn(nn.Module):
80 | """A bilinear attention layer over a sequence X w.r.t y:
81 | * o_i = softmax(x_i'Wy) for x_i in X.
82 | Optionally don't normalize output weights.
83 | """
84 |
85 | def __init__(self, x_size, y_size, identity=False, normalize=True):
86 | super(BilinearSeqAttn, self).__init__()
87 | self.normalize = normalize
88 |
89 | # If identity is true, we just use a dot product without transformation.
90 | if not identity:
91 | self.linear = nn.Linear(y_size, x_size)
92 | else:
93 | self.linear = None
94 |
95 | def forward(self, x, y, x_mask):
96 | """
97 | Args:
98 | x: batch * len * hdim1
99 | y: batch * hdim2
100 | x_mask: batch * len (1 for padding, 0 for true)
101 | Output:
102 | alpha = batch * len
103 | """
104 | Wy = self.linear(y) if self.linear is not None else y
105 | xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)
106 | xWy.data.masked_fill_(x_mask.data, -float('inf'))
107 | if self.normalize:
108 | if self.training:
109 | # In training we output log-softmax for NLL
110 | alpha = F.log_softmax(xWy, dim=-1)
111 | else:
112 | # ...Otherwise 0-1 probabilities
113 | alpha = F.softmax(xWy, dim=-1)
114 | else:
115 | alpha = xWy.exp()
116 | return alpha
117 |
118 | class SeqAttnMatch(nn.Module):
119 | """Given sequences X and Y, match sequence Y to each element in X.
120 | * o_i = sum(alpha_j * y_j) for i in X
121 | * alpha_j = softmax(y_j * x_i)
122 | """
123 |
124 | def __init__(self, input_size, identity=False):
125 | super(SeqAttnMatch, self).__init__()
126 | if not identity:
127 | self.linear = nn.Linear(input_size, input_size)
128 | else:
129 | self.linear = None
130 |
131 | def forward(self, x, y, y_mask):
132 | """
133 | Args:
134 | x: batch * len1 * hdim
135 | y: batch * len2 * hdim
136 | y_mask: batch * len2 (1 for padding, 0 for true)
137 | Output:
138 | matched_seq: batch * len1 * hdim
139 | """
140 | # Project vectors
141 | if self.linear:
142 | x_proj = self.linear(x.view(-1, x.size(2))).view(x.size())
143 | x_proj = F.relu(x_proj)
144 | y_proj = self.linear(y.view(-1, y.size(2))).view(y.size())
145 | y_proj = F.relu(y_proj)
146 | else:
147 | x_proj = x
148 | y_proj = y
149 |
150 | # Compute scores
151 | scores = x_proj.bmm(y_proj.transpose(2, 1))
152 |
153 | # Mask padding
154 | y_mask = y_mask.unsqueeze(1).expand(scores.size())
155 | scores.data.masked_fill_(y_mask.data, -float('inf'))
156 |
157 | # Normalize with softmax
158 | alpha_flat = F.softmax(scores.view(-1, y.size(1)), dim=-1)
159 | alpha = alpha_flat.view(-1, x.size(1), y.size(1))
160 |
161 | # Take weighted average
162 | matched_seq = alpha.bmm(y)
163 |
164 | return matched_seq
165 |
166 |
167 | class QueryReform(nn.Module):
168 | """docstring for QueryReform"""
169 | def __init__(self, h_dim):
170 | super(QueryReform, self).__init__()
171 | # self.q_encoder = AttnEncoder(h_dim)
172 | self.fusion = Fusion(h_dim)
173 | self.q_ent_attn = nn.Linear(h_dim, h_dim)
174 |
175 | def forward(self, q_node, ent_emb, seed_info, ent_mask):
176 | '''
177 | q: (B,q_len,h_dim)
178 | q_mask: (B,q_len)
179 | q_ent_span: (B,q_len)
180 | ent_emb: (B,C,h_dim)
181 | seed_info: (B, C)
182 | ent_mask: (B, C)
183 | '''
184 | # q_node = self.q_encoder(q, q_mask)
185 | q_ent_attn = (self.q_ent_attn(q_node).unsqueeze(1) * ent_emb).sum(2, keepdim=True)
186 | q_ent_attn = F.softmax(q_ent_attn - (1 - ent_mask.unsqueeze(2)) * 1e8, dim=1)
187 | # attn_retrieve = (q_ent_attn * ent_emb).sum(1)
188 |
189 | seed_retrieve = torch.bmm(seed_info.unsqueeze(1), ent_emb).squeeze(1) # (B, 1, h_dim)
190 | # how to calculate the gate
191 |
192 | # return self.fusion(q_node, attn_retrieve)
193 | return self.fusion(q_node, seed_retrieve)
194 |
195 |
196 | # retrieved = self.transform(torch.cat([seed_retrieve, attn_retrieve], dim=-1)).relu()
197 | # gate_val = self.gate(torch.cat([q.squeeze(1), seed_retrieve, attn_retrieve], dim=-1)).sigmoid()
198 | # return self.fusion(q.squeeze(1), retrieved).unsqueeze(1)
199 | # return (gate_val * q.squeeze(1) + (1 - gate_val) * torch.tanh(self.transform(torch.cat([q.squeeze(1), seed_retrieve, attn_retrieve], dim=-1)))).unsqueeze(1)
200 |
201 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=$1 python train.py --model_id $2 --num_layer 1 --max_num_neighbors 50 --label_smooth 0.1 --data_folder datasets/webqsp/full/
--------------------------------------------------------------------------------
/run_with_doc.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=$1 python train.py --model_id $2 --num_layer 1 --max_num_neighbors 100 --use_doc --data_folder datasets/webqsp/kb_05/ --eps 0.12
--------------------------------------------------------------------------------
/script.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | from tqdm import tqdm
4 | from collections import Counter
5 | from itertools import izip
6 |
7 | def combine_dist(dist1, dist2, w1):
8 | ensemble_dist = dist2.copy()
9 | for gid, prob in dist1.items():
10 | if gid in ensemble_dist:
11 | ensemble_dist[gid] = (1 - w1) * ensemble_dist[gid] + w1 * prob
12 | else:
13 | ensemble_dist[gid] = prob
14 | return ensemble_dist
15 |
16 | def get_one_f1(entities, dist, eps, answers):
17 | best_entity = -1
18 | max_prob = 0.0
19 | preds = []
20 | for entity in entities:
21 | if dist[entity] > max_prob:
22 | max_prob = dist[entity]
23 | best_entity = entity
24 | if dist[entity] > eps:
25 | preds.append(entity)
26 |
27 | return cal_eval_metric(best_entity, preds, answers)
28 |
29 | def cal_eval_metric(best_pred, preds, answers):
30 | correct, total = 0.0, 0.0
31 | for entity in preds:
32 | if entity in answers:
33 | correct += 1
34 | total += 1
35 | if len(answers) == 0:
36 | if total == 0:
37 | return 1.0, 1.0, 1.0, 1.0 # precision, recall, f1, hits
38 | else:
39 | return 0.0, 1.0, 0.0, 1.0 # precision, recall, f1, hits
40 | else:
41 | hits = float(best_pred in answers)
42 | if total == 0:
43 | return 1.0, 0.0, 0.0, hits # precision, recall, f1, hits
44 | else:
45 | precision, recall = correct / total, correct / len(answers)
46 | f1 = 2.0 / (1.0 / precision + 1.0 / recall) if precision != 0 and recall != 0 else 0.0
47 | return precision, recall, f1, hits
48 |
49 | def compare_pr(kb_pred_file, doc_pred_file, hybrid_pred_file, w_kb, eps_doc, eps_kb, eps_ensemble, eps_hybrid, eps_ensemble_all):
50 | doc_only_recall, doc_only_precision, doc_only_f1, doc_only_hits = [], [], [], []
51 | kb_only_recall, kb_only_precision, kb_only_f1, kb_only_hits = [], [], [], []
52 | ensemble_recall, ensemble_precision, ensemble_f1, ensemble_hits = [], [], [], []
53 | hybrid_recall, hybrid_precision, hybrid_f1, hybrid_hits = [], [], [], []
54 | ensemble_all_recall, ensemble_all_precision, ensemble_all_f1, ensemble_all_hits = [], [], [], []
55 |
56 | # total_not_answerable = 0.0
57 | with open(kb_pred_file) as f_kb, open(doc_pred_file) as f_doc, open(hybrid_pred_file) as f_hybrid:
58 | line_id = 0
59 | for line_kb, line_doc, line_hybrid in tqdm(zip(f_kb, f_doc, f_hybrid)):
60 | line_id += 1
61 | line_kb = json.loads(line_kb)
62 | line_doc = json.loads(line_doc)
63 | line_hybrid = json.loads(line_hybrid)
64 | assert line_kb['answers'] == line_doc['answers'] == line_hybrid['answers']
65 | answers = set([unicode(answer) for answer in line_kb['answers']])
66 | # total_not_answerable += (len(answers) == 0)
67 | # assert len(answers) > 0
68 |
69 | dist_kb = line_kb['dist']
70 | dist_doc = line_doc['dist']
71 | dist_hybrid = line_hybrid['dist']
72 | dist_ensemble = combine_dist(dist_kb, dist_doc, w_kb)
73 | dist_ensemble_all = combine_dist(dist_ensemble, dist_hybrid, w1=0.3)
74 |
75 | kb_entities = set(dist_kb.keys())
76 | doc_entities = set(dist_doc.keys())
77 | either_entities = kb_entities | doc_entities
78 | assert either_entities == set(dist_hybrid.keys())
79 |
80 | p, r, f1, hits = get_one_f1(doc_entities, dist_doc, eps_doc, answers)
81 | doc_only_precision.append(p)
82 | doc_only_recall.append(r)
83 | doc_only_f1.append(f1)
84 | doc_only_hits.append(hits)
85 |
86 | p, r, f1, hits = get_one_f1(kb_entities, dist_kb, eps_kb, answers)
87 | kb_only_precision.append(p)
88 | kb_only_recall.append(r)
89 | kb_only_f1.append(f1)
90 | kb_only_hits.append(hits)
91 |
92 | p, r, f1, hits = get_one_f1(either_entities, dist_ensemble, eps_ensemble, answers)
93 | ensemble_precision.append(p)
94 | ensemble_recall.append(r)
95 | ensemble_f1.append(f1)
96 | ensemble_hits.append(hits)
97 |
98 | p, r, f1, hits = get_one_f1(either_entities, dist_hybrid, eps_hybrid, answers)
99 | hybrid_precision.append(p)
100 | hybrid_recall.append(r)
101 | hybrid_f1.append(f1)
102 | hybrid_hits.append(hits)
103 |
104 | p, r, f1, hits = get_one_f1(either_entities, dist_ensemble_all, eps_ensemble_all, answers)
105 | ensemble_all_precision.append(p)
106 | ensemble_all_recall.append(r)
107 | ensemble_all_f1.append(f1)
108 | ensemble_all_hits.append(hits)
109 |
110 |
111 | print('text only setting:')
112 | print('hits: ', sum(doc_only_hits) / len(doc_only_hits))
113 | print('precision: ', sum(doc_only_precision) / len(doc_only_precision))
114 | print('recall: ', sum(doc_only_recall) / len(doc_only_recall))
115 | print('f1: ', sum(doc_only_f1) / len(doc_only_f1))
116 | print('\n')
117 |
118 | print('kb only setting:')
119 | print('hits: ', sum(kb_only_hits) / len(kb_only_hits))
120 | print('precision: ', sum(kb_only_precision) / len(kb_only_precision))
121 | print('recall: ', sum(kb_only_recall) / len(kb_only_recall))
122 | print('f1: ', sum(kb_only_f1) / len(kb_only_f1))
123 | print('\n')
124 |
125 | print('late fusion:')
126 | print('hits: ', sum(ensemble_hits) / len(ensemble_hits))
127 | print('precision: ', sum(ensemble_precision) / len(ensemble_precision))
128 | print('recall: ', sum(ensemble_recall) / len(ensemble_recall))
129 | print('f1: ', sum(ensemble_f1) / len(ensemble_f1))
130 | print('\n')
131 |
132 | print('early fusion:')
133 | print('hits: ', sum(hybrid_hits) / len(hybrid_hits))
134 | print('precision: ', sum(hybrid_precision) / len(hybrid_precision))
135 | print('recall: ', sum(hybrid_recall) / len(hybrid_recall))
136 | print('f1: ', sum(hybrid_f1) / len(hybrid_f1))
137 | print('\n')
138 |
139 | print('early & late fusion:')
140 | print('hits: ', sum(ensemble_all_hits) / len(ensemble_all_hits))
141 | print('precision: ', sum(ensemble_all_precision) / len(ensemble_all_precision))
142 | print('recall: ', sum(ensemble_all_recall) / len(ensemble_all_recall))
143 | print('f1: ', sum(ensemble_all_f1) / len(ensemble_all_f1))
144 | print('\n')
145 |
146 |
147 | if __name__ == "__main__":
148 | dataset = sys.argv[1]
149 | pred_kb_file = sys.argv[2]
150 | pred_doc_file = sys.argv[3]
151 | pred_hybrid_file = sys.argv[4]
152 | if dataset == "wikimovie":
153 | w_kb = 0.9
154 | eps_doc, eps_kb, eps_ensemble, eps_hybrid, eps_ensemble_all = 0.5, 0.55, 0.6, 0.5, 0.55
155 | elif dataset == "webqsp":
156 | w_kb = 1.0
157 | eps_doc, eps_kb, eps_ensemble, eps_hybrid, eps_ensemble_all = 0.15, 0.2, 0.2, 0.2, 0.3
158 | else:
159 | assert False, "dataset not recognized"
160 |
161 | compare_pr(pred_kb_file, pred_doc_file, pred_hybrid_file, w_kb, eps_doc, eps_kb, eps_ensemble, eps_hybrid, eps_ensemble_all)
162 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CUDA_VISIBLE_DEVICES=$1 python train.py --model_id $2 --num_layer 1 --max_num_neighbors 50 --mode test --eps 0.08 --data_folder datasets/webqsp/kb_03/
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import random
4 |
5 | from data_generator import DataLoader
6 | from model import KAReader
7 | from util import get_config, cal_accuracy, load_documents
8 |
9 | from tensorboardX import SummaryWriter
10 |
11 | def f1_and_hits(answers, candidate2prob, eps):
12 | retrieved = []
13 | correct = 0
14 | best_ans, max_prob = -1, 0
15 | for c, prob in candidate2prob.items():
16 | if prob > max_prob:
17 | max_prob = prob
18 | best_ans = c
19 | if prob > eps:
20 | retrieved.append(c)
21 | if c in answers:
22 | correct += 1
23 | if len(answers) == 0:
24 | if len(retrieved) == 0:
25 | return 1.0, 1.0
26 | else:
27 | return 0.0, 1.0
28 | else:
29 | hits = float(best_ans in answers)
30 | if len(retrieved) == 0:
31 | return 0.0, hits
32 | else:
33 | p, r = correct / len(retrieved), correct / len(answers)
34 | f1 = 2.0 / (1.0 / p + 1.0 / r) if p != 0 and r != 0 else 0.0
35 | return f1, hits
36 |
37 | def get_best_ans(candidate2prob):
38 | best_ans, max_prob = -1, 0
39 | for c, prob in candidate2prob.items():
40 | if prob > max_prob:
41 | max_prob = prob
42 | best_ans = c
43 | return best_ans
44 |
45 | def train(cfg):
46 | tf_logger = SummaryWriter('tf_logs/' + cfg['model_id'])
47 |
48 | # train and test share the same set of documents
49 | documents = load_documents(cfg['data_folder'] + cfg['{}_documents'.format(cfg['mode'])])
50 |
51 | # train data
52 | train_data = DataLoader(cfg, documents)
53 | valid_data = DataLoader(cfg, documents, mode='dev')
54 |
55 | model = KAReader(cfg)
56 | model = model.to(torch.device('cuda'))
57 |
58 | trainable = filter(lambda p: p.requires_grad, model.parameters())
59 | optim = torch.optim.Adam(trainable, lr=cfg['learning_rate'])
60 |
61 | if cfg['lr_schedule']:
62 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, [30], gamma=0.5)
63 |
64 | model.train()
65 | best_val_f1 = 0
66 | best_val_hits = 0
67 | for epoch in range(cfg['num_epoch']):
68 | batcher = train_data.batcher(shuffle=True)
69 | train_loss = []
70 | for feed in batcher:
71 | loss, pred, pred_dist = model(feed)
72 | train_loss.append(loss.item())
73 | # acc, max_acc = cal_accuracy(pred, feed['answers'].cpu().numpy())
74 | # train_acc.append(acc)
75 | # train_max_acc.append(max_acc)
76 | optim.zero_grad()
77 | loss.backward()
78 | if cfg['gradient_clip'] != 0:
79 | torch.nn.utils.clip_grad_norm_(trainable, cfg['gradient_clip'])
80 | optim.step()
81 | tf_logger.add_scalar('avg_batch_loss', np.mean(train_loss), epoch)
82 |
83 | val_f1, val_hits = test(model, valid_data, cfg['eps'])
84 | if cfg['lr_schedule']:
85 | scheduler.step()
86 | tf_logger.add_scalar('eval_f1', val_f1, epoch)
87 | tf_logger.add_scalar('eval_hits', val_hits, epoch)
88 | if val_f1 > best_val_f1:
89 | best_val_f1 = val_f1
90 | if val_hits > best_val_hits:
91 | best_val_hits = val_hits
92 | torch.save(model.state_dict(), 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id']))
93 | print('evaluation best f1:{} current:{}'.format(best_val_f1, val_f1))
94 | print('evaluation best hits:{} current:{}'.format(best_val_hits, val_hits))
95 |
96 | print('save final model')
97 | torch.save(model.state_dict(), 'model/{}/{}_final.pt'.format(cfg['name'], cfg['model_id']))
98 |
99 |
100 | # model_save_path = 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id'])
101 | # model.load_state_dict(torch.load(model_save_path))
102 |
103 |
104 | print('\n..........Finished training, start testing.......')
105 |
106 | test_data = DataLoader(cfg, documents, mode='test')
107 | model.eval()
108 | print('finished training, testing final model...')
109 | test(model, test_data, cfg['eps'])
110 |
111 | # print('testing best model...')
112 | # model_save_path = 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id'])
113 | # model.load_state_dict(torch.load(model_save_path))
114 | # model.eval()
115 | # test(model, test_data, cfg['eps'])
116 |
117 |
118 | def test(model, test_data, eps):
119 |
120 | model.eval()
121 | batcher = test_data.batcher()
122 | id2entity = test_data.id2entity
123 | f1s, hits = [], []
124 | questions = []
125 | pred_answers = []
126 | for feed in batcher:
127 | _, pred, pred_dist = model(feed)
128 | acc, max_acc = cal_accuracy(pred, feed['answers'].cpu().numpy())
129 | batch_size = pred_dist.size(0)
130 | batch_answers = feed['answers_']
131 | questions += feed['questions_']
132 | batch_candidates = feed['candidate_entities']
133 | pad_ent_id = len(id2entity)
134 | for batch_id in range(batch_size):
135 | answers = batch_answers[batch_id]
136 | candidates = batch_candidates[batch_id,:].tolist()
137 | probs = pred_dist[batch_id, :].tolist()
138 | candidate2prob = {}
139 | for c, p in zip(candidates, probs):
140 | if c == pad_ent_id:
141 | continue
142 | else:
143 | candidate2prob[c] = p
144 | f1, hit = f1_and_hits(answers, candidate2prob, eps)
145 | best_ans = get_best_ans(candidate2prob)
146 | best_ans = id2entity.get(best_ans, '')
147 |
148 | pred_answers.append(best_ans)
149 | f1s.append(f1)
150 | hits.append(hit)
151 | print('evaluation.......')
152 | print('how many eval samples......', len(f1s))
153 | print('avg_f1', np.mean(f1s))
154 | print('avg_hits', np.mean(hits))
155 |
156 | model.train()
157 | return np.mean(f1s), np.mean(hits)
158 |
159 | if __name__ == "__main__":
160 | # config_file = sys.argv[2]
161 | cfg = get_config()
162 | random.seed(cfg['seed'])
163 | np.random.seed(cfg['seed'])
164 | torch.manual_seed(cfg['seed'])
165 | torch.cuda.manual_seed_all(cfg['seed'])
166 | if cfg['mode'] == 'train':
167 | train(cfg)
168 | elif cfg['mode'] == 'test':
169 | documents = load_documents(cfg['data_folder'] + cfg['{}_documents'.format(cfg['mode'])])
170 | test_data = DataLoader(cfg, documents, mode='test')
171 | model = KAReader(cfg)
172 | model = model.to(torch.device('cuda'))
173 | model_save_path = 'model/{}/{}_best.pt'.format(cfg['name'], cfg['model_id'])
174 | model.load_state_dict(torch.load(model_save_path))
175 | model.eval()
176 | test(model, test_data, cfg['eps'])
177 | else:
178 | assert False, "--train or --test?"
179 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import json
2 | import nltk
3 | import numpy as np
4 | import os
5 | import torch
6 | import yaml
7 |
8 | from collections import Counter
9 |
10 | from torch.autograd import Variable
11 | from tqdm import tqdm
12 |
13 | import argparse
14 |
15 | def get_config(config_path=None):
16 | if not config_path:
17 | parser = argparse.ArgumentParser()
18 |
19 | # datasets
20 | parser.add_argument('--name', default='webqsp', type=str)
21 | parser.add_argument('--data_folder', default='datasets/webqsp/kb_03/', type=str)
22 | parser.add_argument('--train_data', default='train.json', type=str)
23 | parser.add_argument('--train_documents', default='documents.json', type=str)
24 | parser.add_argument('--dev_data', default='dev.json', type=str)
25 | parser.add_argument('--dev_documents', default='documents.json', type=str)
26 | parser.add_argument('--test_data', default='test.json', type=str)
27 | parser.add_argument('--test_documents', default='documents.json', type=str)
28 | parser.add_argument('--max_query_word', default=10, type=int)
29 | parser.add_argument('--max_document_word', default=50, type=int)
30 | parser.add_argument('--max_char', default=25, type=int)
31 | parser.add_argument('--max_num_neighbors', default=100, type=int)
32 | parser.add_argument('--max_rel_words', default=8, type=int)
33 |
34 | # embeddings
35 | parser.add_argument('--word2id', default='glove_vocab.txt', type=str)
36 | parser.add_argument('--relation2id', default='relations.txt', type=str)
37 | parser.add_argument('--entity2id', default='entities.txt', type=str)
38 | parser.add_argument('--char2id', default='chars.txt', type=str)
39 | parser.add_argument('--word_emb_file', default='glove_word_emb.npy', type=str)
40 | parser.add_argument('--entity_emb_file', default='entity_emb_100d.npy', type=str)
41 | parser.add_argument('--rel_word_ids', default='rel_word_idx.npy', type=str)
42 |
43 | # dimensions, layers, dropout
44 | parser.add_argument('--num_layer', default=1, type=int)
45 | parser.add_argument('--entity_dim', default=100, type=int)
46 | parser.add_argument('--word_dim', default=300, type=int)
47 | parser.add_argument('--hidden_drop', default=0.2, type=float)
48 | parser.add_argument('--word_drop', default=0.2, type=float)
49 |
50 | # optimization
51 | parser.add_argument('--num_epoch', default=100, type=int)
52 | parser.add_argument('--batch_size', default=8, type=int)
53 | parser.add_argument('--gradient_clip', default=1.0, type=float)
54 | parser.add_argument('--learning_rate', default=0.001, type=float)
55 | parser.add_argument('--seed', default=19940715, type=int)
56 | parser.add_argument('--lr_schedule', action='store_true')
57 | parser.add_argument('--label_smooth', default=0.1, type=float)
58 | parser.add_argument('--fact_drop', default=0, type=float)
59 |
60 | # model options
61 | parser.add_argument('--use_doc', action='store_true')
62 | parser.add_argument('--use_inverse_relation', action='store_true')
63 | parser.add_argument('--model_id', default='debug', type=str)
64 | parser.add_argument('--load_model_file', default=None, type=str)
65 | parser.add_argument('--mode', default='train', type=str)
66 | parser.add_argument('--eps', default=0.05, type=float) # threshold for f1
67 |
68 | args = parser.parse_args()
69 |
70 | if args.name == 'webqsp':
71 | args.type_rels = ['', '', '', '', '', '', '', '', '', '']
72 | else:
73 | args.type_rels = []
74 |
75 | config = vars(args)
76 | config['to_save_model'] = True # always save model
77 | config['save_model_file'] = 'model/' + config['name'] + '/best_{}.pt'.format(config['model_id'])
78 | config['pred_file'] = 'results/' + config['name'] + '/best_{}.pred'.format(config['model_id'])
79 | else:
80 | with open(config_path, "r") as setting:
81 | config = yaml.load(setting)
82 |
83 | print('-'* 10 + 'Experiment Config' + '-' * 10)
84 | for k, v in config.items():
85 | print(k + ': ', v)
86 | print('-'* 10 + 'Experiment Config' + '-' * 10 + '\n')
87 |
88 | return config
89 |
90 | def use_cuda(var):
91 | if torch.cuda.is_available():
92 | return var.cuda()
93 | else:
94 | return var
95 |
96 | def save_model(the_model, path):
97 | if os.path.exists(path):
98 | path = path + '_copy'
99 | print("saving model to ...", path)
100 | torch.save(the_model, path)
101 |
102 |
103 | def load_model(path):
104 | if not os.path.exists(path):
105 | assert False, 'cannot find model: ' + path
106 | print("loading model from ...", path)
107 | return torch.load(path)
108 |
109 | def load_dict(filename):
110 | word2id = dict()
111 | with open(filename) as f_in:
112 | for line in f_in:
113 | word = line.strip()
114 | word2id[word] = len(word2id)
115 | return word2id
116 |
117 | def load_documents(document_file):
118 | print('loading document from', document_file)
119 | documents = dict()
120 | with open(document_file) as f_in:
121 | for line in tqdm(list(f_in)):
122 | passage = json.loads(line)
123 | # tokenize document
124 | document_token = nltk.word_tokenize(passage['document']['text'])
125 | if 'title' in passage:
126 | title_token = nltk.word_tokenize(passage['title']['text'])
127 | passage['tokens'] = title_token + ['|'] + document_token
128 | # passage['tokens'] = title_token
129 | else:
130 | passage['tokens'] = document_token
131 | documents[int(passage['documentId'])] = passage
132 | return documents
133 |
134 | def cal_accuracy(pred, answer_dist):
135 | """
136 | pred: batch_size
137 | answer_dist: batch_size, max_local_entity
138 | """
139 | num_correct = 0.0
140 | num_answerable = 0.0
141 | for i, l in enumerate(pred):
142 | num_correct += (answer_dist[i, l] != 0)
143 | for dist in answer_dist:
144 | if np.sum(dist) != 0:
145 | num_answerable += 1
146 | return num_correct / len(pred), num_answerable / len(pred)
147 |
148 | def char_vocab(word2id, data_path):
149 | # build char embeddings
150 | char_counter = Counter()
151 | max_char = 0
152 | with open(word2id) as f:
153 | for word in f:
154 | word = word.strip()
155 | max_char = max(max_char, len(word))
156 | for char in word:
157 | char_counter[char] += 1
158 |
159 | char2id = {c: idx for idx, c in enumerate(char_counter.keys(), 1)}
160 | char2id['__unk__'] = 0
161 |
162 | id2char = {id_:c for c, id_ in char2id.items()}
163 |
164 | vocab_size = len(char2id)
165 | char_vocabs = []
166 | for _ in range(vocab_size):
167 | char_vocabs.append(id2char[_])
168 |
169 | with open(data_path, 'w') as g:
170 | g.write('\n'.join(char_vocabs))
171 |
172 | print(max_char)
173 |
174 | class LeftMMFixed(torch.autograd.Function):
175 | """
176 | Implementation of matrix multiplication of a Sparse Variable with a Dense Variable, returning a Dense one.
177 | This is added because there's no autograd for sparse yet. No gradient computed on the sparse weights.
178 | """
179 |
180 | def __init__(self):
181 | super(LeftMMFixed, self).__init__()
182 | self.sparse_weights = None
183 |
184 | def forward(self, sparse_weights, x):
185 | if self.sparse_weights is None:
186 | self.sparse_weights = sparse_weights
187 | return torch.mm(self.sparse_weights, x)
188 |
189 | def backward(self, grad_output):
190 | sparse_weights = self.sparse_weights
191 | return None, torch.mm(sparse_weights.t(), grad_output)
192 |
193 |
194 | def sparse_bmm(X, Y):
195 | """Batch multiply X and Y where X is sparse, Y is dense.
196 | Args:
197 | X: Sparse tensor of size BxMxN. Consists of two tensors,
198 | I:3xZ indices, and V:1xZ values.
199 | Y: Dense tensor of size BxNxK.
200 | Returns:
201 | batched-matmul(X, Y): BxMxK
202 | """
203 | I = X._indices()
204 | V = X._values()
205 | B, M, N = X.size()
206 | _, _, K = Y.size()
207 | Z = I.size()[1]
208 | lookup = Y[I[0, :], I[2, :], :]
209 | X_I = torch.stack((I[0, :] * M + I[1, :], use_cuda(torch.arange(Z).type(torch.LongTensor))), 0)
210 | S = use_cuda(Variable(torch.cuda.sparse.FloatTensor(X_I, V, torch.Size([B * M, Z])), requires_grad=False))
211 | prod_op = LeftMMFixed()
212 | prod = prod_op(S, lookup)
213 | return prod.view(B, M, K)
214 |
215 | if __name__ == "__main__":
216 | # load_documents('datasets/wikimovie/full_doc/documents.json')
217 | char_vocab('datasets/webqsp/kb_05/vocab.txt', 'datasets/webqsp/kb_05/chars.txt')
218 |
--------------------------------------------------------------------------------