724 | and the mask would be 1 1 1 1 1 0 0
725 | and let the logits be l1 l2 l3 l4 l5 l6
726 | We actually need to compare:
727 | the sequence w1 w2 w3
728 | with masks 1 1 1 1 0 0
729 | against l1 l2 l3 l4 l5 l6
730 | (where the input was) w1 w2 w3
731 | """
732 | # shape: (batch_size, num_decoding_steps)
733 | relevant_targets = targets[:, 1:].contiguous()
734 |
735 | # shape: (batch_size, num_decoding_steps)
736 | relevant_mask = target_mask[:, 1:].contiguous()
737 |
738 | return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)
739 |
740 | @overrides
741 | def get_metrics(self, reset: bool = False) -> Dict[str, float]:
742 | all_metrics: Dict[str, float] = {}
743 | if self._bleu and not self.training:
744 | all_metrics.update(self._bleu.get_metric(reset=reset))
745 | all_metrics.update({'acc': self._acc.get_metric(reset=reset)})
746 | all_metrics.update({'no_result': self._no_result.get_metric(reset=reset)})
747 |
748 | return all_metrics
749 |
--------------------------------------------------------------------------------
/GeoQA+/NGS_Aux_test.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple
2 |
3 | import numpy
4 | from overrides import overrides
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.nn as nn
8 | from torch.nn.modules.linear import Linear
9 | from torch.nn.modules.rnn import LSTMCell
10 | from torch.nn.modules.rnn import GRUCell
11 | from allennlp.common.checks import ConfigurationError
12 | from allennlp.common.util import START_SYMBOL, END_SYMBOL
13 | from allennlp.data.vocabulary import Vocabulary
14 | from allennlp.modules.attention import LegacyAttention
15 | from allennlp.modules import Attention, TextFieldEmbedder, Seq2SeqEncoder
16 | from allennlp.modules.similarity_functions import SimilarityFunction
17 | from allennlp.models.model import Model
18 | from allennlp.modules.token_embedders import Embedding
19 | from allennlp.nn import util
20 | from allennlp.nn.beam_search import BeamSearch
21 | from allennlp.training.metrics import BLEU
22 |
23 | from ManualProgram.eval_equ import Equations
24 | from transformers import AutoModel, AutoTokenizer
25 |
26 | import random
27 | import warnings
28 | import math
29 | warnings.filterwarnings("ignore")
30 | torch.cuda.set_device(0)
31 |
32 | no_result_id=[]
33 | right_id=[]
34 | wrong_manual=[]
35 | noresult_manual=[]
36 | from utils import *
37 |
38 | from mcan import *
39 | import json
40 | model_name = "data/pretrain/Roberta"
41 | @Model.register("MyEncoder")
42 | class Encoder(Model):
43 | def __init__(self,
44 | vocab: Vocabulary,
45 | input_dim: int,
46 | emb_dim: int,
47 | hid_dim: int,
48 | dropout: int):
49 | super(Encoder, self).__init__(vocab)
50 | self.input_dim = input_dim
51 | self.emb_dim = emb_dim
52 | self.hid_dim = hid_dim
53 | self.embedding = AutoModel.from_pretrained(model_name)
54 | self.trans = nn.Linear(emb_dim, hid_dim)
55 | self.norm = nn.LayerNorm(hid_dim)
56 | self.dropout = nn.Dropout(dropout)
57 | self.lstm_embedding = nn.Embedding(22128, hid_dim, padding_idx=0)
58 | self.lstm_dropout = nn.Dropout(0.5)
59 | self.lstm = torch.nn.LSTM(hid_dim, hid_dim, batch_first=True, bidirectional=True, num_layers=2, dropout=0.5)
60 | self.concat_trans = nn.Linear(hid_dim, hid_dim)
61 | self.concat_norm = nn.LayerNorm(hid_dim)
62 | self._encoder_output_dim = 512
63 | self._decoder_output_dim = 512
64 |
65 | @overrides
66 | def forward(self, src, source_mask):
67 |
68 | embedded = self.embedding(src, attention_mask=source_mask, return_dict=True, output_hidden_states=True)
69 | bert_output = embedded.last_hidden_state
70 | output = self.dropout(self.norm(torch.relu(self.trans(bert_output))))
71 | lstm_embedding = self.lstm_dropout(self.lstm_embedding(src))
72 | input_length = torch.sum(source_mask, dim=1).long().view(-1,).cpu()
73 | packed = nn.utils.rnn.pack_padded_sequence(lstm_embedding, input_length, batch_first=True, enforce_sorted=False)
74 | lstm_output, _ = self.lstm(packed)
75 | lstm_output, _ = nn.utils.rnn.pad_packed_sequence(lstm_output, batch_first=True)
76 | lstm_output = lstm_output[:, :, :self.hid_dim] + lstm_output[:, :, self.hid_dim:]
77 | output = output + lstm_output
78 | # output = torch.cat((output,lstm_output), dim=-1)
79 | output = self.concat_norm(torch.relu(self.concat_trans(output)))
80 | return output
81 |
82 | def get_output_dim(self):
83 | return self._encoder_output_dim
84 |
85 | def is_bidirectional(self) -> bool:
86 | return True
87 |
88 | @Model.register("geo_s2s")
89 | class SimpleSeq2Seq(Model):
90 | """
91 | This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then
92 | uses the encoded representations to decode another sequence. You can use this as the basis for
93 | a neural machine translation system, an abstractive summarization system, or any other common
94 | seq2seq problem. The model here is simple, but should be a decent starting place for
95 | implementing recent models for these tasks.
96 |
97 | Parameters
98 | ----------
99 | vocab : ``Vocabulary``, required
100 | Vocabulary containing source and target vocabularies. They may be under the same namespace
101 | (`tokens`) or the target tokens can have a different namespace, in which case it needs to
102 | be specified as `target_namespace`.
103 | source_embedder : ``TextFieldEmbedder``, required
104 | Embedder for source side sequences
105 | encoder : ``Seq2SeqEncoder``, required
106 | The encoder of the "encoder/decoder" model
107 | max_decoding_steps : ``int``
108 | Maximum length of decoded sequences.
109 | target_namespace : ``str``, optional (default = 'tokens')
110 | If the target side vocabulary is different from the source side's, you need to specify the
111 | target's namespace here. If not, we'll assume it is "tokens", which is also the default
112 | choice for the source side, and this might cause them to share vocabularies.
113 | target_embedding_dim : ``int``, optional (default = source_embedding_dim)
114 | You can specify an embedding dimensionality for the target side. If not, we'll use the same
115 | value as the source embedder's.
116 | attention : ``Attention``, optional (default = None)
117 | If you want to use attention to get a dynamic summary of the encoder outputs at each step
118 | of decoding, this is the function used to compute similarity between the decoder hidden
119 | state and encoder outputs.
120 | attention_function: ``SimilarityFunction``, optional (default = None)
121 | This is if you want to use the legacy implementation of attention. This will be deprecated
122 | since it consumes more memory than the specialized attention modules.
123 | beam_size : ``int``, optional (default = None)
124 | Width of the beam for beam search. If not specified, greedy decoding is used.
125 | scheduled_sampling_ratio : ``float``, optional (default = 0.)
126 | At each timestep during training, we sample a random number between 0 and 1, and if it is
127 | not less than this value, we use the ground truth labels for the whole batch. Else, we use
128 | the predictions from the previous time step for the whole batch. If this value is 0.0
129 | (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not
130 | using target side ground truth labels. See the following paper for more information:
131 | `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al.,
132 | 2015 `_.
133 | use_bleu : ``bool``, optional (default = True)
134 | If True, the BLEU metric will be calculated during validation.
135 | """
136 |
137 | def __init__(self,
138 | vocab: Vocabulary,
139 | source_embedder: TextFieldEmbedder,
140 | encoder: Encoder,
141 | max_decoding_steps: int,
142 | knowledge_points_ratio = 0,
143 | attention: Attention = True,
144 | attention_function: SimilarityFunction = None,
145 | beam_size: int = None,
146 | target_namespace: str = "tokens",
147 | target_embedding_dim: int = None,
148 | scheduled_sampling_ratio: float = 0.,
149 | resnet_pretrained = None,
150 | use_bleu: bool = True) -> None:
151 | super(SimpleSeq2Seq, self).__init__(vocab)
152 |
153 | resnet = build_model()
154 |
155 | if resnet_pretrained is not None:
156 | resnet.load_state_dict(torch.load(resnet_pretrained))
157 | print('##### Checkpoint Loaded! #####')
158 | else:
159 | print("No Diagram Pretrain !!!")
160 | self.resnet = resnet
161 |
162 | self.channel_transform = torch.nn.Linear(1024, 512)
163 |
164 | __C = Cfgs()
165 | self.mcan = MCA_ED(__C)
166 | self.attflat_img = AttFlat(__C)
167 | self.attflat_lang = AttFlat(__C) # not use
168 |
169 | self.decode_transform = torch.nn.Linear(1024, 512)
170 |
171 | self._equ = Equations()
172 |
173 | self._target_namespace = target_namespace
174 | self._scheduled_sampling_ratio = scheduled_sampling_ratio
175 |
176 | # We need the start symbol to provide as the input at the first timestep of decoding, and
177 | # end symbol as a way to indicate the end of the decoded sequence.
178 | self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace)
179 | self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace)
180 |
181 | if use_bleu:
182 | pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access
183 | self._bleu = BLEU(ngram_weights=(1, 0, 0, 0), exclude_indices={pad_index, self._end_index, self._start_index})
184 | else:
185 | self._bleu = None
186 | self._acc = Average()
187 | self._no_result = Average()
188 |
189 | # remember to clear after evaluation
190 | self.new_acc = []
191 | self.angle = []
192 | self.length = []
193 | self.area = []
194 | self.other = []
195 | self.point_acc_list = []
196 | self.save_results = dict()
197 |
198 | # At prediction time, we use a beam search to find the most likely sequence of target tokens.
199 | beam_size = beam_size or 1
200 | self._max_decoding_steps = max_decoding_steps
201 | self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size)
202 |
203 | # Dense embedding of source vocab tokens.
204 | self._source_embedder = source_embedder
205 |
206 | # Encodes the sequence of source embeddings into a sequence of hidden states.
207 | self._encoder = encoder
208 |
209 | num_classes = self.vocab.get_vocab_size(self._target_namespace)
210 |
211 | # Attention mechanism applied to the encoder output for each step.
212 | # TODO: attention
213 | if attention:
214 | if attention_function:
215 | raise ConfigurationError("You can only specify an attention module or an "
216 | "attention function, but not both.")
217 | self._attention = LegacyAttention()
218 | elif attention_function:
219 | self._attention = LegacyAttention(attention_function)
220 | else:
221 | self._attention = None
222 | print("No Attention!")
223 | exit()
224 |
225 | # Dense embedding of vocab words in the target space.
226 | target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
227 | self._target_embedder = Embedding(num_classes, target_embedding_dim)
228 |
229 | # Decoder output dim needs to be the same as the encoder output dim since we initialize the
230 | # hidden state of the decoder with the final hidden state of the encoder.
231 | self._encoder_output_dim = self._encoder.get_output_dim()
232 | self._decoder_output_dim = self._encoder_output_dim
233 |
234 | if self._attention:
235 | # If using attention, a weighted average over encoder outputs will be concatenated
236 | # to the previous target embedding to form the input to the decoder at each
237 | # time step.
238 | self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim
239 | else:
240 | # Otherwise, the input to the decoder is just the previous target embedding.
241 | self._decoder_input_dim = target_embedding_dim
242 |
243 | # We'll use an LSTM cell as the recurrent cell that produces a hidden state
244 | # for the decoder at each time step.
245 | self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)
246 | #self._decoder_cell = GRUCell(self._decoder_input_dim, self._decoder_output_dim)
247 | # We project the hidden state from the decoder into the output vocabulary space
248 | # in order to get log probabilities of each target token, at each time step.
249 | self._output_projection_layer = Linear(self._decoder_output_dim, num_classes)
250 | # knowledge points
251 | self.point_ratio = knowledge_points_ratio
252 | if self.point_ratio != 0:
253 | self.points_norm = LayerNorm(__C.FLAT_OUT_SIZE)
254 | self.points_proj = nn.Linear(__C.FLAT_OUT_SIZE, 77)
255 | self.points_criterion = nn.BCELoss()
256 |
257 | def take_step(self,
258 | last_predictions: torch.Tensor,
259 | state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
260 | """
261 | Take a decoding step. This is called by the beam search class.
262 |
263 | Parameters
264 | ----------
265 | last_predictions : ``torch.Tensor``
266 | A tensor of shape ``(group_size,)``, which gives the indices of the predictions
267 | during the last time step.
268 | state : ``Dict[str, torch.Tensor]``
269 | A dictionary of tensors that contain the current state information
270 | needed to predict the next step, which includes the encoder outputs,
271 | the source mask, and the decoder hidden state and context. Each of these
272 | tensors has shape ``(group_size, *)``, where ``*`` can be any other number
273 | of dimensions.
274 |
275 | Returns
276 | -------
277 | Tuple[torch.Tensor, Dict[str, torch.Tensor]]
278 | A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities``
279 | is a tensor of shape ``(group_size, num_classes)`` containing the predicted
280 | log probability of each class for the next step, for each item in the group,
281 | while ``updated_state`` is a dictionary of tensors containing the encoder outputs,
282 | source mask, and updated decoder hidden state and context.
283 |
284 | Notes
285 | -----
286 | We treat the inputs as a batch, even though ``group_size`` is not necessarily
287 | equal to ``batch_size``, since the group may contain multiple states
288 | for each source sentence in the batch.
289 | """
290 | # shape: (group_size, num_classes)
291 | output_projections, state = self._prepare_output_projections(last_predictions, state)
292 |
293 | # shape: (group_size, num_classes)
294 | class_log_probabilities = F.log_softmax(output_projections, dim=-1)
295 |
296 | return class_log_probabilities, state
297 |
298 | @overrides
299 | def forward(self, # type: ignore
300 | image, source_nums, choice_nums, label, type, data_id, manual_program,
301 | source_tokens: Dict[str, torch.LongTensor],
302 | point_label = None,
303 | target_tokens: Dict[str, torch.LongTensor] = None, **kwargs) -> Dict[str, torch.Tensor]:
304 | # pylint: disable=arguments-differ
305 | """
306 | Make foward pass with decoder logic for producing the entire target sequence.
307 |
308 | Parameters
309 | ----------
310 | source_tokens : ``Dict[str, torch.LongTensor]``
311 | The output of `TextField.as_array()` applied on the source `TextField`. This will be
312 | passed through a `TextFieldEmbedder` and then through an encoder.
313 | target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None)
314 | Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
315 | target tokens are also represented as a `TextField`.
316 |
317 | Returns
318 | -------
319 | Dict[str, torch.Tensor]
320 | """
321 | bs = len(label)
322 | state = self._encode(source_tokens)
323 |
324 | with torch.no_grad():
325 | img_feats = self.resnet(image)
326 | # (N, C, 14, 14) -> (N, 196, C)
327 | img_feats = img_feats.reshape(img_feats.shape[0], img_feats.shape[1], -1).transpose(1, 2)
328 | img_mask = make_mask(img_feats)
329 | img_feats = self.channel_transform(img_feats)
330 |
331 | lang_feats = state['encoder_outputs']
332 | # mask the digital encoding question without embedding, i.e. source_tokens(already index to number)
333 | lang_mask = make_mask(source_tokens['tokens'].unsqueeze(2))
334 |
335 | _, img_feats = self.mcan(lang_feats, img_feats, lang_mask, img_mask)
336 |
337 | # (N, 308, 512)
338 | # for attention, image first and then lang, using mask
339 | state['encoder_outputs'] = torch.cat([img_feats, lang_feats], 1)
340 |
341 | # decode
342 | state = self._init_decoder_state(state, lang_feats, img_feats, img_mask)
343 | output_dict = self._forward_loop(state, target_tokens) # recurrent decoding for LSTM
344 |
345 | # knowledge points
346 | if self.point_ratio != 0:
347 | concat_feature = state["concat_feature"]
348 | point_feat = self.points_norm(concat_feature)
349 | point_feat = self.points_proj(point_feat)
350 | point_pred = torch.sigmoid(point_feat)
351 | point_loss = self.points_criterion(point_pred, point_label) * self.point_ratio
352 | output_dict["point_pred"] = point_pred
353 | output_dict["point_loss"] = point_loss
354 | output_dict["loss"] += point_loss
355 |
356 | # if testing, beam search and evaluation
357 | if not self.training:
358 | # state = self._init_decoder_state(state)
359 | state = self._init_decoder_state(state, lang_feats, img_feats, img_mask) # TODO
360 | predictions = self._forward_beam_search(state)
361 | output_dict.update(predictions)
362 |
363 | if target_tokens and self._bleu:
364 | # shape: (batch_size, beam_size, max_sequence_length)
365 | top_k_predictions = output_dict["predictions"]
366 |
367 | # execute the decode programs to calculate the accuracy
368 | suc_knt, no_knt, = 0, 0
369 |
370 | selected_programs = []
371 | wrong_id = []
372 | noresult_id = []
373 |
374 | for b in range(bs):
375 |
376 | hypo = None
377 | used_hypo = None
378 | choice = None
379 | for i in range(self._beam_search.beam_size):
380 | if choice is not None:
381 | break
382 | hypo = list(top_k_predictions[b][i])
383 | if self._end_index in list(hypo):
384 | hypo = hypo[:hypo.index(self._end_index)]
385 | hypo = [self.vocab.get_token_from_index(idx.item()) for idx in hypo]
386 | res = self._equ.excuate_equation(hypo, source_nums[b])
387 |
388 | if res is not None and len(res) > 0:
389 | for j in range(4):
390 | if choice_nums[b][j] is not None and math.fabs(res[-1] - choice_nums[b][j]) < 0.001:
391 | choice = j
392 | used_hypo = hypo
393 | selected_programs.append([hypo])
394 | if choice is None:
395 | no_knt += 1
396 | answer_state = 'no_result'
397 | #no_result_id.append(data_id[b])
398 |
399 | self.new_acc.append(0)
400 | elif choice == label[b]:
401 | suc_knt += 1
402 | answer_state = 'right'
403 | self.new_acc.append(1)
404 | right_id.append(data_id[b])
405 | else:
406 | answer_state = 'false'
407 | wrong_id.append(b)
408 | self.new_acc.append(0)
409 |
410 | self.save_results[data_id[b]] = dict(manual_program=manual_program[b],
411 | predict_program=hypo, predict_res=res,
412 | choice=choice_nums[b], right_answer=label[b],
413 | answer_state=answer_state)
414 |
415 | flag = 1 if choice == label[b] else 0
416 | if type[b] == 'angle':
417 | self.angle.append(flag)
418 | elif type[b] == 'length':
419 | self.length.append(flag)
420 | else:
421 | self.other.append(flag)
422 |
423 | # knowledge points
424 | # if self.point_ratio != 0:
425 | # point_acc = self.multi_label_evaluation(point_pred[b].unsqueeze(0), point_label[b].unsqueeze(0))
426 | # self.point_acc_list.append(point_acc)
427 |
428 | # with open('save/test.json', 'w') as f:
429 | # json.dump(self.save_results, f)
430 |
431 | if random.random() < 0.05:
432 | print('selected_programs', selected_programs)
433 | """
434 | for item in noresult_id:
435 | noresult_manual.append(selected_programs[item])
436 |
437 | for item in wrong_id:
438 | wrong_manual.append(selected_programs[item])
439 |
440 | print((wrong_manual),(noresult_manual))
441 | """
442 | # calculate BLEU
443 | best_predictions = top_k_predictions[:, 0, :]
444 | self._bleu(best_predictions, target_tokens["tokens"])
445 | self._acc(suc_knt / bs)
446 | self._no_result(no_knt / bs)
447 |
448 | print(right_id)
449 | print(len(right_id))
450 | return output_dict
451 |
452 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
453 | """
454 | Finalize predictions.
455 |
456 | This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test
457 | time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
458 | within the ``forward`` method.
459 |
460 | This method trims the output predictions to the first end symbol, replaces indices with
461 | corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``.
462 | """
463 | predicted_indices = output_dict["predictions"]
464 | if not isinstance(predicted_indices, numpy.ndarray):
465 | predicted_indices = predicted_indices.detach().cpu().numpy()
466 | all_predicted_tokens = []
467 | for indices in predicted_indices:
468 | # Beam search gives us the top k results for each source sentence in the batch
469 | # but we just want the single best.
470 | if len(indices.shape) > 1:
471 | indices = indices[0]
472 | indices = list(indices)
473 | # Collect indices till the first end_symbol
474 | if self._end_index in indices:
475 | indices = indices[:indices.index(self._end_index)]
476 | predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace)
477 | for x in indices]
478 | all_predicted_tokens.append(predicted_tokens)
479 | output_dict["predicted_tokens"] = all_predicted_tokens
480 | return output_dict
481 |
482 | def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
483 | # shape: (batch_size, max_input_sequence_length, encoder_input_dim)
484 | #embedded_input = self._source_embedder(source_tokens)
485 | # shape: (batch_size, max_input_sequence_length)
486 | source_mask = util.get_text_field_mask(source_tokens)
487 |
488 | img_mask = torch.ones(source_mask.shape[0], 196).long().cuda()
489 | concat_mask = torch.cat([img_mask, source_mask], 1)
490 |
491 | # shape: (batch_size, max_input_sequence_length, encoder_output_dim)
492 | #encoder_outputs = self._encoder(embedded_input, source_mask)
493 | encoder_outputs = self._encoder(source_tokens['tokens'], source_mask)
494 |
495 | return {
496 | "source_mask": source_mask, # source_mask,
497 | "concat_mask": concat_mask,
498 | "encoder_outputs": encoder_outputs,
499 | }
500 |
501 | def _init_decoder_state(self, state, lang_feats, img_feats, img_mask):
502 |
503 | batch_size = state["source_mask"].size(0)
504 | final_lang_feat = util.get_final_encoder_states(
505 | lang_feats,
506 | state["source_mask"],
507 | self._encoder.is_bidirectional())
508 | img_feat = self.attflat_img(img_feats, img_mask)
509 | feat = torch.cat([final_lang_feat, img_feat], 1)
510 | feat = self.decode_transform(feat)
511 | state["concat_feature"] = feat
512 |
513 | state["decoder_hidden"] = feat
514 | # C0 shape: (batch_size, decoder_output_dim)
515 | state["decoder_context"] = torch.zeros(batch_size, self._decoder_output_dim).cuda()
516 | # state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim)
517 | return state
518 |
519 | def _forward_loop(self,
520 | state: Dict[str, torch.Tensor],
521 | target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]:
522 | """
523 | Make forward pass during training or do greedy search during prediction.
524 |
525 | Notes
526 | -----
527 | We really only use the predictions from the method to test that beam search
528 | with a beam size of 1 gives the same results.
529 | """
530 | # shape: (batch_size, max_input_sequence_length)
531 | source_mask = state["source_mask"]
532 |
533 | batch_size = source_mask.size()[0]
534 |
535 | if target_tokens:
536 | # shape: (batch_size, max_target_sequence_length)
537 | targets = target_tokens["tokens"]
538 |
539 | _, target_sequence_length = targets.size()
540 |
541 | # The last input from the target is either padding or the end symbol.
542 | # Either way, we don't have to process it.
543 | num_decoding_steps = target_sequence_length - 1
544 | else:
545 | num_decoding_steps = self._max_decoding_steps
546 |
547 | # Initialize target predictions with the start index.
548 | # shape: (batch_size,)
549 | last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index)
550 |
551 | step_logits: List[torch.Tensor] = []
552 | step_predictions: List[torch.Tensor] = []
553 | for timestep in range(num_decoding_steps):
554 | if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio:
555 | # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio
556 | # during training.
557 | # shape: (batch_size,)
558 | input_choices = last_predictions
559 | elif not target_tokens:
560 | # shape: (batch_size,)
561 | input_choices = last_predictions
562 | else:
563 | # shape: (batch_size,)
564 | input_choices = targets[:, timestep]
565 |
566 | # shape: (batch_size, num_classes)
567 | # recurrent decoding
568 | output_projections, state = self._prepare_output_projections(input_choices, state)
569 |
570 | # list of tensors, shape: (batch_size, 1, num_classes)
571 | step_logits.append(output_projections.unsqueeze(1))
572 |
573 | # shape: (batch_size, num_classes)
574 | class_probabilities = F.softmax(output_projections, dim=-1)
575 |
576 | # shape (predicted_classes): (batch_size,)
577 | _, predicted_classes = torch.max(class_probabilities, 1)
578 |
579 | # shape (predicted_classes): (batch_size,)
580 | last_predictions = predicted_classes
581 |
582 | step_predictions.append(last_predictions.unsqueeze(1))
583 |
584 | # shape: (batch_size, num_decoding_steps)
585 | predictions = torch.cat(step_predictions, 1)
586 |
587 | output_dict = {"predictions": predictions}
588 |
589 | if target_tokens:
590 | # shape: (batch_size, num_decoding_steps, num_classes)
591 | logits = torch.cat(step_logits, 1)
592 |
593 | # Compute loss.
594 | target_mask = util.get_text_field_mask(target_tokens)
595 | loss = self._get_loss(logits, targets, target_mask)
596 | output_dict["loss"] = loss
597 |
598 | return output_dict
599 |
600 | def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
601 | """Make forward pass during prediction using a beam search."""
602 | batch_size = state["source_mask"].size()[0]
603 | start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index)
604 |
605 | # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps)
606 | # shape (log_probabilities): (batch_size, beam_size)
607 | all_top_k_predictions, log_probabilities = self._beam_search.search(
608 | start_predictions, state, self.take_step)
609 |
610 | output_dict = {
611 | "class_log_probabilities": log_probabilities,
612 | "predictions": all_top_k_predictions,
613 | }
614 | return output_dict
615 |
616 | def _prepare_output_projections(self,
617 | last_predictions: torch.Tensor,
618 | state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long
619 | """
620 | Decode current state and last prediction to produce produce projections
621 | into the target space, which can then be used to get probabilities of
622 | each target token for the next step.
623 |
624 | Inputs are the same as for `take_step()`.
625 | """
626 | # shape: (group_size, max_input_sequence_length, encoder_output_dim)
627 | encoder_outputs = state["encoder_outputs"]
628 |
629 | # shape: (group_size, max_input_sequence_length)
630 | # source_mask = state["source_mask"]
631 | source_mask = state["concat_mask"]
632 |
633 | # decoder_hidden and decoder_context are get from encoder_outputs in _init_decoder_state()
634 | # shape: (group_size, decoder_output_dim)
635 | decoder_hidden = state["decoder_hidden"]
636 | # shape: (group_size, decoder_output_dim)
637 | decoder_context = state["decoder_context"]
638 |
639 | # shape: (group_size, target_embedding_dim)
640 | embedded_input = self._target_embedder(last_predictions)
641 |
642 | if self._attention:
643 | # shape: (group_size, encoder_output_dim)
644 | attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask)
645 |
646 | # shape: (group_size, decoder_output_dim + target_embedding_dim)
647 | decoder_input = torch.cat((attended_input, embedded_input), -1)
648 |
649 | else:
650 | # shape: (group_size, target_embedding_dim)
651 | decoder_input = embedded_input
652 |
653 | # shape (decoder_hidden): (batch_size, decoder_output_dim)
654 | # shape (decoder_context): (batch_size, decoder_output_dim)
655 |
656 | decoder_hidden, decoder_context = self._decoder_cell(
657 | decoder_input,
658 | (decoder_hidden, decoder_context))
659 |
660 | state["decoder_hidden"] = decoder_hidden
661 | state["decoder_context"] = decoder_context
662 |
663 | # shape: (group_size, num_classes)
664 | output_projections = self._output_projection_layer(decoder_hidden)
665 | """
666 | decoder_hidden = self._decoder_cell(
667 | decoder_input,
668 | (decoder_hidden))
669 |
670 | state["decoder_hidden"] = decoder_hidden
671 | state["decoder_context"] = decoder_hidden
672 |
673 | # shape: (group_size, num_classes)
674 | output_projections = self._output_projection_layer(decoder_hidden)
675 | """
676 | return output_projections, state
677 |
678 | def _prepare_attended_input(self,
679 | decoder_hidden_state: torch.LongTensor = None,
680 | encoder_outputs: torch.LongTensor = None,
681 | encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor:
682 | """Apply attention over encoder outputs and decoder state."""
683 | # Ensure mask is also a FloatTensor. Or else the multiplication within
684 | # attention will complain.
685 | # shape: (batch_size, max_input_sequence_length)
686 | encoder_outputs_mask = encoder_outputs_mask.float()
687 |
688 | # shape: (batch_size, max_input_sequence_length)
689 | input_weights = self._attention(
690 | decoder_hidden_state, encoder_outputs, encoder_outputs_mask)
691 |
692 | # shape: (batch_size, encoder_output_dim)
693 | attended_input = util.weighted_sum(encoder_outputs, input_weights)
694 |
695 | return attended_input
696 |
697 | def multi_label_evaluation(self, input, target):
698 | one = torch.ones(target.shape).cuda()
699 | zero = torch.zeros(target.shape).cuda()
700 | res = torch.where(input > 0.5, one, zero)
701 |
702 | over = (res * target).sum(dim=1)
703 | union = res.sum(dim=1) + target.sum(dim=1) - over
704 | acc = over / union
705 |
706 | index = torch.isnan(acc) # nan appear when both pred and target are zeros, which means makes right answer
707 | acc_fix = torch.where(index, torch.ones(acc.shape).cuda(), acc)
708 |
709 | acc_sum = acc_fix.sum().item()
710 |
711 | return acc_sum
712 |
713 | @staticmethod
714 | def _get_loss(logits: torch.LongTensor,
715 | targets: torch.LongTensor,
716 | target_mask: torch.LongTensor) -> torch.Tensor:
717 | """
718 | Compute loss.
719 |
720 | Takes logits (unnormalized outputs from the decoder) of size (batch_size,
721 | num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
722 | and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
723 | entropy loss while taking the mask into account.
724 |
725 | The length of ``targets`` is expected to be greater than that of ``logits`` because the
726 | decoder does not need to compute the output corresponding to the last timestep of
727 | ``targets``. This method aligns the inputs appropriately to compute the loss.
728 |
729 | During training, we want the logit corresponding to timestep i to be similar to the target
730 | token from timestep i + 1. That is, the targets should be shifted by one timestep for
731 | appropriate comparison. Consider a single example where the target has 3 words, and
732 | padding is to 7 tokens.
733 | The complete sequence would correspond to w1 w2 w3
734 | and the mask would be 1 1 1 1 1 0 0
735 | and let the logits be l1 l2 l3 l4 l5 l6
736 | We actually need to compare:
737 | the sequence w1 w2 w3
738 | with masks 1 1 1 1 0 0
739 | against l1 l2 l3 l4 l5 l6
740 | (where the input was) w1 w2 w3
741 | """
742 | # shape: (batch_size, num_decoding_steps)
743 | relevant_targets = targets[:, 1:].contiguous()
744 |
745 | # shape: (batch_size, num_decoding_steps)
746 | relevant_mask = target_mask[:, 1:].contiguous()
747 |
748 | return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask)
749 |
750 | @overrides
751 | def get_metrics(self, reset: bool = False) -> Dict[str, float]:
752 | all_metrics: Dict[str, float] = {}
753 | if self._bleu and not self.training:
754 | all_metrics.update(self._bleu.get_metric(reset=reset))
755 | # all_metrics.update({'acc': self._acc.get_metric(reset=reset)})
756 | all_metrics.update({'acc': self._acc.get_metric(reset=reset)})
757 | if len(self.new_acc) != 0:
758 | all_metrics.update({'new_acc': sum(self.new_acc)/len(self.new_acc)})
759 | print('Num of total, angle, len, other', len(self.new_acc), len(self.angle), len(self.length), len(self.other))
760 | if len(self.angle) != 0:
761 | all_metrics.update({'angle_acc': sum(self.angle)/len(self.angle)})
762 | if len(self.length) != 0:
763 | all_metrics.update({'length_acc': sum(self.length)/len(self.length)})
764 | if len(self.other) != 0:
765 | all_metrics.update({'other_acc': sum(self.other)/len(self.other)})
766 | all_metrics.update({'no_result': self._no_result.get_metric(reset=reset)})
767 |
768 | # if len(self.point_acc_list) != 0:
769 | # all_metrics.update({'point_acc': sum(self.point_acc_list) / len(self.point_acc_list)})
770 |
771 | return all_metrics
772 |
--------------------------------------------------------------------------------