├── __init__.py ├── README.md ├── tf_util.py ├── query.py └── qa_network.py /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## General Information 2 | 3 | This repository contains an implementation of Query-Answer Networks ([qa_network.py](./qa_network.py)). Interacting with the model 4 | requires the construction of ContextQueries objects ([query.py](./query.py)). Additional information can be found in comments 5 | within these files. A minimal example can be found at the end of [qa_network.py](./qa_network.py#L531). 6 | 7 | If you use code of this repository, please make sure to cite: 8 | 9 | **Separating Answers from Queries for Neural Reading Comprehension**. Dirk Weissenborn. [*arXiv:1607.03316*](http://arxiv.org/abs/1607.03316). 10 | 11 | ## Installation 12 | 13 | 14 | * requires TensorFlow 0.9 15 | -------------------------------------------------------------------------------- /tf_util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def batch_dot(t1, t2): 5 | t1_e = tf.expand_dims(t1, 1) 6 | t2_e = tf.expand_dims(t2, 2) 7 | return tf.squeeze(tf.batch_matmul(t1_e, t2_e), [1, 2]) 8 | 9 | 10 | # compute all participating tensors in forward pass 11 | def get_tensors(output_tensors, input_tensors, include_out=True, current_path=None): 12 | res = set() 13 | for o in output_tensors: 14 | if o not in input_tensors: # we do not want to add inputs 15 | current_new = set() 16 | if include_out: 17 | current_new.add(o) # we do not add o directly to res 18 | if current_path: 19 | current_new = current_new.union(current_path) 20 | res = res.union(get_tensors(o.op.inputs, input_tensors, True, current_new)) 21 | else: 22 | # only keep paths leading to inputs 23 | res = res.union(current_path) 24 | return res -------------------------------------------------------------------------------- /query.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains abstractions for defining queries and its corresponding support in form of 3 | supporting query-answer pairs. Each context (list of symbol ids) can contain multiple 4 | cloze-style queries. 5 | 6 | A query can have additional support which is a list of additional query-answer pairs in the same 7 | format as normal queries. They can be used to help answer the original query. 8 | 9 | 10 | """ 11 | 12 | class ContextQueries: 13 | """ 14 | Wraps all queries for a certain context into a single object. If used as support, context can be specified 15 | as None which means that the context of the actual query serves as context for the support. 16 | """ 17 | 18 | def __init__(self, context, queries, support=None, collaborative_support=False, source=None): 19 | """ 20 | :param context: list of symbol ids that should be the same for all queries 21 | :param queries: ClozeQueries within this context. 22 | :param support: None or list of ContextQueries; Support defined here serves as support for all 23 | queries of this object. 24 | :param collaborative_support: Used for answering queries collaboratively (has never been used yet) 25 | :param source: optional, can be used to keep track of the context origin 26 | """ 27 | assert all((isinstance(q, ClozeQuery) and q.context == context for q in queries)), \ 28 | "Context queries must share same context." 29 | self.context = context 30 | self.queries = queries 31 | self.collaborative_support = collaborative_support 32 | #supporting evidence for all queries 33 | assert support is None or all((isinstance(q, ContextQueries) for q in support)), \ 34 | "Support must be a list of ContextQueries" 35 | self.support = support 36 | self.source = source # information about source of this query (optional) 37 | 38 | 39 | class ClozeQuery: 40 | """ 41 | Cloze-queries are fully defined by a a span-of-interest (start and end) within a certain context. 42 | Note, the qa model embeds a query by its surrounding context (everything outside the span). If you 43 | want the model to also encode the span itself, simply use the negative span, i.e. swap start and end. 44 | If answer is None, only candidates will be scored in the QANetwork. If answer is defined the model will 45 | score "[answer] + candidates". 46 | """ 47 | 48 | def __init__(self, context, start, end, answer, answer_word, candidates, support=None): 49 | """ 50 | :param context: list of symbol ids 51 | :param start: start of the span of the query 52 | :param end: end of the span of the query 53 | :param answer: id of the answer to be predicted for the given span; If None only candidates are scored 54 | in QA Network 55 | :param answer_word: symbol-id for the respective answer that are used to refine query between hops 56 | (note, answer vocab can differ from input vocab, thus the differentiation between answer and answer_word) 57 | :param candidates: answer candidates of this query without the answer itself 58 | :param support: None or list of ContextQueries 59 | """ 60 | self.context = context 61 | self.start = start 62 | self.end = end 63 | self.answer = answer 64 | self.answer_word = answer_word 65 | self.candidates = candidates 66 | assert support is None or all((isinstance(q, ContextQueries) for q in support)), \ 67 | "Support must be a list of ContextQueries" 68 | self.support = support 69 | 70 | 71 | def flatten_queries(context_queries_list): 72 | ret = [] 73 | for qs in context_queries_list: 74 | ret.extend(qs.queries) 75 | return ret 76 | -------------------------------------------------------------------------------- /qa_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.ops.rnn_cell import * 4 | from tensorflow.python.ops.rnn import dynamic_rnn 5 | from query import * 6 | import tf_util 7 | 8 | 9 | def default_init(): 10 | return tf.random_normal_initializer(0.0, 0.1) 11 | 12 | class QAModel: 13 | """ 14 | A QAModel embeds ClozeQuery (see query.py) and their respective support to answer the query by selecting one of the 15 | specified answer candidates. This is an implementation of the system described in [to appear]. 16 | """ 17 | 18 | def __init__(self, size, batch_size, vocab_size, answer_vocab_size, max_length, is_train=True, learning_rate=1e-2, 19 | composition="GRU", max_hops=0, devices=None, keep_prob=1.0): 20 | """ 21 | :param size: size of hidden states 22 | :param batch_size: initial batch_size (adapts automatically) 23 | :param vocab_size: size of input vocabulary (vocabulary of contexts) 24 | :param answer_vocab_size: size of answer (candidates) vocabulary 25 | :param max_length: maximum length of an individual context 26 | :param is_train: 27 | :param learning_rate: 28 | :param composition: "GRU", "LSTM", "BiGRU" are possible 29 | :param max_hops: maximum number of hops, can be set manually to something lower by assigning a different value 30 | to variable (self.)num_hops which is initialized with max_hops 31 | :param devices: defaults to ["/cpu:0"], but can be a list of up to 3 devices. The model is automatically 32 | partitioned into the different devices. 33 | :param keep_prob: 1.0-dropout rate, that is applied to the input embeddings 34 | """ 35 | self._vocab_size = vocab_size 36 | self._max_length = max_length 37 | self._size = size 38 | self._batch_size = batch_size 39 | self._is_train = is_train 40 | self._composition = composition 41 | self._max_hops = max_hops 42 | self._device0 = devices[0] if devices is not None else "/cpu:0" 43 | self._device1 = devices[1 % len(devices)] if devices is not None else "/cpu:0" 44 | self._device2 = devices[2 % len(devices)] if devices is not None else "/cpu:0" 45 | 46 | self._init = tf.random_normal_initializer(0.0, 0.1) 47 | with tf.device(self._device0): 48 | with tf.variable_scope(self.name(), initializer=tf.contrib.layers.xavier_initializer()): 49 | self._init_inputs() 50 | self.keep_prob = tf.get_variable("keep_prob", [], initializer=tf.constant_initializer(keep_prob)) 51 | with tf.device("/cpu:0"): 52 | # embeddings 53 | self.output_embedding = tf.get_variable("E_candidate", [answer_vocab_size, self._size], 54 | initializer=self._init) 55 | self.input_embedding = tf.get_variable("E_words", [vocab_size, self._size], 56 | initializer=self._init) 57 | answer, _ = tf.dynamic_partition(self._answer_input, self._query_partition, 2) 58 | lookup_individual = tf.nn.embedding_lookup(self.output_embedding, answer) 59 | cands, _ = tf.dynamic_partition(self._answer_candidates, self._query_partition, 2) 60 | self.candidate_lookup = tf.nn.embedding_lookup(self.output_embedding, cands) 61 | 62 | self.num_hops = tf.Variable(self._max_hops, trainable=False, name="num_queries") 63 | self.query = self._comp_f() 64 | answer = self._retrieve_answer(self.query) 65 | self.score = tf_util.batch_dot(lookup_individual, answer) 66 | self.scores_with_negs = self._score_candidates(answer) 67 | 68 | if is_train: 69 | self.learning_rate = tf.Variable(float(learning_rate), trainable=False, name="lr") 70 | self.global_step = tf.Variable(0, trainable=False, name="step") 71 | 72 | self.opt = tf.train.AdamOptimizer(self.learning_rate) 73 | 74 | current_batch_size = tf.gather(tf.shape(self.scores_with_negs), [0]) 75 | 76 | loss = math_ops.reduce_sum( 77 | tf.nn.sparse_softmax_cross_entropy_with_logits(self.scores_with_negs, 78 | tf.tile(tf.constant([0], tf.int64), 79 | current_batch_size))) 80 | 81 | train_params = tf.trainable_variables() 82 | self.training_weight = tf.Variable(1.0, trainable=False, name="training_weight") 83 | 84 | self._loss = loss / math_ops.cast(current_batch_size, tf.float32) 85 | self._grads = tf.gradients(self._loss, train_params, self.training_weight, colocate_gradients_with_ops=True) 86 | 87 | if len(train_params) > 0: 88 | grads, _ = tf.clip_by_global_norm(self._grads, 5.0) 89 | self._update = self.opt.apply_gradients(zip(grads, train_params), 90 | global_step=self.global_step) 91 | else: 92 | self._update = tf.assign_add(self.global_step, 1) 93 | self.saver = tf.train.Saver(tf.all_variables(), max_to_keep=1) 94 | 95 | def _score_candidates(self, answer): 96 | return tf.squeeze(tf.batch_matmul(self.candidate_lookup, tf.expand_dims(answer, [2])), [2]) + \ 97 | self._candidate_mask # number of negative candidates can vary for each example 98 | 99 | def _composition_function(self, inputs, length, init_state=None): 100 | if self._composition == "GRU": 101 | cell = GRUCell(self._size) 102 | return dynamic_rnn(cell, inputs, sequence_length=length, time_major=True, 103 | initial_state=init_state, dtype=tf.float32)[0] 104 | elif self._composition == "LSTM": 105 | cell = BasicLSTMCell(self._size) 106 | init_state = tf.concat(1, [tf.zeros_like(init_state, tf.float32), init_state]) if init_state else None 107 | outs = dynamic_rnn(cell, inputs, sequence_length=length, time_major=True, 108 | initial_state=init_state, dtype=tf.float32)[0] 109 | return outs 110 | elif self._composition == "BiGRU": 111 | cell = GRUCell(self._size // 2, self._size) 112 | init_state_fw, init_state_bw = tf.split(1, 2, init_state) if init_state else (None, None) 113 | with tf.variable_scope("forward"): 114 | fw_outs = dynamic_rnn(cell, inputs, sequence_length=length, time_major=True, 115 | initial_state=init_state_fw, dtype=tf.float32)[0] 116 | with tf.variable_scope("backward"): 117 | rev_inputs = tf.reverse_sequence(tf.pack(inputs), length, 0, 1) 118 | rev_inputs = [tf.reshape(x, [-1, self._size]) for x in tf.split(0, len(inputs), rev_inputs)] 119 | bw_outs = dynamic_rnn(cell, rev_inputs, sequence_length=length, time_major=True, 120 | initial_state=init_state_bw, dtype=tf.float32)[0] 121 | bw_outs = tf.reverse_sequence(tf.pack(bw_outs), length, 0, 1) 122 | bw_outs = [tf.reshape(x, [-1, self._size]) for x in tf.split(0, len(inputs), bw_outs)] 123 | return [tf.concat(1, [fw_out, bw_out]) for fw_out, bw_out in zip(fw_outs, bw_outs)] 124 | else: 125 | raise NotImplementedError("Other compositions not implemented yet.") 126 | 127 | def name(self): 128 | return self.__class__.__name__ 129 | 130 | def _comp_f(self): 131 | """ 132 | Encodes all queries (including supporting queries) 133 | :return: encoded queries 134 | """ 135 | with tf.device("/cpu:0"): 136 | max_length = tf.cast(tf.reduce_max(self._length), tf.int32) 137 | context_t = tf.transpose(self._context) 138 | context_t = tf.slice(context_t, [0, 0], tf.pack([max_length, -1])) 139 | embedded = tf.nn.embedding_lookup(self.input_embedding, context_t) 140 | embedded = tf.nn.dropout(embedded, self.keep_prob) 141 | batch_size = tf.shape(self._context)[0] 142 | batch_size_32 = tf.reshape(batch_size, [1]) 143 | batch_size_64 = tf.cast(batch_size, tf.int64) 144 | 145 | with tf.device(self._device1): 146 | #use other device for backward rnn 147 | with tf.variable_scope("backward"): 148 | min_end = tf.segment_min(self._ends, self._span_context) 149 | init_state = tf.get_variable("init_state", [self._size], initializer=self._init) 150 | init_state = tf.reshape(tf.tile(init_state, batch_size_32), [-1, self._size]) 151 | rev_embedded = tf.reverse_sequence(embedded, self._length, 0, 1) 152 | # TIME-MAJOR: [T, B, S] 153 | outs_bw = self._composition_function(rev_embedded, self._length - min_end, init_state) 154 | # reshape to all possible queries for all sequences. Dim[0]=batch_size*(max_length+1). 155 | # "+1" because we include the initial state 156 | outs_bw = tf.reshape(tf.concat(0, [tf.expand_dims(init_state, 0), outs_bw]), [-1, self._size]) 157 | # gather respective queries via their lengths-start (because reversed sequence) 158 | lengths_aligned = tf.gather(self._length, self._span_context) 159 | out_bw = tf.gather(outs_bw, (lengths_aligned - self._ends) * batch_size_64 + self._span_context) 160 | 161 | with tf.device(self._device2): 162 | with tf.variable_scope("forward"): 163 | #e_inputs = [tf.reshape(e, [-1, self._size]) for e in tf.split(1, self._max_length, embedded)] 164 | max_start = tf.segment_max(self._starts, self._span_context) 165 | init_state = tf.get_variable("init_state", [self._size], initializer=self._init) 166 | init_state = tf.reshape(tf.tile(init_state, batch_size_32), [-1, self._size]) 167 | # TIME-MAJOR: [T, B, S] 168 | outs_fw = self._composition_function(embedded, max_start, init_state) 169 | # reshape to all possible queries for all sequences. Dim[0]=batch_size*(max_length+1). 170 | # "+1" because we include the initial state 171 | outs_fw = tf.reshape(tf.concat(0, [tf.expand_dims(init_state, 0), outs_fw]), [-1, self._size]) 172 | # gather respective queries via their positions (with offset of batch_size*ends) 173 | out_fw = tf.gather(outs_fw, self._starts * batch_size_64 + self._span_context) 174 | # form query from forward and backward compositions 175 | query = tf.contrib.layers.fully_connected(tf.concat(1, [out_fw, out_bw]), self._size, 176 | activation_fn=None, weights_initializer=None, biases_initializer=None) 177 | query = tf.add_n([query, out_bw, out_fw]) 178 | 179 | return query 180 | 181 | def set_train(self, sess): 182 | """ 183 | enables dropout 184 | :param sess: 185 | :return: 186 | """ 187 | sess.run(self.keep_prob.initializer) 188 | 189 | def set_eval(self, sess): 190 | """ 191 | removes dropout 192 | :param sess: 193 | :return: 194 | """ 195 | sess.run(self.keep_prob.assign(1.0)) 196 | 197 | def _retrieve_answer(self, query): 198 | """ 199 | Retrieves answer based on the specified query. Implements consecutive updates to the query and answer. 200 | :return: answer, if num_hops is 0, returns query itself 201 | """ 202 | query, supp_queries = tf.dynamic_partition(query, self._query_partition, 2) 203 | with tf.variable_scope("support"): 204 | num_queries = tf.shape(query)[0] 205 | 206 | with tf.device("/cpu:0"): 207 | _, supp_answer_output_ids = tf.dynamic_partition(self._answer_input, self._query_partition, 2) 208 | _, supp_answer_input_ids = tf.dynamic_partition(self._answer_word_input, self._query_partition, 2) 209 | supp_answers = tf.nn.embedding_lookup(self.output_embedding, supp_answer_output_ids) 210 | aligned_supp_answers = tf.gather(supp_answers, self._support_ids) # and with respective answers 211 | 212 | if self._max_hops > 1: 213 | # used in multihop 214 | answer_words = tf.nn.embedding_lookup(self.input_embedding, supp_answer_input_ids) 215 | aligned_answers_input = tf.gather(answer_words, self._support_ids) 216 | 217 | self.support_scores = [] 218 | query_as_answer = tf.contrib.layers.fully_connected(query, self._size, 219 | activation_fn=None, weights_initializer=None, 220 | biases_initializer=None, scope="query_to_answer") 221 | query_as_answer = query_as_answer * tf.sigmoid(tf.get_variable("query_as_answer_gate", tuple(), 222 | initializer=tf.constant_initializer(0.0))) 223 | current_answer = query_as_answer 224 | current_query = query 225 | 226 | aligned_support = tf.gather(supp_queries, self._support_ids) # align supp_queries with queries 227 | collab_support = tf.gather(query, self._collab_support_ids) # align supp_queries with queries 228 | aligned_support = tf.concat(0, [aligned_support, collab_support]) 229 | 230 | query_ids = tf.concat(0, [self._query_ids, self._collab_query_ids]) 231 | self.answer_weights = [] 232 | 233 | 234 | for i in range(self._max_hops): 235 | if i > 0: 236 | tf.get_variable_scope().reuse_variables() 237 | collab_queries = tf.gather(current_query, self._collab_query_ids) # align supp_queries with queries 238 | aligned_queries = tf.gather(current_query, self._query_ids) # align queries 239 | aligned_queries = tf.concat(0, [aligned_queries, collab_queries]) 240 | 241 | with tf.variable_scope("support_scores"): 242 | scores = tf_util.batch_dot(aligned_queries, aligned_support) 243 | self.support_scores.append(scores) 244 | score_max = tf.gather(tf.segment_max(scores, query_ids), query_ids) 245 | e_scores = tf.exp(scores - score_max) 246 | norm = tf.unsorted_segment_sum(e_scores, query_ids, num_queries) + 0.00001 # for zero norms 247 | norm = tf.expand_dims(norm, 1) 248 | e_scores = tf.expand_dims(e_scores, 1) 249 | 250 | with tf.variable_scope("support_answers"): 251 | aligned_supp_answers_with_collab = tf.concat(0, [aligned_supp_answers, collab_queries]) 252 | weighted_supp_answers = tf.unsorted_segment_sum(e_scores * aligned_supp_answers_with_collab, 253 | query_ids, num_queries) / norm 254 | 255 | with tf.variable_scope("support_queries"): 256 | weighted_supp_queries = tf.unsorted_segment_sum(e_scores * aligned_support, query_ids, num_queries) / norm 257 | 258 | with tf.variable_scope("answer_accumulation"): 259 | answer_p_max = tf.reduce_max(tf.nn.softmax(self._score_candidates(weighted_supp_answers)), [1], keep_dims=True) 260 | answer_weight = tf.contrib.layers.fully_connected(tf.concat(1, [query_as_answer * weighted_supp_answers, 261 | weighted_supp_queries * current_query, 262 | answer_p_max]), 263 | 1, 264 | activation_fn=tf.nn.sigmoid, 265 | weights_initializer=tf.constant_initializer(0.0), 266 | biases_initializer=tf.constant_initializer(0.0), 267 | scope="answer_weight") 268 | 269 | new_answer = answer_weight * weighted_supp_answers + current_answer 270 | 271 | # this condition allows for setting varying number of hops 272 | current_answer = tf.cond(tf.greater(self.num_hops, i), 273 | lambda: new_answer, 274 | lambda: current_answer) 275 | 276 | self.answer_weights.append(answer_weight) 277 | 278 | if i < self._max_hops - 1: 279 | with tf.variable_scope("query_update"): 280 | # prepare subsequent query 281 | aligned_answers_input_with_collab = tf.concat(0, [aligned_answers_input, collab_queries]) 282 | weighted_answer_words = tf.unsorted_segment_sum(e_scores * aligned_answers_input_with_collab, 283 | query_ids, num_queries) / norm 284 | 285 | c = tf.contrib.layers.fully_connected(tf.concat(1, [current_query, weighted_supp_queries, weighted_answer_words]), 286 | self._size, activation_fn=tf.tanh, scope="update_candidate", 287 | weights_initializer=None, biases_initializer=None) 288 | 289 | gate = tf.contrib.layers.fully_connected(tf.concat(1, [current_query, weighted_supp_queries]), 290 | self._size, activation_fn=tf.sigmoid, 291 | weights_initializer=None, scope="update_gate", 292 | biases_initializer=tf.constant_initializer(1)) 293 | current_query = gate * current_query + (1-gate) * c 294 | 295 | return current_answer 296 | 297 | def _init_inputs(self): 298 | #General 299 | with tf.device("/cpu:0"): 300 | self._context = tf.placeholder(tf.int64, shape=[None, self._max_length], name="context") 301 | self._answer_candidates = tf.placeholder(tf.int64, shape=[None, None], name="candidates") 302 | self._answer_input = tf.placeholder(tf.int64, shape=[None], name="answer") 303 | # answer word ids (index to E_embeddings) might differ from answer ids (input to E_candidates) 304 | self._answer_word_input = tf.placeholder(tf.int64, shape=[None], name="answer_word") 305 | self._starts = tf.placeholder(tf.int64, shape=[None], name="span_start") 306 | self._ends = tf.placeholder(tf.int64, shape=[None], name="span_end") 307 | # holds batch idx for respective span 308 | self._span_context = tf.placeholder(tf.int64, shape=[None], name="answer_position_context") 309 | self._candidate_mask = tf.placeholder(tf.float32, shape=[None, None], name="candidate_mask") 310 | self._length = tf.placeholder(tf.int64, shape=[None], name="context_length") 311 | 312 | self._ctxt = np.zeros([self._batch_size, self._max_length], dtype=np.int64) 313 | self._len = np.zeros([self._batch_size], dtype=np.int64) 314 | 315 | #Supporting Evidence 316 | # partition of queries (class 0) and support (class 1) 317 | self._query_partition = tf.placeholder(tf.int32, [None], "query_partition") 318 | # aligned support ids with query ids for supporting evidence 319 | self._support_ids = tf.placeholder(tf.int64, shape=[None], name="support_for_query_ids") 320 | self._collab_support_ids = tf.placeholder(tf.int64, shape=[None], name="collab_supp_ids") 321 | self._query_ids = tf.placeholder(tf.int64, shape=[None], name="query_for_support_ids") 322 | self._collab_query_ids = tf.placeholder(tf.int64, shape=[None], name="collab_query_ids") 323 | 324 | self._feed_dict = {} 325 | 326 | def _change_batch_size(self, batch_size): 327 | new_ctxt_in = np.zeros([batch_size, self._max_length], dtype=np.int64) 328 | new_ctxt_in[:self._batch_size] = self._ctxt 329 | self._ctxt = new_ctxt_in 330 | 331 | new_length = np.zeros([batch_size], dtype=np.int64) 332 | new_length[:self._batch_size] = self._len 333 | self._len = new_length 334 | 335 | self._batch_size = batch_size 336 | 337 | def _start_adding_examples(self): 338 | self._batch_idx = 0 339 | self._query_idx = 0 340 | self._support_idx = 0 341 | self._answer_cands = [] 342 | self._answer_in = [] 343 | self._answer_word_in = [] 344 | self._s = [] 345 | self._e = [] 346 | self._span_ctxt = [] 347 | # supporting evidence 348 | self._query_part = [] 349 | self.queries_for_support = [] 350 | self.support_for_queries = [] 351 | self._collab_queries = [] 352 | self._collab_support = [] 353 | 354 | self.supporting_qa = [] 355 | 356 | def _add_example(self, context_queries, is_query=True): 357 | ''' 358 | All queries and supporting queries are encoded the same. However we keep track of which are queries, 359 | which are support and how they belong to each other via partition variables and aligned support to query 360 | and query to support ids. 361 | :param context_queries: contains all queries about a particular context, see ContextQueries in query.py 362 | :param is_query: True if this is query, False if this is support 363 | :return: 364 | ''' 365 | assert is_query or context_queries.support is None, "Support cannot have support!" 366 | if self._batch_idx >= self._batch_size: 367 | self._change_batch_size(max(self._batch_size*2, self._batch_idx)) 368 | self._ctxt[self._batch_idx][:len(context_queries.context)] = context_queries.context 369 | self._len[self._batch_idx] = len(context_queries.context) 370 | 371 | batch_idx = self._batch_idx 372 | self._batch_idx += 1 373 | for i, q in enumerate(context_queries.queries): 374 | self._s.append(q.start) 375 | self._e.append(q.end) 376 | self._span_ctxt.append(batch_idx) 377 | self._answer_in.append(q.answer if q.answer is not None else q.candidates[0]) 378 | self._answer_word_in.append(q.answer_word) 379 | cands = [q.answer] if q.answer is not None else [] 380 | if q.candidates is not None: 381 | cands.extend(c for c in q.candidates if c != q.answer) 382 | self._answer_cands.append(cands) 383 | self._query_part.append(0 if is_query else 1) 384 | 385 | if is_query: 386 | if context_queries.collaborative_support: 387 | # save queries also as support, only with different query_partition index (1 for support) 388 | for i in range(len(context_queries.queries)): 389 | for j in range(len(context_queries.queries)): 390 | if j != i: 391 | self._collab_queries.append(self._query_idx+i) 392 | self._collab_support.append(self._query_idx+j) 393 | 394 | ### add query specific supports ### 395 | for i, q in enumerate(context_queries.queries): 396 | if q.support is not None and self._max_hops > 0: 397 | for qs in q.support: 398 | if qs.context is None: 399 | #supporting context is the same as query context, only add corresponding positions 400 | for q in qs.queries: 401 | self._s.append(q.start) 402 | self._e.append(q.end) 403 | self._span_ctxt.append(batch_idx) 404 | self._answer_in.append(q.answer) 405 | self._answer_word_in.append(q.answer_word) 406 | self._answer_cands.append([q.answer]) 407 | self._query_part.append(1) 408 | self.supporting_qa.append((q.context, q.start, q.end, q.answer)) 409 | else: 410 | self._add_example(qs, is_query=False) 411 | # align queries with support idxs 412 | self.support_for_queries.extend(range(self._support_idx, self._support_idx+len(qs.queries))) 413 | self._support_idx += len(qs.queries) 414 | self.queries_for_support.extend([self._query_idx] * len(qs.queries)) 415 | self._query_idx += 1 416 | 417 | ### add context specific support to all queries of this context ### 418 | if context_queries.support is not None and self._max_hops > 0: 419 | for qs in context_queries.support: 420 | if qs.context is None: 421 | for q in qs.queries: 422 | self._s.append(q.start) 423 | self._e.append(q.end) 424 | self._span_ctxt.append(batch_idx) 425 | self._answer_in.append(q.answer) 426 | self._answer_word_in.append(q.answer_word) 427 | self._answer_cands.append([q.answer]) 428 | self._query_part.append(1) 429 | self.supporting_qa.append((q.context, q.start, q.end, q.answer)) 430 | else: 431 | self._add_example(qs, is_query=False) 432 | # this evidence supports all queries in this context 433 | for i, _ in enumerate(context_queries.queries): 434 | # align queries with support idxs 435 | self.support_for_queries.extend(range(self._support_idx, self._support_idx+len(qs.queries))) 436 | self.queries_for_support.extend([self._query_idx - len(context_queries.queries) + i] * len(qs.queries)) 437 | self._support_idx += len(qs.queries) 438 | else: 439 | for i, q in enumerate(context_queries.queries): 440 | self.supporting_qa.append((q.context, q.start, q.end, q.answer)) 441 | 442 | def _finish_adding_examples(self): 443 | max_cands = max((len(x) for x in self._answer_cands)) 444 | # mask is used to determine which candidates are real candidates and which are dummies, 445 | # number of candidates can vary from query to query within a batch 446 | cand_mask = [] 447 | for i in range(len(self._answer_cands)): 448 | l = len(self._answer_cands[i]) 449 | if self._query_part[i] == 0: # if this is a query (and not supporting evidence) 450 | mask = [0] * l 451 | for _ in range(max_cands - l): 452 | self._answer_cands[i].append(self._answer_cands[i][0]) # dummy 453 | if self._query_part[i] == 0: 454 | mask.append(-1e6) # this is added to scores, serves basically as a bias mask to exclude dummy negative candidates 455 | if self._query_part[i] == 0: 456 | cand_mask.append(mask) 457 | 458 | if self._batch_idx < self._batch_size: 459 | self._feed_dict[self._context] = self._ctxt[:self._batch_idx] 460 | self._feed_dict[self._length] = self._len[:self._batch_idx] 461 | else: 462 | self._feed_dict[self._context] = self._ctxt 463 | self._feed_dict[self._length] = self._len 464 | self._feed_dict[self._starts] = self._s 465 | self._feed_dict[self._ends] = self._e 466 | self._feed_dict[self._span_context] = self._span_ctxt 467 | self._feed_dict[self._answer_input] = self._answer_in 468 | self._feed_dict[self._answer_word_input] = self._answer_word_in 469 | self._feed_dict[self._answer_candidates] = self._answer_cands 470 | self._feed_dict[self._candidate_mask] = cand_mask 471 | self._feed_dict[self._query_ids] = self.queries_for_support 472 | self._feed_dict[self._support_ids] = self.support_for_queries 473 | self._feed_dict[self._collab_query_ids] = self._collab_queries 474 | self._feed_dict[self._collab_support_ids] = self._collab_support 475 | self._feed_dict[self._query_partition] = self._query_part 476 | 477 | def get_feed_dict(self): 478 | return self._feed_dict 479 | 480 | def step(self, sess, queries, mode="update"): 481 | ''' 482 | :param sess: 483 | :param queries: list of ContextQueries 484 | :param mode: "loss" for loss, else performs update on parameters 485 | :return: 486 | ''' 487 | assert self._is_train, "model has to be created in training mode!" 488 | if mode == "loss": 489 | return self.run(sess, self._loss, queries) 490 | else: 491 | return self.run(sess, [self._loss, self._update], queries)[0] 492 | 493 | def run(self, sess, to_run, queries): 494 | ''' 495 | :param sess: 496 | :param to_run: target(s) to run, e.g. : 497 | * self.num_hops, 498 | * self.query, 499 | * self.score (only score for provided answers), 500 | * self.score_with_negs (scores of all candidates where score[0] is score of the answer) 501 | * self.input_embedding, 502 | * self.output_embedding, 503 | * self.support_scores (match-scores for all support with query for each hop. Aligns with 504 | self.supporting_qa which keeps track of all support QA-pairs and self.queries_for_support which 505 | defines batch_idx of query for all supporting_qa),, 506 | * self.answer_weights (weight used to accumulate retrieved answer in each hop) 507 | 508 | answer = self._retrieve_answer(self.query) 509 | self.score = tf_util.batch_dot(lookup_individual, answer) 510 | self.scores_with_negs = self._score_candidates(answer) 511 | 512 | if is_train: 513 | self.learning_rate = tf.Variable(float(learning_rate), trainable=False, name="lr") 514 | self.global_step = tf.Variable(0, trainable=False, name="step") 515 | 516 | :param queries: list of ContextQueries 517 | :return: 518 | ''' 519 | batch_size = len(queries) 520 | self._start_adding_examples() 521 | num_batch_queries = 0 522 | for batch_idx in range(batch_size): 523 | context_query = queries[batch_idx] 524 | num_batch_queries += len(context_query.queries) 525 | self._add_example(context_query) 526 | self._finish_adding_examples() 527 | 528 | return sess.run(to_run, feed_dict=self.get_feed_dict()) 529 | 530 | 531 | def test_model(): 532 | 533 | model = QAModel(10, 4, 5, 5, 5, max_hops=2) 534 | # 3 contexts (of length 3) with queries at 2/1/2 (totaling 5) positions 535 | # and respective negative candidates for each position 536 | contexts = [[0, 1, 2] , [1, 2, 0], [0, 2, 1]] # 4 => placeholder for prediction position 537 | 538 | support = [ContextQueries(contexts[0], [ClozeQuery(contexts[0], 0,1,0,0,[2,1]), 539 | ClozeQuery(contexts[0], 2,3,2,2,[0,1])]), 540 | ContextQueries(contexts[1], [ClozeQuery(contexts[1], 1,2,2,2,[0,1])])] 541 | 542 | queries = [ContextQueries(contexts[0], [ClozeQuery(contexts[0], 0,1,0,0,[2,1], support=support), 543 | ClozeQuery(contexts[0], 2,3,2,2,[0,1], support=support)]), 544 | ContextQueries(contexts[1], [ClozeQuery(contexts[1], 1,2,2,2,[0,1], support=support)]), 545 | ContextQueries(contexts[2], [ClozeQuery(contexts[2], 1, 2, 1, 1, [0,2], support=support), 546 | ClozeQuery(contexts[2], 2,3,2,2,[0,1], support=support)])] 547 | 548 | with tf.Session() as sess: 549 | sess.run(tf.initialize_all_variables()) 550 | sess.run(model.num_hops.assign(1)) 551 | print("Test update ...") 552 | for i in range(10): 553 | print("Loss: %.3f" % 554 | model.step(sess, queries)[0]) 555 | print("Test scoring ...") 556 | print(model.run(sess, model.scores_with_negs, queries)) 557 | print("Done") 558 | 559 | 560 | if __name__ == '__main__': 561 | test_model() 562 | 563 | 564 | """ 565 | Test update ... 566 | Loss: 1.100 567 | Loss: 1.085 568 | Loss: 1.071 569 | Loss: 1.054 570 | Loss: 1.034 571 | Loss: 1.010 572 | Loss: 0.978 573 | Loss: 0.941 574 | Loss: 0.898 575 | Loss: 0.852 576 | Test scoring ... 577 | [[-0.04599455 0.40225014 -0.24225023] 578 | [ 1.09410524 -0.54674953 -0.43304446] 579 | [ 0.82218146 -0.37236962 -0.33791351] 580 | [-0.31010905 -0.36932546 0.79180872] 581 | [ 1.00475466 -0.49434289 -0.38799521]] 582 | Done 583 | """ 584 | --------------------------------------------------------------------------------