628 | and the mask would be 1 1 1 1 1 0 0
629 | and let the logits be l1 l2 l3 l4 l5 l6
630 | We actually need to compare:
631 | the sequence w1 w2 w3
632 | with masks 1 1 1 1 0 0
633 | against l1 l2 l3 l4 l5 l6
634 | (where the input was) w1 w2 w3
635 | """
636 | # shape: (batch_size, num_decoding_steps)
637 | relevant_targets = targets[:, 1:].contiguous()
638 |
639 | # shape: (batch_size, num_decoding_steps)
640 | relevant_mask = target_mask[:, 1:].contiguous()
641 |
642 | return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)
643 |
644 | @overrides
645 | def get_metrics(self, reset: bool = False) -> Dict[str, float]:
646 | all_metrics: Dict[str, float] = {}
647 | if self._bleu and not self.training:
648 | all_metrics.update(self._bleu.get_metric(reset=reset))
649 | all_metrics.update({'acc': self._acc.get_metric(reset=reset)})
650 | all_metrics.update({'no_result': self._no_result.get_metric(reset=reset)})
651 |
652 | return all_metrics
--------------------------------------------------------------------------------
/NGS_Aux_test.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple
2 |
3 | import numpy
4 | from overrides import overrides
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 | from torch.nn.modules.linear import Linear
9 | from torch.nn.modules.rnn import LSTMCell
10 |
11 | from allennlp.common.checks import ConfigurationError
12 | from allennlp.common.util import START_SYMBOL, END_SYMBOL
13 | from allennlp.data.vocabulary import Vocabulary
14 | from allennlp.modules.attention import LegacyAttention
15 | from allennlp.modules import Attention, TextFieldEmbedder, Seq2SeqEncoder
16 | from allennlp.modules.similarity_functions import SimilarityFunction
17 | from allennlp.models.model import Model
18 | from allennlp.modules.token_embedders import Embedding
19 | from allennlp.nn import util
20 | from allennlp.nn.beam_search import BeamSearch
21 | from allennlp.training.metrics import BLEU
22 |
23 | from ManualProgram.eval_equ import Equations
24 |
25 | import random
26 | import warnings
27 | import math
28 | warnings.filterwarnings("ignore")
29 |
30 | from utils import *
31 |
32 | from mcan import *
33 |
34 | import json
35 |
36 |
37 | @Model.register("geo_s2s")
38 | class SimpleSeq2Seq(Model):
39 | """
40 | This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then
41 | uses the encoded representations to decode another sequence. You can use this as the basis for
42 | a neural machine translation system, an abstractive summarization system, or any other common
43 | seq2seq problem. The model here is simple, but should be a decent starting place for
44 | implementing recent models for these tasks.
45 |
46 | Parameters
47 | ----------
48 | vocab : ``Vocabulary``, required
49 | Vocabulary containing source and target vocabularies. They may be under the same namespace
50 | (`tokens`) or the target tokens can have a different namespace, in which case it needs to
51 | be specified as `target_namespace`.
52 | source_embedder : ``TextFieldEmbedder``, required
53 | Embedder for source side sequences
54 | encoder : ``Seq2SeqEncoder``, required
55 | The encoder of the "encoder/decoder" model
56 | max_decoding_steps : ``int``
57 | Maximum length of decoded sequences.
58 | target_namespace : ``str``, optional (default = 'tokens')
59 | If the target side vocabulary is different from the source side's, you need to specify the
60 | target's namespace here. If not, we'll assume it is "tokens", which is also the default
61 | choice for the source side, and this might cause them to share vocabularies.
62 | target_embedding_dim : ``int``, optional (default = source_embedding_dim)
63 | You can specify an embedding dimensionality for the target side. If not, we'll use the same
64 | value as the source embedder's.
65 | attention : ``Attention``, optional (default = None)
66 | If you want to use attention to get a dynamic summary of the encoder outputs at each step
67 | of decoding, this is the function used to compute similarity between the decoder hidden
68 | state and encoder outputs.
69 | attention_function: ``SimilarityFunction``, optional (default = None)
70 | This is if you want to use the legacy implementation of attention. This will be deprecated
71 | since it consumes more memory than the specialized attention modules.
72 | beam_size : ``int``, optional (default = None)
73 | Width of the beam for beam search. If not specified, greedy decoding is used.
74 | scheduled_sampling_ratio : ``float``, optional (default = 0.)
75 | At each timestep during training, we sample a random number between 0 and 1, and if it is
76 | not less than this value, we use the ground truth labels for the whole batch. Else, we use
77 | the predictions from the previous time step for the whole batch. If this value is 0.0
78 | (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
79 | using target side ground truth labels. See the following paper for more information:
80 | `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
81 | 2015 `_.
82 | use_bleu : ``bool``, optional (default = True)
83 | If True, the BLEU metric will be calculated during validation.
84 | """
85 |
86 | def __init__(self,
87 | vocab: Vocabulary,
88 | source_embedder: TextFieldEmbedder,
89 | encoder: Seq2SeqEncoder,
90 | max_decoding_steps: int,
91 | knowledge_points_ratio = 0,
92 | attention: Attention = True,
93 | attention_function: SimilarityFunction = None,
94 | beam_size: int = None,
95 | target_namespace: str = "tokens",
96 | target_embedding_dim: int = None,
97 | scheduled_sampling_ratio: float = 0.,
98 | resnet_pretrained = None,
99 | use_bleu: bool = True) -> None:
100 | super(SimpleSeq2Seq, self).__init__(vocab)
101 |
102 | resnet = build_model()
103 |
104 | if resnet_pretrained is not None:
105 | resnet.load_state_dict(torch.load(resnet_pretrained))
106 | print('##### Checkpoint Loaded! #####')
107 | else:
108 | print("No Diagram Pretrain !!!")
109 | self.resnet = resnet
110 |
111 | self.channel_transform = torch.nn.Linear(1024, 512)
112 |
113 | __C = Cfgs()
114 | self.mcan = MCA_ED(__C)
115 | self.attflat_img = AttFlat(__C)
116 | self.attflat_lang = AttFlat(__C) # not use
117 |
118 | self.decode_transform = torch.nn.Linear(1024, 512)
119 |
120 | self._equ = Equations()
121 |
122 | self._target_namespace = target_namespace
123 | self._scheduled_sampling_ratio = scheduled_sampling_ratio
124 |
125 | # We need the start symbol to provide as the input at the first timestep of decoding, and
126 | # end symbol as a way to indicate the end of the decoded sequence.
127 | self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
128 | self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
129 |
130 | if use_bleu:
131 | pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access
132 | self._bleu = BLEU(ngram_weights=(1, 0, 0, 0), exclude_indices={pad_index, self._end_index, self._start_index})
133 | else:
134 | self._bleu = None
135 | self._acc = Average()
136 | self._no_result = Average()
137 |
138 | # remember to clear after evaluation
139 | self.new_acc = []
140 | self.angle = []
141 | self.length = []
142 | self.area = []
143 | self.other = []
144 | self.point_acc_list = []
145 | self.save_results = dict()
146 |
147 | # At prediction time, we use a beam search to find the most likely sequence of target tokens.
148 | beam_size = beam_size or 1
149 | self._max_decoding_steps = max_decoding_steps
150 | self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size)
151 |
152 | # Dense embedding of source vocab tokens.
153 | self._source_embedder = source_embedder
154 |
155 | # Encodes the sequence of source embeddings into a sequence of hidden states.
156 | self._encoder = encoder
157 |
158 | num_classes = self.vocab.get_vocab_size(self._target_namespace)
159 |
160 | # Attention mechanism applied to the encoder output for each step.
161 | # TODO: attention
162 | if attention:
163 | if attention_function:
164 | raise ConfigurationError("You can only specify an attention module or an "
165 | "attention function, but not both.")
166 | self._attention = LegacyAttention()
167 | elif attention_function:
168 | self._attention = LegacyAttention(attention_function)
169 | else:
170 | self._attention = None
171 | print("No Attention!")
172 | exit()
173 |
174 | # Dense embedding of vocab words in the target space.
175 | target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
176 | self._target_embedder = Embedding(num_classes, target_embedding_dim)
177 |
178 | # Decoder output dim needs to be the same as the encoder output dim since we initialize the
179 | # hidden state of the decoder with the final hidden state of the encoder.
180 | self._encoder_output_dim = self._encoder.get_output_dim()
181 | self._decoder_output_dim = self._encoder_output_dim
182 |
183 | if self._attention:
184 | # If using attention, a weighted average over encoder outputs will be concatenated
185 | # to the previous target embedding to form the input to the decoder at each
186 | # time step.
187 | self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
188 | else:
189 | # Otherwise, the input to the decoder is just the previous target embedding.
190 | self._decoder_input_dim = target_embedding_dim
191 |
192 | # We'll use an LSTM cell as the recurrent cell that produces a hidden state
193 | # for the decoder at each time step.
194 | self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)
195 |
196 | # We project the hidden state from the decoder into the output vocabulary space
197 | # in order to get log probabilities of each target token, at each time step.
198 | self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
199 |
200 | # knowledge points
201 | self.point_ratio = knowledge_points_ratio
202 | if self.point_ratio != 0:
203 | self.points_norm = LayerNorm(__C.FLAT_OUT_SIZE)
204 | self.points_proj = nn.Linear(__C.FLAT_OUT_SIZE, 50)
205 | self.points_criterion = nn.BCELoss()
206 |
207 | def take_step(self,
208 | last_predictions: torch.Tensor,
209 | state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
210 | """
211 | Take a decoding step. This is called by the beam search class.
212 |
213 | Parameters
214 | ----------
215 | last_predictions : ``torch.Tensor``
216 | A tensor of shape ``(group_size,)``, which gives the indices of the predictions
217 | during the last time step.
218 | state : ``Dict[str, torch.Tensor]``
219 | A dictionary of tensors that contain the current state information
220 | needed to predict the next step, which includes the encoder outputs,
221 | the source mask, and the decoder hidden state and context. Each of these
222 | tensors has shape ``(group_size, *)``, where ``*`` can be any other number
223 | of dimensions.
224 |
225 | Returns
226 | -------
227 | Tuple[torch.Tensor, Dict[str, torch.Tensor]]
228 | A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
229 | is a tensor of shape ``(group_size, num_classes)`` containing the predicted
230 | log probability of each class for the next step, for each item in the group,
231 | while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
232 | source mask, and updated decoder hidden state and context.
233 |
234 | Notes
235 | -----
236 | We treat the inputs as a batch, even though ``group_size`` is not necessarily
237 | equal to ``batch_size``, since the group may contain multiple states
238 | for each source sentence in the batch.
239 | """
240 | # shape: (group_size, num_classes)
241 | output_projections, state = self._prepare_output_projections(last_predictions, state)
242 |
243 | # shape: (group_size, num_classes)
244 | class_log_probabilities = F.log_softmax(output_projections, dim=-1)
245 |
246 | return class_log_probabilities, state
247 |
248 | @overrides
249 | def forward(self, # type: ignore
250 | image, source_nums, choice_nums, label, type, data_id, manual_program,
251 | source_tokens: Dict[str, torch.LongTensor],
252 | point_label = None,
253 | target_tokens: Dict[str, torch.LongTensor] = None, **kwargs) -> Dict[str, torch.Tensor]:
254 | # pylint: disable=arguments-differ
255 | """
256 | Make foward pass with decoder logic for producing the entire target sequence.
257 |
258 | Parameters
259 | ----------
260 | source_tokens : ``Dict[str, torch.LongTensor]``
261 | The output of `TextField.as_array()` applied on the source `TextField`. This will be
262 | passed through a `TextFieldEmbedder` and then through an encoder.
263 | target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
264 | Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
265 | target tokens are also represented as a `TextField`.
266 |
267 | Returns
268 | -------
269 | Dict[str, torch.Tensor]
270 | """
271 | bs = len(label)
272 | state = self._encode(source_tokens)
273 |
274 | with torch.no_grad():
275 | img_feats = self.resnet(image)
276 | # (N, C, 14, 14) -> (N, 196, C)
277 | img_feats = img_feats.reshape(img_feats.shape[0], img_feats.shape[1], -1).transpose(1, 2)
278 | img_mask = make_mask(img_feats)
279 | img_feats = self.channel_transform(img_feats)
280 |
281 | lang_feats = state['encoder_outputs']
282 | # mask the digital encoding question without embedding, i.e. source_tokens(already index to number)
283 | lang_mask = make_mask(source_tokens['tokens'].unsqueeze(2))
284 |
285 | _, img_feats = self.mcan(lang_feats, img_feats, lang_mask, img_mask)
286 |
287 | # (N, 308, 512)
288 | # for attention, image first and then lang, using mask
289 | state['encoder_outputs'] = torch.cat([img_feats, lang_feats], 1)
290 |
291 | # decode
292 | state = self._init_decoder_state(state, lang_feats, img_feats, img_mask)
293 | output_dict = self._forward_loop(state, target_tokens) # recurrent decoding for LSTM
294 |
295 | # knowledge points
296 | if self.point_ratio != 0:
297 | concat_feature = state["concat_feature"]
298 | point_feat = self.points_norm(concat_feature)
299 | point_feat = self.points_proj(point_feat)
300 | point_pred = torch.sigmoid(point_feat)
301 | point_loss = self.points_criterion(point_pred, point_label) * self.point_ratio
302 | output_dict["point_pred"] = point_pred
303 | output_dict["point_loss"] = point_loss
304 | output_dict["loss"] += point_loss
305 |
306 |
307 | # if testing, beam search and evaluation
308 | if not self.training:
309 | # state = self._init_decoder_state(state)
310 | state = self._init_decoder_state(state, lang_feats, img_feats, img_mask) # TODO
311 | predictions = self._forward_beam_search(state)
312 | output_dict.update(predictions)
313 |
314 | if target_tokens and self._bleu:
315 | # shape: (batch_size, beam_size, max_sequence_length)
316 | top_k_predictions = output_dict["predictions"]
317 |
318 | # execute the decode programs to calculate the accuracy
319 | suc_knt, no_knt, = 0, 0
320 |
321 | selected_programs = []
322 |
323 | for b in range(bs):
324 | hypo = None
325 | used_hypo = None
326 | choice = None
327 | for i in range(self._beam_search.beam_size):
328 | if choice is not None:
329 | break
330 | hypo = list(top_k_predictions[b][i])
331 | if self._end_index in list(hypo):
332 | hypo = hypo[:hypo.index(self._end_index)]
333 | hypo = [self.vocab.get_token_from_index(idx.item()) for idx in hypo]
334 | res = self._equ.excuate_equation(hypo, source_nums[b])
335 | if res is not None and len(res) > 0:
336 | for j in range(4):
337 | if choice_nums[b][j] is not None and math.fabs(res[-1] - choice_nums[b][j]) < 0.001:
338 | choice = j
339 | used_hypo = hypo
340 | selected_programs.append([hypo])
341 |
342 | if choice is None:
343 | no_knt += 1
344 | answer_state = 'no_result'
345 | self.new_acc.append(0)
346 | elif choice == label[b]:
347 | suc_knt += 1
348 | answer_state = 'right'
349 | self.new_acc.append(1)
350 | else:
351 | answer_state = 'false'
352 | self.new_acc.append(0)
353 |
354 | self.save_results[data_id[b]] = dict(manual_program=manual_program[b],
355 | predict_program=hypo, predict_res=res,
356 | choice=choice_nums[b], right_answer=label[b],
357 | answer_state=answer_state)
358 |
359 | flag = 1 if choice == label[b] else 0
360 | if type[b] == 'angle':
361 | self.angle.append(flag)
362 | elif type[b] == 'length':
363 | self.length.append(flag)
364 | else:
365 | self.other.append(flag)
366 |
367 | # knowledge points
368 | # if self.point_ratio != 0:
369 | # point_acc = self.multi_label_evaluation(point_pred[b].unsqueeze(0), point_label[b].unsqueeze(0))
370 | # self.point_acc_list.append(point_acc)
371 |
372 | # with open('save/test.json', 'w') as f:
373 | # json.dump(self.save_results, f)
374 |
375 | if random.random() < 0.05:
376 | print('selected_programs', selected_programs)
377 |
378 | # calculate BLEU
379 | best_predictions = top_k_predictions[:, 0, :]
380 | self._bleu(best_predictions, target_tokens["tokens"])
381 | self._acc(suc_knt / bs)
382 | self._no_result(no_knt / bs)
383 |
384 | return output_dict
385 |
386 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
387 | """
388 | Finalize predictions.
389 |
390 | This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
391 | time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
392 | within the ``forward`` method.
393 |
394 | This method trims the output predictions to the first end symbol, replaces indices with
395 | corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
396 | """
397 | predicted_indices = output_dict["predictions"]
398 | if not isinstance(predicted_indices, numpy.ndarray):
399 | predicted_indices = predicted_indices.detach().cpu().numpy()
400 | all_predicted_tokens = []
401 | for indices in predicted_indices:
402 | # Beam search gives us the top k results for each source sentence in the batch
403 | # but we just want the single best.
404 | if len(indices.shape) > 1:
405 | indices = indices[0]
406 | indices = list(indices)
407 | # Collect indices till the first end_symbol
408 | if self._end_index in indices:
409 | indices = indices[:indices.index(self._end_index)]
410 | predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace)
411 | for x in indices]
412 | all_predicted_tokens.append(predicted_tokens)
413 | output_dict["predicted_tokens"] = all_predicted_tokens
414 | return output_dict
415 |
416 | def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
417 | # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
418 | embedded_input = self._source_embedder(source_tokens)
419 | # shape: (batch_size, max_input_sequence_length)
420 | source_mask = util.get_text_field_mask(source_tokens)
421 |
422 | img_mask = torch.ones(source_mask.shape[0], 196).long().cuda()
423 | concat_mask = torch.cat([img_mask, source_mask], 1)
424 |
425 | # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
426 | encoder_outputs = self._encoder(embedded_input, source_mask)
427 | return {
428 | "source_mask": source_mask, # source_mask,
429 | "concat_mask": concat_mask,
430 | "encoder_outputs": encoder_outputs,
431 | }
432 |
433 | def _init_decoder_state(self, state, lang_feats, img_feats, img_mask):
434 |
435 | batch_size = state["source_mask"].size(0)
436 | final_lang_feat = util.get_final_encoder_states(
437 | lang_feats,
438 | state["source_mask"],
439 | self._encoder.is_bidirectional())
440 | img_feat = self.attflat_img(img_feats, img_mask)
441 | feat = torch.cat([final_lang_feat, img_feat], 1)
442 | feat = self.decode_transform(feat)
443 | state["concat_feature"] = feat
444 |
445 | state["decoder_hidden"] = feat
446 | # C0 shape: (batch_size, decoder_output_dim)
447 | state["decoder_context"] = torch.zeros(batch_size, self._decoder_output_dim).cuda()
448 | # state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim)
449 | return state
450 |
451 | def _forward_loop(self,
452 | state: Dict[str, torch.Tensor],
453 | target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
454 | """
455 | Make forward pass during training or do greedy search during prediction.
456 |
457 | Notes
458 | -----
459 | We really only use the predictions from the method to test that beam search
460 | with a beam size of 1 gives the same results.
461 | """
462 | # shape: (batch_size, max_input_sequence_length)
463 | source_mask = state["source_mask"]
464 |
465 | batch_size = source_mask.size()[0]
466 |
467 | if target_tokens:
468 | # shape: (batch_size, max_target_sequence_length)
469 | targets = target_tokens["tokens"]
470 |
471 | _, target_sequence_length = targets.size()
472 |
473 | # The last input from the target is either padding or the end symbol.
474 | # Either way, we don't have to process it.
475 | num_decoding_steps = target_sequence_length - 1
476 | else:
477 | num_decoding_steps = self._max_decoding_steps
478 |
479 | # Initialize target predictions with the start index.
480 | # shape: (batch_size,)
481 | last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index)
482 |
483 | step_logits: List[torch.Tensor] = []
484 | step_predictions: List[torch.Tensor] = []
485 | for timestep in range(num_decoding_steps):
486 | if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
487 | # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
488 | # during training.
489 | # shape: (batch_size,)
490 | input_choices = last_predictions
491 | elif not target_tokens:
492 | # shape: (batch_size,)
493 | input_choices = last_predictions
494 | else:
495 | # shape: (batch_size,)
496 | input_choices = targets[:, timestep]
497 |
498 | # shape: (batch_size, num_classes)
499 | # recurrent decoding
500 | output_projections, state = self._prepare_output_projections(input_choices, state)
501 |
502 | # list of tensors, shape: (batch_size, 1, num_classes)
503 | step_logits.append(output_projections.unsqueeze(1))
504 |
505 | # shape: (batch_size, num_classes)
506 | class_probabilities = F.softmax(output_projections, dim=-1)
507 |
508 | # shape (predicted_classes): (batch_size,)
509 | _, predicted_classes = torch.max(class_probabilities, 1)
510 |
511 | # shape (predicted_classes): (batch_size,)
512 | last_predictions = predicted_classes
513 |
514 | step_predictions.append(last_predictions.unsqueeze(1))
515 |
516 | # shape: (batch_size, num_decoding_steps)
517 | predictions = torch.cat(step_predictions, 1)
518 |
519 | output_dict = {"predictions": predictions}
520 |
521 | if target_tokens:
522 | # shape: (batch_size, num_decoding_steps, num_classes)
523 | logits = torch.cat(step_logits, 1)
524 |
525 | # Compute loss.
526 | target_mask = util.get_text_field_mask(target_tokens)
527 | loss = self._get_loss(logits, targets, target_mask)
528 | output_dict["loss"] = loss
529 |
530 | return output_dict
531 |
532 | def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
533 | """Make forward pass during prediction using a beam search."""
534 | batch_size = state["source_mask"].size()[0]
535 | start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index)
536 |
537 | # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
538 | # shape (log_probabilities): (batch_size, beam_size)
539 | all_top_k_predictions, log_probabilities = self._beam_search.search(
540 | start_predictions, state, self.take_step)
541 |
542 | output_dict = {
543 | "class_log_probabilities": log_probabilities,
544 | "predictions": all_top_k_predictions,
545 | }
546 | return output_dict
547 |
548 | def _prepare_output_projections(self,
549 | last_predictions: torch.Tensor,
550 | state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long
551 | """
552 | Decode current state and last prediction to produce produce projections
553 | into the target space, which can then be used to get probabilities of
554 | each target token for the next step.
555 |
556 | Inputs are the same as for `take_step()`.
557 | """
558 | # shape: (group_size, max_input_sequence_length, encoder_output_dim)
559 | encoder_outputs = state["encoder_outputs"]
560 |
561 | # shape: (group_size, max_input_sequence_length)
562 | # source_mask = state["source_mask"]
563 | source_mask = state["concat_mask"]
564 |
565 | # decoder_hidden and decoder_context are get from encoder_outputs in _init_decoder_state()
566 | # shape: (group_size, decoder_output_dim)
567 | decoder_hidden = state["decoder_hidden"]
568 | # shape: (group_size, decoder_output_dim)
569 | decoder_context = state["decoder_context"]
570 |
571 | # shape: (group_size, target_embedding_dim)
572 | embedded_input = self._target_embedder(last_predictions)
573 |
574 | if self._attention:
575 | # shape: (group_size, encoder_output_dim)
576 | attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask)
577 |
578 | # shape: (group_size, decoder_output_dim + target_embedding_dim)
579 | decoder_input = torch.cat((attended_input, embedded_input), -1)
580 |
581 | else:
582 | # shape: (group_size, target_embedding_dim)
583 | decoder_input = embedded_input
584 |
585 | # shape (decoder_hidden): (batch_size, decoder_output_dim)
586 | # shape (decoder_context): (batch_size, decoder_output_dim)
587 | decoder_hidden, decoder_context = self._decoder_cell(
588 | decoder_input,
589 | (decoder_hidden, decoder_context))
590 |
591 | state["decoder_hidden"] = decoder_hidden
592 | state["decoder_context"] = decoder_context
593 |
594 | # shape: (group_size, num_classes)
595 | output_projections = self._output_projection_layer(decoder_hidden)
596 |
597 | return output_projections, state
598 |
599 | def _prepare_attended_input(self,
600 | decoder_hidden_state: torch.LongTensor = None,
601 | encoder_outputs: torch.LongTensor = None,
602 | encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
603 | """Apply attention over encoder outputs and decoder state."""
604 | # Ensure mask is also a FloatTensor. Or else the multiplication within
605 | # attention will complain.
606 | # shape: (batch_size, max_input_sequence_length)
607 | encoder_outputs_mask = encoder_outputs_mask.float()
608 |
609 | # shape: (batch_size, max_input_sequence_length)
610 | input_weights = self._attention(
611 | decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
612 |
613 | # shape: (batch_size, encoder_output_dim)
614 | attended_input = util.weighted_sum(encoder_outputs, input_weights)
615 |
616 | return attended_input
617 |
618 | def multi_label_evaluation(self, input, target):
619 | one = torch.ones(target.shape).cuda()
620 | zero = torch.zeros(target.shape).cuda()
621 | res = torch.where(input > 0.5, one, zero)
622 |
623 | over = (res * target).sum(dim=1)
624 | union = res.sum(dim=1) + target.sum(dim=1) - over
625 | acc = over / union
626 |
627 | index = torch.isnan(acc) # nan appear when both pred and target are zeros, which means makes right answer
628 | acc_fix = torch.where(index, torch.ones(acc.shape).cuda(), acc)
629 |
630 | acc_sum = acc_fix.sum().item()
631 |
632 | return acc_sum
633 |
634 | @staticmethod
635 | def _get_loss(logits: torch.LongTensor,
636 | targets: torch.LongTensor,
637 | target_mask: torch.LongTensor) -> torch.Tensor:
638 | """
639 | Compute loss.
640 |
641 | Takes logits (unnormalized outputs from the decoder) of size (batch_size,
642 | num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
643 | and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
644 | entropy loss while taking the mask into account.
645 |
646 | The length of ``targets`` is expected to be greater than that of ``logits`` because the
647 | decoder does not need to compute the output corresponding to the last timestep of
648 | ``targets``. This method aligns the inputs appropriately to compute the loss.
649 |
650 | During training, we want the logit corresponding to timestep i to be similar to the target
651 | token from timestep i + 1. That is, the targets should be shifted by one timestep for
652 | appropriate comparison. Consider a single example where the target has 3 words, and
653 | padding is to 7 tokens.
654 | The complete sequence would correspond to w1 w2 w3
655 | and the mask would be 1 1 1 1 1 0 0
656 | and let the logits be l1 l2 l3 l4 l5 l6
657 | We actually need to compare:
658 | the sequence w1 w2 w3
659 | with masks 1 1 1 1 0 0
660 | against l1 l2 l3 l4 l5 l6
661 | (where the input was) w1 w2 w3
662 | """
663 | # shape: (batch_size, num_decoding_steps)
664 | relevant_targets = targets[:, 1:].contiguous()
665 |
666 | # shape: (batch_size, num_decoding_steps)
667 | relevant_mask = target_mask[:, 1:].contiguous()
668 |
669 | return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)
670 |
671 | @overrides
672 | def get_metrics(self, reset: bool = False) -> Dict[str, float]:
673 | all_metrics: Dict[str, float] = {}
674 | if self._bleu and not self.training:
675 | all_metrics.update(self._bleu.get_metric(reset=reset))
676 | # all_metrics.update({'acc': self._acc.get_metric(reset=reset)})
677 | all_metrics.update({'acc': self._acc.get_metric(reset=reset)})
678 | if len(self.new_acc) != 0:
679 | all_metrics.update({'new_acc': sum(self.new_acc)/len(self.new_acc)})
680 | print('Num of total, angle, len, other', len(self.new_acc), len(self.angle), len(self.length), len(self.other))
681 | if len(self.angle) != 0:
682 | all_metrics.update({'angle_acc': sum(self.angle)/len(self.angle)})
683 | if len(self.length) != 0:
684 | all_metrics.update({'length_acc': sum(self.length)/len(self.length)})
685 | if len(self.other) != 0:
686 | all_metrics.update({'other_acc': sum(self.other)/len(self.other)})
687 | all_metrics.update({'no_result': self._no_result.get_metric(reset=reset)})
688 |
689 | # if len(self.point_acc_list) != 0:
690 | # all_metrics.update({'point_acc': sum(self.point_acc_list) / len(self.point_acc_list)})
691 |
692 | return all_metrics
693 |
--------------------------------------------------------------------------------