├── ManualProgram ├── eval_equ.py └── operators.py ├── NGS_Aux.py ├── NGS_Aux_test.py ├── README.md ├── config └── NGS_Aux.json ├── mcan.py ├── requirements.txt ├── resnet.py └── utils.py /ManualProgram/eval_equ.py: -------------------------------------------------------------------------------- 1 | from ManualProgram import operators 2 | from inspect import getmembers, isfunction 3 | import itertools 4 | import math 5 | 6 | 7 | constant = [30, 60, 90, 180, 360, math.pi, 0.618] 8 | op_dict = {0: 'g_equal', 1: 'g_double', 2: 'g_half', 3: 'g_add', 4: 'g_minus', 9 | 5: 'g_sin', 6: 'g_cos', 7: 'g_tan', 8: 'g_asin', 9: 'g_acos', 10 | 10: 'gougu_add', 11: 'gougu_minus', 12: 'g_bili', 11 | 13: 'g_mul', 14: 'g_divide', 15: 'cal_circle_area', 16: 'cal_circle_perimeter', 17: 'cal_cone'} 12 | op_list = [op_dict[key] for key in sorted(op_dict.keys())] 13 | 14 | 15 | class Equations: 16 | def __init__(self): 17 | 18 | self.op_list = op_list 19 | self.op_num = {} 20 | self.call_op = {} 21 | self.exp_info = None 22 | self.results = [] 23 | self.max_step = 3 24 | self.max_len = 7 25 | for op in self.op_list: 26 | self.call_op[op] = eval('operators.{}'.format(op)) 27 | # self.call_op[op] = eval(op) 28 | self.op_num[op] = self.call_op[op].__code__.co_argcount 29 | 30 | def str2exp(self, inputs): 31 | inputs = inputs.split(',') 32 | exp = inputs.copy() 33 | for i, s in enumerate(inputs): 34 | if 'n' in s or 'v' in s or 'c' in s: 35 | exp[i] = s.replace('n', 'N_').replace('v', 'V_').replace('c', 'C_') 36 | else: 37 | exp[i] = op_dict[int(s[2:])] 38 | exp[i] = exp[i].strip() 39 | 40 | self.exp = exp 41 | return exp 42 | 43 | def excuate_equation(self, exp, source_nums=None): 44 | 45 | if source_nums is None: 46 | source_nums = self.exp_info['nums'] 47 | vars = [] 48 | idx = 0 49 | while idx < len(exp): 50 | op = exp[idx] 51 | if op not in self.op_list: 52 | return None 53 | op_nums = self.op_num[op] 54 | if idx + op_nums >= len(exp): 55 | return None 56 | excuate_nums = [] 57 | for tmp in exp[idx + 1: idx + 1 + op_nums]: 58 | if tmp[0] == 'N' and int(tmp[-1]) < len(source_nums): 59 | excuate_nums.append(source_nums[int(tmp[-1])]) 60 | elif tmp[0] == 'V' and int(tmp[-1]) < len(vars): 61 | excuate_nums.append(vars[int(tmp[-1])]) 62 | elif tmp[0] == 'C' and int(tmp[-1]) < len(constant): 63 | excuate_nums.append(constant[int(tmp[-1])]) 64 | else: 65 | return None 66 | idx += op_nums + 1 67 | v = self.call_op[op](*excuate_nums) 68 | if v is None: 69 | return None 70 | vars.append(v) 71 | return vars 72 | 73 | 74 | if __name__ == '__main__': 75 | eq = Equations() 76 | 77 | -------------------------------------------------------------------------------- /ManualProgram/operators.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def g_equal(n1): # 0 5 | return n1 6 | 7 | 8 | def g_double(n1): # 1 9 | return n1*2 10 | 11 | 12 | def g_half(n1): # 2 13 | return n1/2 14 | 15 | 16 | def g_add(n1, n2): # 3 17 | return n1 + n2 18 | 19 | 20 | def g_minus(n1, n2): # 4 21 | return math.fabs(n1 - n2) 22 | 23 | 24 | def g_sin(n1): # 5 25 | if n1 % 15 == 0 and 0 <= n1 <= 180: 26 | return math.sin(n1/180*math.pi) 27 | return False 28 | 29 | 30 | def g_cos(n1): # 6 31 | if n1 % 15 == 0 and 0 <= n1 <= 180: 32 | return math.cos(n1/180*math.pi) 33 | return False 34 | 35 | 36 | def g_tan(n1): # 7 37 | if n1 % 15 == 0 and 5 <= n1 <= 85: 38 | return math.tan(n1/180*math.pi) 39 | return False 40 | 41 | 42 | def g_asin(n1): # 8 43 | if -1 < n1 < 1: 44 | n1 = math.asin(n1) 45 | n1 = math.degrees(n1) 46 | return n1 47 | return False 48 | 49 | 50 | def g_acos(n1): # 9 51 | if -1 < n1 < 1: 52 | n1 = math.acos(n1) 53 | n1 = math.degrees(n1) 54 | return n1 55 | return False 56 | 57 | 58 | def gougu_add(n1, n2): # 13 59 | return math.sqrt(n1*n1+n2*n2) 60 | 61 | 62 | def gougu_minus(n1, n2): # 14 63 | if n1 != n2: 64 | return math.sqrt(math.fabs(n1*n1-n2*n2)) 65 | return False 66 | 67 | 68 | def g_bili(n1, n2, n3): # 16 69 | if n1 > 0 and n2 > 0 and n3 > 0: 70 | return n1/n2*n3 71 | else: 72 | return False 73 | 74 | 75 | def g_mul(n1, n2): # 17 76 | return n1*n2 77 | 78 | 79 | def g_divide(n1, n2): # 18 80 | if n1 > 0 and n2 > 0: 81 | return n1/n2 82 | return False 83 | 84 | 85 | def cal_circle_area(n1): # 19 86 | return n1*n1*math.pi 87 | 88 | 89 | def cal_circle_perimeter(n1): # 20 90 | return 2*math.pi*n1 91 | 92 | 93 | def cal_cone(n1, n2): # 21 94 | return n1*n2*math.pi 95 | 96 | -------------------------------------------------------------------------------- /NGS_Aux.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import numpy 4 | from overrides import overrides 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | from torch.nn.modules.linear import Linear 9 | from torch.nn.modules.rnn import LSTMCell 10 | 11 | from allennlp.common.checks import ConfigurationError 12 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 13 | from allennlp.data.vocabulary import Vocabulary 14 | from allennlp.modules.attention import LegacyAttention 15 | from allennlp.modules import Attention, TextFieldEmbedder, Seq2SeqEncoder 16 | from allennlp.modules.similarity_functions import SimilarityFunction 17 | from allennlp.models.model import Model 18 | from allennlp.modules.token_embedders import Embedding 19 | from allennlp.nn import util 20 | from allennlp.nn.beam_search import BeamSearch 21 | from allennlp.training.metrics import BLEU 22 | 23 | from ManualProgram.eval_equ import Equations 24 | 25 | import random 26 | import warnings 27 | import math 28 | warnings.filterwarnings("ignore") 29 | 30 | from utils import * 31 | 32 | from mcan import * 33 | 34 | 35 | @Model.register("geo_s2s") 36 | class SimpleSeq2Seq(Model): 37 | """ 38 | This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then 39 | uses the encoded representations to decode another sequence. You can use this as the basis for 40 | a neural machine translation system, an abstractive summarization system, or any other common 41 | seq2seq problem. The model here is simple, but should be a decent starting place for 42 | implementing recent models for these tasks. 43 | 44 | Parameters 45 | ---------- 46 | vocab : ``Vocabulary``, required 47 | Vocabulary containing source and target vocabularies. They may be under the same namespace 48 | (`tokens`) or the target tokens can have a different namespace, in which case it needs to 49 | be specified as `target_namespace`. 50 | source_embedder : ``TextFieldEmbedder``, required 51 | Embedder for source side sequences 52 | encoder : ``Seq2SeqEncoder``, required 53 | The encoder of the "encoder/decoder" model 54 | max_decoding_steps : ``int`` 55 | Maximum length of decoded sequences. 56 | target_namespace : ``str``, optional (default = 'tokens') 57 | If the target side vocabulary is different from the source side's, you need to specify the 58 | target's namespace here. If not, we'll assume it is "tokens", which is also the default 59 | choice for the source side, and this might cause them to share vocabularies. 60 | target_embedding_dim : ``int``, optional (default = source_embedding_dim) 61 | You can specify an embedding dimensionality for the target side. If not, we'll use the same 62 | value as the source embedder's. 63 | attention : ``Attention``, optional (default = None) 64 | If you want to use attention to get a dynamic summary of the encoder outputs at each step 65 | of decoding, this is the function used to compute similarity between the decoder hidden 66 | state and encoder outputs. 67 | attention_function: ``SimilarityFunction``, optional (default = None) 68 | This is if you want to use the legacy implementation of attention. This will be deprecated 69 | since it consumes more memory than the specialized attention modules. 70 | beam_size : ``int``, optional (default = None) 71 | Width of the beam for beam search. If not specified, greedy decoding is used. 72 | scheduled_sampling_ratio : ``float``, optional (default = 0.) 73 | At each timestep during training, we sample a random number between 0 and 1, and if it is 74 | not less than this value, we use the ground truth labels for the whole batch. Else, we use 75 | the predictions from the previous time step for the whole batch. If this value is 0.0 76 | (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not 77 | using target side ground truth labels. See the following paper for more information: 78 | `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 79 | 2015 `_. 80 | use_bleu : ``bool``, optional (default = True) 81 | If True, the BLEU metric will be calculated during validation. 82 | """ 83 | 84 | def __init__(self, 85 | vocab: Vocabulary, 86 | source_embedder: TextFieldEmbedder, 87 | encoder: Seq2SeqEncoder, 88 | max_decoding_steps: int, 89 | knowledge_points_ratio = 0, 90 | attention: Attention = True, 91 | attention_function: SimilarityFunction = None, 92 | beam_size: int = None, 93 | target_namespace: str = "tokens", 94 | target_embedding_dim: int = None, 95 | scheduled_sampling_ratio: float = 0., 96 | resnet_pretrained = None, 97 | use_bleu: bool = True) -> None: 98 | super(SimpleSeq2Seq, self).__init__(vocab) 99 | 100 | resnet = build_model() 101 | 102 | if resnet_pretrained is not None: 103 | resnet.load_state_dict(torch.load(resnet_pretrained)) 104 | print('##### Checkpoint Loaded! #####') 105 | else: 106 | print("No Diagram Pretrain !!!") 107 | self.resnet = resnet 108 | 109 | self.channel_transform = torch.nn.Linear(1024, 512) 110 | 111 | __C = Cfgs() 112 | self.mcan = MCA_ED(__C) 113 | self.attflat_img = AttFlat(__C) 114 | self.attflat_lang = AttFlat(__C) # not use 115 | 116 | self.decode_transform = torch.nn.Linear(1024, 512) 117 | 118 | self._equ = Equations() 119 | 120 | self._target_namespace = target_namespace 121 | self._scheduled_sampling_ratio = scheduled_sampling_ratio 122 | 123 | # We need the start symbol to provide as the input at the first timestep of decoding, and 124 | # end symbol as a way to indicate the end of the decoded sequence. 125 | self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) 126 | self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) 127 | 128 | if use_bleu: 129 | pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access 130 | self._bleu = BLEU(ngram_weights=(1, 0, 0, 0), exclude_indices={pad_index, self._end_index, self._start_index}) 131 | else: 132 | self._bleu = None 133 | self._acc = Average() 134 | self._no_result = Average() 135 | 136 | # remember to clear after evaluation 137 | self.new_acc = [] 138 | self.angle = [] 139 | self.length = [] 140 | self.area = [] 141 | self.other = [] 142 | self.point_acc_list = [] 143 | 144 | # At prediction time, we use a beam search to find the most likely sequence of target tokens. 145 | beam_size = beam_size or 1 146 | self._max_decoding_steps = max_decoding_steps 147 | self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) 148 | 149 | # Dense embedding of source vocab tokens. 150 | self._source_embedder = source_embedder 151 | 152 | # Encodes the sequence of source embeddings into a sequence of hidden states. 153 | self._encoder = encoder 154 | 155 | num_classes = self.vocab.get_vocab_size(self._target_namespace) 156 | 157 | # Attention mechanism applied to the encoder output for each step. 158 | # TODO: attention 159 | if attention: 160 | if attention_function: 161 | raise ConfigurationError("You can only specify an attention module or an " 162 | "attention function, but not both.") 163 | self._attention = LegacyAttention() 164 | elif attention_function: 165 | self._attention = LegacyAttention(attention_function) 166 | else: 167 | self._attention = None 168 | print("No Attention!") 169 | exit() 170 | 171 | # Dense embedding of vocab words in the target space. 172 | target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim() 173 | self._target_embedder = Embedding(num_classes, target_embedding_dim) 174 | 175 | # Decoder output dim needs to be the same as the encoder output dim since we initialize the 176 | # hidden state of the decoder with the final hidden state of the encoder. 177 | self._encoder_output_dim = self._encoder.get_output_dim() 178 | self._decoder_output_dim = self._encoder_output_dim 179 | 180 | if self._attention: 181 | # If using attention, a weighted average over encoder outputs will be concatenated 182 | # to the previous target embedding to form the input to the decoder at each 183 | # time step. 184 | self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim 185 | else: 186 | # Otherwise, the input to the decoder is just the previous target embedding. 187 | self._decoder_input_dim = target_embedding_dim 188 | 189 | # We'll use an LSTM cell as the recurrent cell that produces a hidden state 190 | # for the decoder at each time step. 191 | self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) 192 | 193 | # We project the hidden state from the decoder into the output vocabulary space 194 | # in order to get log probabilities of each target token, at each time step. 195 | self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) 196 | 197 | # knowledge points 198 | self.point_ratio = knowledge_points_ratio 199 | if self.point_ratio != 0: 200 | self.points_norm = LayerNorm(__C.FLAT_OUT_SIZE) 201 | self.points_proj = nn.Linear(__C.FLAT_OUT_SIZE, 50) 202 | self.points_criterion = nn.BCELoss() 203 | 204 | def take_step(self, 205 | last_predictions: torch.Tensor, 206 | state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 207 | """ 208 | Take a decoding step. This is called by the beam search class. 209 | 210 | Parameters 211 | ---------- 212 | last_predictions : ``torch.Tensor`` 213 | A tensor of shape ``(group_size,)``, which gives the indices of the predictions 214 | during the last time step. 215 | state : ``Dict[str, torch.Tensor]`` 216 | A dictionary of tensors that contain the current state information 217 | needed to predict the next step, which includes the encoder outputs, 218 | the source mask, and the decoder hidden state and context. Each of these 219 | tensors has shape ``(group_size, *)``, where ``*`` can be any other number 220 | of dimensions. 221 | 222 | Returns 223 | ------- 224 | Tuple[torch.Tensor, Dict[str, torch.Tensor]] 225 | A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` 226 | is a tensor of shape ``(group_size, num_classes)`` containing the predicted 227 | log probability of each class for the next step, for each item in the group, 228 | while ``updated_state`` is a dictionary of tensors containing the encoder outputs, 229 | source mask, and updated decoder hidden state and context. 230 | 231 | Notes 232 | ----- 233 | We treat the inputs as a batch, even though ``group_size`` is not necessarily 234 | equal to ``batch_size``, since the group may contain multiple states 235 | for each source sentence in the batch. 236 | """ 237 | # shape: (group_size, num_classes) 238 | output_projections, state = self._prepare_output_projections(last_predictions, state) 239 | 240 | # shape: (group_size, num_classes) 241 | class_log_probabilities = F.log_softmax(output_projections, dim=-1) 242 | 243 | return class_log_probabilities, state 244 | 245 | @overrides 246 | def forward(self, # type: ignore 247 | image, source_nums, choice_nums, label, type, 248 | source_tokens: Dict[str, torch.LongTensor], 249 | point_label = None, 250 | target_tokens: Dict[str, torch.LongTensor] = None, **kwargs) -> Dict[str, torch.Tensor]: 251 | # pylint: disable=arguments-differ 252 | """ 253 | Make foward pass with decoder logic for producing the entire target sequence. 254 | 255 | Parameters 256 | ---------- 257 | source_tokens : ``Dict[str, torch.LongTensor]`` 258 | The output of `TextField.as_array()` applied on the source `TextField`. This will be 259 | passed through a `TextFieldEmbedder` and then through an encoder. 260 | target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) 261 | Output of `Textfield.as_array()` applied on target `TextField`. We assume that the 262 | target tokens are also represented as a `TextField`. 263 | 264 | Returns 265 | ------- 266 | Dict[str, torch.Tensor] 267 | """ 268 | bs = len(label) 269 | state = self._encode(source_tokens) 270 | 271 | with torch.no_grad(): 272 | img_feats = self.resnet(image) 273 | # (N, C, 14, 14) -> (N, 196, C) 274 | img_feats = img_feats.reshape(img_feats.shape[0], img_feats.shape[1], -1).transpose(1, 2) 275 | img_mask = make_mask(img_feats) 276 | img_feats = self.channel_transform(img_feats) 277 | 278 | lang_feats = state['encoder_outputs'] 279 | # mask the digital encoding question without embedding, i.e. source_tokens(already index to number) 280 | lang_mask = make_mask(source_tokens['tokens'].unsqueeze(2)) 281 | 282 | _, img_feats = self.mcan(lang_feats, img_feats, lang_mask, img_mask) 283 | 284 | # (N, 308, 512) 285 | # for attention, image first and then lang, using mask 286 | state['encoder_outputs'] = torch.cat([img_feats, lang_feats], 1) 287 | 288 | # decode 289 | state = self._init_decoder_state(state, lang_feats, img_feats, img_mask) 290 | output_dict = self._forward_loop(state, target_tokens) # recurrent decoding for LSTM 291 | 292 | # knowledge points 293 | if self.point_ratio != 0: 294 | concat_feature = state["concat_feature"] 295 | point_feat = self.points_norm(concat_feature) 296 | point_feat = self.points_proj(point_feat) 297 | point_pred = torch.sigmoid(point_feat) 298 | point_loss = self.points_criterion(point_pred, point_label) * self.point_ratio 299 | output_dict["point_pred"] = point_pred 300 | output_dict["point_loss"] = point_loss 301 | output_dict["loss"] += point_loss 302 | 303 | 304 | # TODO: if testing, beam search and evaluation 305 | if not self.training: 306 | state = self._init_decoder_state(state, lang_feats, img_feats, img_mask) # TODO 307 | predictions = self._forward_beam_search(state) 308 | output_dict.update(predictions) 309 | 310 | if target_tokens and self._bleu: 311 | # shape: (batch_size, beam_size, max_sequence_length) 312 | top_k_predictions = output_dict["predictions"] 313 | 314 | # execute the decode programs to calculate the accuracy 315 | # suc_knt, no_knt = 0, 0 316 | suc_knt, no_knt, = 0, 0 317 | 318 | selected_programs = [] 319 | for b in range(bs): 320 | hypo = None 321 | used_hypo = None 322 | choice = None 323 | for i in range(self._beam_search.beam_size): 324 | if choice is not None: 325 | break 326 | hypo = list(top_k_predictions[b][i]) 327 | if self._end_index in list(hypo): 328 | hypo = hypo[:hypo.index(self._end_index)] 329 | hypo = [self.vocab.get_token_from_index(idx.item()) for idx in hypo] 330 | # print(hypo) 331 | res = self._equ.excuate_equation(hypo, source_nums[b]) 332 | # print(res, choice_nums[b]) 333 | if res is not None and len(res) > 0: 334 | for j in range(4): 335 | if choice_nums[b][j] is not None and math.fabs(res[-1] - choice_nums[b][j]) < 0.001: 336 | choice = j 337 | used_hypo = hypo 338 | 339 | selected_programs.append([hypo]) 340 | 341 | if choice is None: 342 | no_knt += 1 343 | if choice == label[b]: 344 | suc_knt += 1 345 | 346 | if random.random() < 0.05: 347 | print('selected_programs', selected_programs) 348 | 349 | # calculate BLEU 350 | best_predictions = top_k_predictions[:, 0, :] 351 | self._bleu(best_predictions, target_tokens["tokens"]) 352 | self._acc(suc_knt / bs) 353 | self._no_result(no_knt / bs) 354 | 355 | 356 | return output_dict 357 | 358 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 359 | """ 360 | Finalize predictions. 361 | 362 | This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test 363 | time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives 364 | within the ``forward`` method. 365 | 366 | This method trims the output predictions to the first end symbol, replaces indices with 367 | corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. 368 | """ 369 | predicted_indices = output_dict["predictions"] 370 | if not isinstance(predicted_indices, numpy.ndarray): 371 | predicted_indices = predicted_indices.detach().cpu().numpy() 372 | all_predicted_tokens = [] 373 | for indices in predicted_indices: 374 | # Beam search gives us the top k results for each source sentence in the batch 375 | # but we just want the single best. 376 | if len(indices.shape) > 1: 377 | indices = indices[0] 378 | indices = list(indices) 379 | # Collect indices till the first end_symbol 380 | if self._end_index in indices: 381 | indices = indices[:indices.index(self._end_index)] 382 | predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace) 383 | for x in indices] 384 | all_predicted_tokens.append(predicted_tokens) 385 | output_dict["predicted_tokens"] = all_predicted_tokens 386 | return output_dict 387 | 388 | def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 389 | # shape: (batch_size, max_input_sequence_length, encoder_input_dim) 390 | embedded_input = self._source_embedder(source_tokens) 391 | # shape: (batch_size, max_input_sequence_length) 392 | source_mask = util.get_text_field_mask(source_tokens) 393 | 394 | # source mask are used in attention 395 | img_mask = torch.ones(source_mask.shape[0], 196).long().cuda() 396 | concat_mask = torch.cat([img_mask, source_mask], 1) 397 | 398 | # shape: (batch_size, max_input_sequence_length, encoder_output_dim) 399 | encoder_outputs = self._encoder(embedded_input, source_mask) 400 | return { 401 | "source_mask": source_mask, # source_mask, 402 | "concat_mask": concat_mask, 403 | "encoder_outputs": encoder_outputs, 404 | } 405 | 406 | def _init_decoder_state(self, state, lang_feats, img_feats, img_mask): 407 | 408 | batch_size = state["source_mask"].size(0) 409 | final_lang_feat = util.get_final_encoder_states( 410 | lang_feats, 411 | state["source_mask"], 412 | self._encoder.is_bidirectional()) 413 | img_feat = self.attflat_img(img_feats, img_mask) 414 | feat = torch.cat([final_lang_feat, img_feat], 1) 415 | feat = self.decode_transform(feat) 416 | state["concat_feature"] = feat 417 | 418 | state["decoder_hidden"] = feat 419 | # C0 shape: (batch_size, decoder_output_dim) 420 | state["decoder_context"] = torch.zeros(batch_size, self._decoder_output_dim).cuda() 421 | # state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim) 422 | return state 423 | 424 | def _forward_loop(self, 425 | state: Dict[str, torch.Tensor], 426 | target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: 427 | """ 428 | Make forward pass during training or do greedy search during prediction. 429 | 430 | Notes 431 | ----- 432 | We really only use the predictions from the method to test that beam search 433 | with a beam size of 1 gives the same results. 434 | """ 435 | # shape: (batch_size, max_input_sequence_length) 436 | source_mask = state["source_mask"] 437 | 438 | batch_size = source_mask.size()[0] 439 | 440 | if target_tokens: 441 | # shape: (batch_size, max_target_sequence_length) 442 | targets = target_tokens["tokens"] 443 | 444 | _, target_sequence_length = targets.size() 445 | 446 | # The last input from the target is either padding or the end symbol. 447 | # Either way, we don't have to process it. 448 | num_decoding_steps = target_sequence_length - 1 449 | else: 450 | num_decoding_steps = self._max_decoding_steps 451 | 452 | # Initialize target predictions with the start index. 453 | # shape: (batch_size,) 454 | last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index) 455 | 456 | step_logits: List[torch.Tensor] = [] 457 | step_predictions: List[torch.Tensor] = [] 458 | for timestep in range(num_decoding_steps): 459 | if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: 460 | # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio 461 | # during training. 462 | # shape: (batch_size,) 463 | input_choices = last_predictions 464 | elif not target_tokens: 465 | # shape: (batch_size,) 466 | input_choices = last_predictions 467 | else: 468 | # shape: (batch_size,) 469 | input_choices = targets[:, timestep] 470 | 471 | # shape: (batch_size, num_classes) 472 | # recurrent decoding 473 | output_projections, state = self._prepare_output_projections(input_choices, state) 474 | 475 | # list of tensors, shape: (batch_size, 1, num_classes) 476 | step_logits.append(output_projections.unsqueeze(1)) 477 | 478 | # shape: (batch_size, num_classes) 479 | class_probabilities = F.softmax(output_projections, dim=-1) 480 | 481 | # shape (predicted_classes): (batch_size,) 482 | _, predicted_classes = torch.max(class_probabilities, 1) 483 | 484 | # shape (predicted_classes): (batch_size,) 485 | last_predictions = predicted_classes 486 | 487 | step_predictions.append(last_predictions.unsqueeze(1)) 488 | 489 | # shape: (batch_size, num_decoding_steps) 490 | predictions = torch.cat(step_predictions, 1) 491 | 492 | output_dict = {"predictions": predictions} 493 | 494 | if target_tokens: 495 | # shape: (batch_size, num_decoding_steps, num_classes) 496 | logits = torch.cat(step_logits, 1) 497 | 498 | # Compute loss. 499 | target_mask = util.get_text_field_mask(target_tokens) 500 | loss = self._get_loss(logits, targets, target_mask) 501 | output_dict["loss"] = loss 502 | 503 | return output_dict 504 | 505 | def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 506 | """Make forward pass during prediction using a beam search.""" 507 | batch_size = state["source_mask"].size()[0] 508 | start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index) 509 | 510 | # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) 511 | # shape (log_probabilities): (batch_size, beam_size) 512 | all_top_k_predictions, log_probabilities = self._beam_search.search( 513 | start_predictions, state, self.take_step) 514 | 515 | output_dict = { 516 | "class_log_probabilities": log_probabilities, 517 | "predictions": all_top_k_predictions, 518 | } 519 | return output_dict 520 | 521 | def _prepare_output_projections(self, 522 | last_predictions: torch.Tensor, 523 | state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long 524 | """ 525 | Decode current state and last prediction to produce produce projections 526 | into the target space, which can then be used to get probabilities of 527 | each target token for the next step. 528 | 529 | Inputs are the same as for `take_step()`. 530 | """ 531 | # shape: (group_size, max_input_sequence_length, encoder_output_dim) 532 | encoder_outputs = state["encoder_outputs"] 533 | 534 | # shape: (group_size, max_input_sequence_length) 535 | # source_mask = state["source_mask"] 536 | source_mask = state["concat_mask"] 537 | 538 | # decoder_hidden and decoder_context are get from encoder_outputs in _init_decoder_state() 539 | # shape: (group_size, decoder_output_dim) 540 | decoder_hidden = state["decoder_hidden"] 541 | # shape: (group_size, decoder_output_dim) 542 | decoder_context = state["decoder_context"] 543 | 544 | # shape: (group_size, target_embedding_dim) 545 | embedded_input = self._target_embedder(last_predictions) 546 | 547 | if self._attention: 548 | # shape: (group_size, encoder_output_dim) 549 | attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask) 550 | 551 | # shape: (group_size, decoder_output_dim + target_embedding_dim) 552 | decoder_input = torch.cat((attended_input, embedded_input), -1) 553 | 554 | else: 555 | # shape: (group_size, target_embedding_dim) 556 | decoder_input = embedded_input 557 | 558 | # shape (decoder_hidden): (batch_size, decoder_output_dim) 559 | # shape (decoder_context): (batch_size, decoder_output_dim) 560 | decoder_hidden, decoder_context = self._decoder_cell( 561 | decoder_input, 562 | (decoder_hidden, decoder_context)) 563 | 564 | state["decoder_hidden"] = decoder_hidden 565 | state["decoder_context"] = decoder_context 566 | 567 | # shape: (group_size, num_classes) 568 | output_projections = self._output_projection_layer(decoder_hidden) 569 | 570 | return output_projections, state 571 | 572 | def _prepare_attended_input(self, 573 | decoder_hidden_state: torch.LongTensor = None, 574 | encoder_outputs: torch.LongTensor = None, 575 | encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: 576 | """Apply attention over encoder outputs and decoder state.""" 577 | # Ensure mask is also a FloatTensor. Or else the multiplication within 578 | # attention will complain. 579 | # shape: (batch_size, max_input_sequence_length) 580 | encoder_outputs_mask = encoder_outputs_mask.float() 581 | 582 | # shape: (batch_size, max_input_sequence_length) 583 | input_weights = self._attention( 584 | decoder_hidden_state, encoder_outputs, encoder_outputs_mask) 585 | 586 | # shape: (batch_size, encoder_output_dim) 587 | attended_input = util.weighted_sum(encoder_outputs, input_weights) 588 | 589 | return attended_input 590 | 591 | def multi_label_evaluation(self, input, target): 592 | one = torch.ones(target.shape).cuda() 593 | zero = torch.zeros(target.shape).cuda() 594 | res = torch.where(input > 0.5, one, zero) 595 | 596 | over = (res * target).sum(dim=1) 597 | union = res.sum(dim=1) + target.sum(dim=1) - over 598 | acc = over / union 599 | 600 | index = torch.isnan(acc) # nan appear when both pred and target are zeros, which means makes right answer 601 | acc_fix = torch.where(index, torch.ones(acc.shape).cuda(), acc) 602 | 603 | acc_sum = acc_fix.sum().item() 604 | 605 | return acc_sum 606 | 607 | @staticmethod 608 | def _get_loss(logits: torch.LongTensor, 609 | targets: torch.LongTensor, 610 | target_mask: torch.LongTensor) -> torch.Tensor: 611 | """ 612 | Compute loss. 613 | 614 | Takes logits (unnormalized outputs from the decoder) of size (batch_size, 615 | num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) 616 | and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross 617 | entropy loss while taking the mask into account. 618 | 619 | The length of ``targets`` is expected to be greater than that of ``logits`` because the 620 | decoder does not need to compute the output corresponding to the last timestep of 621 | ``targets``. This method aligns the inputs appropriately to compute the loss. 622 | 623 | During training, we want the logit corresponding to timestep i to be similar to the target 624 | token from timestep i + 1. That is, the targets should be shifted by one timestep for 625 | appropriate comparison. Consider a single example where the target has 3 words, and 626 | padding is to 7 tokens. 627 | The complete sequence would correspond to w1 w2 w3

628 | and the mask would be 1 1 1 1 1 0 0 629 | and let the logits be l1 l2 l3 l4 l5 l6 630 | We actually need to compare: 631 | the sequence w1 w2 w3

632 | with masks 1 1 1 1 0 0 633 | against l1 l2 l3 l4 l5 l6 634 | (where the input was) w1 w2 w3

635 | """ 636 | # shape: (batch_size, num_decoding_steps) 637 | relevant_targets = targets[:, 1:].contiguous() 638 | 639 | # shape: (batch_size, num_decoding_steps) 640 | relevant_mask = target_mask[:, 1:].contiguous() 641 | 642 | return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) 643 | 644 | @overrides 645 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 646 | all_metrics: Dict[str, float] = {} 647 | if self._bleu and not self.training: 648 | all_metrics.update(self._bleu.get_metric(reset=reset)) 649 | all_metrics.update({'acc': self._acc.get_metric(reset=reset)}) 650 | all_metrics.update({'no_result': self._no_result.get_metric(reset=reset)}) 651 | 652 | return all_metrics -------------------------------------------------------------------------------- /NGS_Aux_test.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import numpy 4 | from overrides import overrides 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | from torch.nn.modules.linear import Linear 9 | from torch.nn.modules.rnn import LSTMCell 10 | 11 | from allennlp.common.checks import ConfigurationError 12 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 13 | from allennlp.data.vocabulary import Vocabulary 14 | from allennlp.modules.attention import LegacyAttention 15 | from allennlp.modules import Attention, TextFieldEmbedder, Seq2SeqEncoder 16 | from allennlp.modules.similarity_functions import SimilarityFunction 17 | from allennlp.models.model import Model 18 | from allennlp.modules.token_embedders import Embedding 19 | from allennlp.nn import util 20 | from allennlp.nn.beam_search import BeamSearch 21 | from allennlp.training.metrics import BLEU 22 | 23 | from ManualProgram.eval_equ import Equations 24 | 25 | import random 26 | import warnings 27 | import math 28 | warnings.filterwarnings("ignore") 29 | 30 | from utils import * 31 | 32 | from mcan import * 33 | 34 | import json 35 | 36 | 37 | @Model.register("geo_s2s") 38 | class SimpleSeq2Seq(Model): 39 | """ 40 | This ``SimpleSeq2Seq`` class is a :class:`Model` which takes a sequence, encodes it, and then 41 | uses the encoded representations to decode another sequence. You can use this as the basis for 42 | a neural machine translation system, an abstractive summarization system, or any other common 43 | seq2seq problem. The model here is simple, but should be a decent starting place for 44 | implementing recent models for these tasks. 45 | 46 | Parameters 47 | ---------- 48 | vocab : ``Vocabulary``, required 49 | Vocabulary containing source and target vocabularies. They may be under the same namespace 50 | (`tokens`) or the target tokens can have a different namespace, in which case it needs to 51 | be specified as `target_namespace`. 52 | source_embedder : ``TextFieldEmbedder``, required 53 | Embedder for source side sequences 54 | encoder : ``Seq2SeqEncoder``, required 55 | The encoder of the "encoder/decoder" model 56 | max_decoding_steps : ``int`` 57 | Maximum length of decoded sequences. 58 | target_namespace : ``str``, optional (default = 'tokens') 59 | If the target side vocabulary is different from the source side's, you need to specify the 60 | target's namespace here. If not, we'll assume it is "tokens", which is also the default 61 | choice for the source side, and this might cause them to share vocabularies. 62 | target_embedding_dim : ``int``, optional (default = source_embedding_dim) 63 | You can specify an embedding dimensionality for the target side. If not, we'll use the same 64 | value as the source embedder's. 65 | attention : ``Attention``, optional (default = None) 66 | If you want to use attention to get a dynamic summary of the encoder outputs at each step 67 | of decoding, this is the function used to compute similarity between the decoder hidden 68 | state and encoder outputs. 69 | attention_function: ``SimilarityFunction``, optional (default = None) 70 | This is if you want to use the legacy implementation of attention. This will be deprecated 71 | since it consumes more memory than the specialized attention modules. 72 | beam_size : ``int``, optional (default = None) 73 | Width of the beam for beam search. If not specified, greedy decoding is used. 74 | scheduled_sampling_ratio : ``float``, optional (default = 0.) 75 | At each timestep during training, we sample a random number between 0 and 1, and if it is 76 | not less than this value, we use the ground truth labels for the whole batch. Else, we use 77 | the predictions from the previous time step for the whole batch. If this value is 0.0 78 | (default), this corresponds to teacher forcing, and if it is 1.0, it corresponds to not 79 | using target side ground truth labels. See the following paper for more information: 80 | `Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks. Bengio et al., 81 | 2015 `_. 82 | use_bleu : ``bool``, optional (default = True) 83 | If True, the BLEU metric will be calculated during validation. 84 | """ 85 | 86 | def __init__(self, 87 | vocab: Vocabulary, 88 | source_embedder: TextFieldEmbedder, 89 | encoder: Seq2SeqEncoder, 90 | max_decoding_steps: int, 91 | knowledge_points_ratio = 0, 92 | attention: Attention = True, 93 | attention_function: SimilarityFunction = None, 94 | beam_size: int = None, 95 | target_namespace: str = "tokens", 96 | target_embedding_dim: int = None, 97 | scheduled_sampling_ratio: float = 0., 98 | resnet_pretrained = None, 99 | use_bleu: bool = True) -> None: 100 | super(SimpleSeq2Seq, self).__init__(vocab) 101 | 102 | resnet = build_model() 103 | 104 | if resnet_pretrained is not None: 105 | resnet.load_state_dict(torch.load(resnet_pretrained)) 106 | print('##### Checkpoint Loaded! #####') 107 | else: 108 | print("No Diagram Pretrain !!!") 109 | self.resnet = resnet 110 | 111 | self.channel_transform = torch.nn.Linear(1024, 512) 112 | 113 | __C = Cfgs() 114 | self.mcan = MCA_ED(__C) 115 | self.attflat_img = AttFlat(__C) 116 | self.attflat_lang = AttFlat(__C) # not use 117 | 118 | self.decode_transform = torch.nn.Linear(1024, 512) 119 | 120 | self._equ = Equations() 121 | 122 | self._target_namespace = target_namespace 123 | self._scheduled_sampling_ratio = scheduled_sampling_ratio 124 | 125 | # We need the start symbol to provide as the input at the first timestep of decoding, and 126 | # end symbol as a way to indicate the end of the decoded sequence. 127 | self._start_index = self.vocab.get_token_index(START_SYMBOL, self._target_namespace) 128 | self._end_index = self.vocab.get_token_index(END_SYMBOL, self._target_namespace) 129 | 130 | if use_bleu: 131 | pad_index = self.vocab.get_token_index(self.vocab._padding_token, self._target_namespace) # pylint: disable=protected-access 132 | self._bleu = BLEU(ngram_weights=(1, 0, 0, 0), exclude_indices={pad_index, self._end_index, self._start_index}) 133 | else: 134 | self._bleu = None 135 | self._acc = Average() 136 | self._no_result = Average() 137 | 138 | # remember to clear after evaluation 139 | self.new_acc = [] 140 | self.angle = [] 141 | self.length = [] 142 | self.area = [] 143 | self.other = [] 144 | self.point_acc_list = [] 145 | self.save_results = dict() 146 | 147 | # At prediction time, we use a beam search to find the most likely sequence of target tokens. 148 | beam_size = beam_size or 1 149 | self._max_decoding_steps = max_decoding_steps 150 | self._beam_search = BeamSearch(self._end_index, max_steps=max_decoding_steps, beam_size=beam_size) 151 | 152 | # Dense embedding of source vocab tokens. 153 | self._source_embedder = source_embedder 154 | 155 | # Encodes the sequence of source embeddings into a sequence of hidden states. 156 | self._encoder = encoder 157 | 158 | num_classes = self.vocab.get_vocab_size(self._target_namespace) 159 | 160 | # Attention mechanism applied to the encoder output for each step. 161 | # TODO: attention 162 | if attention: 163 | if attention_function: 164 | raise ConfigurationError("You can only specify an attention module or an " 165 | "attention function, but not both.") 166 | self._attention = LegacyAttention() 167 | elif attention_function: 168 | self._attention = LegacyAttention(attention_function) 169 | else: 170 | self._attention = None 171 | print("No Attention!") 172 | exit() 173 | 174 | # Dense embedding of vocab words in the target space. 175 | target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim() 176 | self._target_embedder = Embedding(num_classes, target_embedding_dim) 177 | 178 | # Decoder output dim needs to be the same as the encoder output dim since we initialize the 179 | # hidden state of the decoder with the final hidden state of the encoder. 180 | self._encoder_output_dim = self._encoder.get_output_dim() 181 | self._decoder_output_dim = self._encoder_output_dim 182 | 183 | if self._attention: 184 | # If using attention, a weighted average over encoder outputs will be concatenated 185 | # to the previous target embedding to form the input to the decoder at each 186 | # time step. 187 | self._decoder_input_dim = self._decoder_output_dim + target_embedding_dim 188 | else: 189 | # Otherwise, the input to the decoder is just the previous target embedding. 190 | self._decoder_input_dim = target_embedding_dim 191 | 192 | # We'll use an LSTM cell as the recurrent cell that produces a hidden state 193 | # for the decoder at each time step. 194 | self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim) 195 | 196 | # We project the hidden state from the decoder into the output vocabulary space 197 | # in order to get log probabilities of each target token, at each time step. 198 | self._output_projection_layer = Linear(self._decoder_output_dim, num_classes) 199 | 200 | # knowledge points 201 | self.point_ratio = knowledge_points_ratio 202 | if self.point_ratio != 0: 203 | self.points_norm = LayerNorm(__C.FLAT_OUT_SIZE) 204 | self.points_proj = nn.Linear(__C.FLAT_OUT_SIZE, 50) 205 | self.points_criterion = nn.BCELoss() 206 | 207 | def take_step(self, 208 | last_predictions: torch.Tensor, 209 | state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 210 | """ 211 | Take a decoding step. This is called by the beam search class. 212 | 213 | Parameters 214 | ---------- 215 | last_predictions : ``torch.Tensor`` 216 | A tensor of shape ``(group_size,)``, which gives the indices of the predictions 217 | during the last time step. 218 | state : ``Dict[str, torch.Tensor]`` 219 | A dictionary of tensors that contain the current state information 220 | needed to predict the next step, which includes the encoder outputs, 221 | the source mask, and the decoder hidden state and context. Each of these 222 | tensors has shape ``(group_size, *)``, where ``*`` can be any other number 223 | of dimensions. 224 | 225 | Returns 226 | ------- 227 | Tuple[torch.Tensor, Dict[str, torch.Tensor]] 228 | A tuple of ``(log_probabilities, updated_state)``, where ``log_probabilities`` 229 | is a tensor of shape ``(group_size, num_classes)`` containing the predicted 230 | log probability of each class for the next step, for each item in the group, 231 | while ``updated_state`` is a dictionary of tensors containing the encoder outputs, 232 | source mask, and updated decoder hidden state and context. 233 | 234 | Notes 235 | ----- 236 | We treat the inputs as a batch, even though ``group_size`` is not necessarily 237 | equal to ``batch_size``, since the group may contain multiple states 238 | for each source sentence in the batch. 239 | """ 240 | # shape: (group_size, num_classes) 241 | output_projections, state = self._prepare_output_projections(last_predictions, state) 242 | 243 | # shape: (group_size, num_classes) 244 | class_log_probabilities = F.log_softmax(output_projections, dim=-1) 245 | 246 | return class_log_probabilities, state 247 | 248 | @overrides 249 | def forward(self, # type: ignore 250 | image, source_nums, choice_nums, label, type, data_id, manual_program, 251 | source_tokens: Dict[str, torch.LongTensor], 252 | point_label = None, 253 | target_tokens: Dict[str, torch.LongTensor] = None, **kwargs) -> Dict[str, torch.Tensor]: 254 | # pylint: disable=arguments-differ 255 | """ 256 | Make foward pass with decoder logic for producing the entire target sequence. 257 | 258 | Parameters 259 | ---------- 260 | source_tokens : ``Dict[str, torch.LongTensor]`` 261 | The output of `TextField.as_array()` applied on the source `TextField`. This will be 262 | passed through a `TextFieldEmbedder` and then through an encoder. 263 | target_tokens : ``Dict[str, torch.LongTensor]``, optional (default = None) 264 | Output of `Textfield.as_array()` applied on target `TextField`. We assume that the 265 | target tokens are also represented as a `TextField`. 266 | 267 | Returns 268 | ------- 269 | Dict[str, torch.Tensor] 270 | """ 271 | bs = len(label) 272 | state = self._encode(source_tokens) 273 | 274 | with torch.no_grad(): 275 | img_feats = self.resnet(image) 276 | # (N, C, 14, 14) -> (N, 196, C) 277 | img_feats = img_feats.reshape(img_feats.shape[0], img_feats.shape[1], -1).transpose(1, 2) 278 | img_mask = make_mask(img_feats) 279 | img_feats = self.channel_transform(img_feats) 280 | 281 | lang_feats = state['encoder_outputs'] 282 | # mask the digital encoding question without embedding, i.e. source_tokens(already index to number) 283 | lang_mask = make_mask(source_tokens['tokens'].unsqueeze(2)) 284 | 285 | _, img_feats = self.mcan(lang_feats, img_feats, lang_mask, img_mask) 286 | 287 | # (N, 308, 512) 288 | # for attention, image first and then lang, using mask 289 | state['encoder_outputs'] = torch.cat([img_feats, lang_feats], 1) 290 | 291 | # decode 292 | state = self._init_decoder_state(state, lang_feats, img_feats, img_mask) 293 | output_dict = self._forward_loop(state, target_tokens) # recurrent decoding for LSTM 294 | 295 | # knowledge points 296 | if self.point_ratio != 0: 297 | concat_feature = state["concat_feature"] 298 | point_feat = self.points_norm(concat_feature) 299 | point_feat = self.points_proj(point_feat) 300 | point_pred = torch.sigmoid(point_feat) 301 | point_loss = self.points_criterion(point_pred, point_label) * self.point_ratio 302 | output_dict["point_pred"] = point_pred 303 | output_dict["point_loss"] = point_loss 304 | output_dict["loss"] += point_loss 305 | 306 | 307 | # if testing, beam search and evaluation 308 | if not self.training: 309 | # state = self._init_decoder_state(state) 310 | state = self._init_decoder_state(state, lang_feats, img_feats, img_mask) # TODO 311 | predictions = self._forward_beam_search(state) 312 | output_dict.update(predictions) 313 | 314 | if target_tokens and self._bleu: 315 | # shape: (batch_size, beam_size, max_sequence_length) 316 | top_k_predictions = output_dict["predictions"] 317 | 318 | # execute the decode programs to calculate the accuracy 319 | suc_knt, no_knt, = 0, 0 320 | 321 | selected_programs = [] 322 | 323 | for b in range(bs): 324 | hypo = None 325 | used_hypo = None 326 | choice = None 327 | for i in range(self._beam_search.beam_size): 328 | if choice is not None: 329 | break 330 | hypo = list(top_k_predictions[b][i]) 331 | if self._end_index in list(hypo): 332 | hypo = hypo[:hypo.index(self._end_index)] 333 | hypo = [self.vocab.get_token_from_index(idx.item()) for idx in hypo] 334 | res = self._equ.excuate_equation(hypo, source_nums[b]) 335 | if res is not None and len(res) > 0: 336 | for j in range(4): 337 | if choice_nums[b][j] is not None and math.fabs(res[-1] - choice_nums[b][j]) < 0.001: 338 | choice = j 339 | used_hypo = hypo 340 | selected_programs.append([hypo]) 341 | 342 | if choice is None: 343 | no_knt += 1 344 | answer_state = 'no_result' 345 | self.new_acc.append(0) 346 | elif choice == label[b]: 347 | suc_knt += 1 348 | answer_state = 'right' 349 | self.new_acc.append(1) 350 | else: 351 | answer_state = 'false' 352 | self.new_acc.append(0) 353 | 354 | self.save_results[data_id[b]] = dict(manual_program=manual_program[b], 355 | predict_program=hypo, predict_res=res, 356 | choice=choice_nums[b], right_answer=label[b], 357 | answer_state=answer_state) 358 | 359 | flag = 1 if choice == label[b] else 0 360 | if type[b] == 'angle': 361 | self.angle.append(flag) 362 | elif type[b] == 'length': 363 | self.length.append(flag) 364 | else: 365 | self.other.append(flag) 366 | 367 | # knowledge points 368 | # if self.point_ratio != 0: 369 | # point_acc = self.multi_label_evaluation(point_pred[b].unsqueeze(0), point_label[b].unsqueeze(0)) 370 | # self.point_acc_list.append(point_acc) 371 | 372 | # with open('save/test.json', 'w') as f: 373 | # json.dump(self.save_results, f) 374 | 375 | if random.random() < 0.05: 376 | print('selected_programs', selected_programs) 377 | 378 | # calculate BLEU 379 | best_predictions = top_k_predictions[:, 0, :] 380 | self._bleu(best_predictions, target_tokens["tokens"]) 381 | self._acc(suc_knt / bs) 382 | self._no_result(no_knt / bs) 383 | 384 | return output_dict 385 | 386 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 387 | """ 388 | Finalize predictions. 389 | 390 | This method overrides ``Model.decode``, which gets called after ``Model.forward``, at test 391 | time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives 392 | within the ``forward`` method. 393 | 394 | This method trims the output predictions to the first end symbol, replaces indices with 395 | corresponding tokens, and adds a field called ``predicted_tokens`` to the ``output_dict``. 396 | """ 397 | predicted_indices = output_dict["predictions"] 398 | if not isinstance(predicted_indices, numpy.ndarray): 399 | predicted_indices = predicted_indices.detach().cpu().numpy() 400 | all_predicted_tokens = [] 401 | for indices in predicted_indices: 402 | # Beam search gives us the top k results for each source sentence in the batch 403 | # but we just want the single best. 404 | if len(indices.shape) > 1: 405 | indices = indices[0] 406 | indices = list(indices) 407 | # Collect indices till the first end_symbol 408 | if self._end_index in indices: 409 | indices = indices[:indices.index(self._end_index)] 410 | predicted_tokens = [self.vocab.get_token_from_index(x, namespace=self._target_namespace) 411 | for x in indices] 412 | all_predicted_tokens.append(predicted_tokens) 413 | output_dict["predicted_tokens"] = all_predicted_tokens 414 | return output_dict 415 | 416 | def _encode(self, source_tokens: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 417 | # shape: (batch_size, max_input_sequence_length, encoder_input_dim) 418 | embedded_input = self._source_embedder(source_tokens) 419 | # shape: (batch_size, max_input_sequence_length) 420 | source_mask = util.get_text_field_mask(source_tokens) 421 | 422 | img_mask = torch.ones(source_mask.shape[0], 196).long().cuda() 423 | concat_mask = torch.cat([img_mask, source_mask], 1) 424 | 425 | # shape: (batch_size, max_input_sequence_length, encoder_output_dim) 426 | encoder_outputs = self._encoder(embedded_input, source_mask) 427 | return { 428 | "source_mask": source_mask, # source_mask, 429 | "concat_mask": concat_mask, 430 | "encoder_outputs": encoder_outputs, 431 | } 432 | 433 | def _init_decoder_state(self, state, lang_feats, img_feats, img_mask): 434 | 435 | batch_size = state["source_mask"].size(0) 436 | final_lang_feat = util.get_final_encoder_states( 437 | lang_feats, 438 | state["source_mask"], 439 | self._encoder.is_bidirectional()) 440 | img_feat = self.attflat_img(img_feats, img_mask) 441 | feat = torch.cat([final_lang_feat, img_feat], 1) 442 | feat = self.decode_transform(feat) 443 | state["concat_feature"] = feat 444 | 445 | state["decoder_hidden"] = feat 446 | # C0 shape: (batch_size, decoder_output_dim) 447 | state["decoder_context"] = torch.zeros(batch_size, self._decoder_output_dim).cuda() 448 | # state["decoder_context"] = state["encoder_outputs"].new_zeros(batch_size, self._decoder_output_dim) 449 | return state 450 | 451 | def _forward_loop(self, 452 | state: Dict[str, torch.Tensor], 453 | target_tokens: Dict[str, torch.LongTensor] = None) -> Dict[str, torch.Tensor]: 454 | """ 455 | Make forward pass during training or do greedy search during prediction. 456 | 457 | Notes 458 | ----- 459 | We really only use the predictions from the method to test that beam search 460 | with a beam size of 1 gives the same results. 461 | """ 462 | # shape: (batch_size, max_input_sequence_length) 463 | source_mask = state["source_mask"] 464 | 465 | batch_size = source_mask.size()[0] 466 | 467 | if target_tokens: 468 | # shape: (batch_size, max_target_sequence_length) 469 | targets = target_tokens["tokens"] 470 | 471 | _, target_sequence_length = targets.size() 472 | 473 | # The last input from the target is either padding or the end symbol. 474 | # Either way, we don't have to process it. 475 | num_decoding_steps = target_sequence_length - 1 476 | else: 477 | num_decoding_steps = self._max_decoding_steps 478 | 479 | # Initialize target predictions with the start index. 480 | # shape: (batch_size,) 481 | last_predictions = source_mask.new_full((batch_size,), fill_value=self._start_index) 482 | 483 | step_logits: List[torch.Tensor] = [] 484 | step_predictions: List[torch.Tensor] = [] 485 | for timestep in range(num_decoding_steps): 486 | if self.training and torch.rand(1).item() < self._scheduled_sampling_ratio: 487 | # Use gold tokens at test time and at a rate of 1 - _scheduled_sampling_ratio 488 | # during training. 489 | # shape: (batch_size,) 490 | input_choices = last_predictions 491 | elif not target_tokens: 492 | # shape: (batch_size,) 493 | input_choices = last_predictions 494 | else: 495 | # shape: (batch_size,) 496 | input_choices = targets[:, timestep] 497 | 498 | # shape: (batch_size, num_classes) 499 | # recurrent decoding 500 | output_projections, state = self._prepare_output_projections(input_choices, state) 501 | 502 | # list of tensors, shape: (batch_size, 1, num_classes) 503 | step_logits.append(output_projections.unsqueeze(1)) 504 | 505 | # shape: (batch_size, num_classes) 506 | class_probabilities = F.softmax(output_projections, dim=-1) 507 | 508 | # shape (predicted_classes): (batch_size,) 509 | _, predicted_classes = torch.max(class_probabilities, 1) 510 | 511 | # shape (predicted_classes): (batch_size,) 512 | last_predictions = predicted_classes 513 | 514 | step_predictions.append(last_predictions.unsqueeze(1)) 515 | 516 | # shape: (batch_size, num_decoding_steps) 517 | predictions = torch.cat(step_predictions, 1) 518 | 519 | output_dict = {"predictions": predictions} 520 | 521 | if target_tokens: 522 | # shape: (batch_size, num_decoding_steps, num_classes) 523 | logits = torch.cat(step_logits, 1) 524 | 525 | # Compute loss. 526 | target_mask = util.get_text_field_mask(target_tokens) 527 | loss = self._get_loss(logits, targets, target_mask) 528 | output_dict["loss"] = loss 529 | 530 | return output_dict 531 | 532 | def _forward_beam_search(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 533 | """Make forward pass during prediction using a beam search.""" 534 | batch_size = state["source_mask"].size()[0] 535 | start_predictions = state["source_mask"].new_full((batch_size,), fill_value=self._start_index) 536 | 537 | # shape (all_top_k_predictions): (batch_size, beam_size, num_decoding_steps) 538 | # shape (log_probabilities): (batch_size, beam_size) 539 | all_top_k_predictions, log_probabilities = self._beam_search.search( 540 | start_predictions, state, self.take_step) 541 | 542 | output_dict = { 543 | "class_log_probabilities": log_probabilities, 544 | "predictions": all_top_k_predictions, 545 | } 546 | return output_dict 547 | 548 | def _prepare_output_projections(self, 549 | last_predictions: torch.Tensor, 550 | state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: # pylint: disable=line-too-long 551 | """ 552 | Decode current state and last prediction to produce produce projections 553 | into the target space, which can then be used to get probabilities of 554 | each target token for the next step. 555 | 556 | Inputs are the same as for `take_step()`. 557 | """ 558 | # shape: (group_size, max_input_sequence_length, encoder_output_dim) 559 | encoder_outputs = state["encoder_outputs"] 560 | 561 | # shape: (group_size, max_input_sequence_length) 562 | # source_mask = state["source_mask"] 563 | source_mask = state["concat_mask"] 564 | 565 | # decoder_hidden and decoder_context are get from encoder_outputs in _init_decoder_state() 566 | # shape: (group_size, decoder_output_dim) 567 | decoder_hidden = state["decoder_hidden"] 568 | # shape: (group_size, decoder_output_dim) 569 | decoder_context = state["decoder_context"] 570 | 571 | # shape: (group_size, target_embedding_dim) 572 | embedded_input = self._target_embedder(last_predictions) 573 | 574 | if self._attention: 575 | # shape: (group_size, encoder_output_dim) 576 | attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask) 577 | 578 | # shape: (group_size, decoder_output_dim + target_embedding_dim) 579 | decoder_input = torch.cat((attended_input, embedded_input), -1) 580 | 581 | else: 582 | # shape: (group_size, target_embedding_dim) 583 | decoder_input = embedded_input 584 | 585 | # shape (decoder_hidden): (batch_size, decoder_output_dim) 586 | # shape (decoder_context): (batch_size, decoder_output_dim) 587 | decoder_hidden, decoder_context = self._decoder_cell( 588 | decoder_input, 589 | (decoder_hidden, decoder_context)) 590 | 591 | state["decoder_hidden"] = decoder_hidden 592 | state["decoder_context"] = decoder_context 593 | 594 | # shape: (group_size, num_classes) 595 | output_projections = self._output_projection_layer(decoder_hidden) 596 | 597 | return output_projections, state 598 | 599 | def _prepare_attended_input(self, 600 | decoder_hidden_state: torch.LongTensor = None, 601 | encoder_outputs: torch.LongTensor = None, 602 | encoder_outputs_mask: torch.LongTensor = None) -> torch.Tensor: 603 | """Apply attention over encoder outputs and decoder state.""" 604 | # Ensure mask is also a FloatTensor. Or else the multiplication within 605 | # attention will complain. 606 | # shape: (batch_size, max_input_sequence_length) 607 | encoder_outputs_mask = encoder_outputs_mask.float() 608 | 609 | # shape: (batch_size, max_input_sequence_length) 610 | input_weights = self._attention( 611 | decoder_hidden_state, encoder_outputs, encoder_outputs_mask) 612 | 613 | # shape: (batch_size, encoder_output_dim) 614 | attended_input = util.weighted_sum(encoder_outputs, input_weights) 615 | 616 | return attended_input 617 | 618 | def multi_label_evaluation(self, input, target): 619 | one = torch.ones(target.shape).cuda() 620 | zero = torch.zeros(target.shape).cuda() 621 | res = torch.where(input > 0.5, one, zero) 622 | 623 | over = (res * target).sum(dim=1) 624 | union = res.sum(dim=1) + target.sum(dim=1) - over 625 | acc = over / union 626 | 627 | index = torch.isnan(acc) # nan appear when both pred and target are zeros, which means makes right answer 628 | acc_fix = torch.where(index, torch.ones(acc.shape).cuda(), acc) 629 | 630 | acc_sum = acc_fix.sum().item() 631 | 632 | return acc_sum 633 | 634 | @staticmethod 635 | def _get_loss(logits: torch.LongTensor, 636 | targets: torch.LongTensor, 637 | target_mask: torch.LongTensor) -> torch.Tensor: 638 | """ 639 | Compute loss. 640 | 641 | Takes logits (unnormalized outputs from the decoder) of size (batch_size, 642 | num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1) 643 | and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross 644 | entropy loss while taking the mask into account. 645 | 646 | The length of ``targets`` is expected to be greater than that of ``logits`` because the 647 | decoder does not need to compute the output corresponding to the last timestep of 648 | ``targets``. This method aligns the inputs appropriately to compute the loss. 649 | 650 | During training, we want the logit corresponding to timestep i to be similar to the target 651 | token from timestep i + 1. That is, the targets should be shifted by one timestep for 652 | appropriate comparison. Consider a single example where the target has 3 words, and 653 | padding is to 7 tokens. 654 | The complete sequence would correspond to w1 w2 w3

655 | and the mask would be 1 1 1 1 1 0 0 656 | and let the logits be l1 l2 l3 l4 l5 l6 657 | We actually need to compare: 658 | the sequence w1 w2 w3

659 | with masks 1 1 1 1 0 0 660 | against l1 l2 l3 l4 l5 l6 661 | (where the input was) w1 w2 w3

662 | """ 663 | # shape: (batch_size, num_decoding_steps) 664 | relevant_targets = targets[:, 1:].contiguous() 665 | 666 | # shape: (batch_size, num_decoding_steps) 667 | relevant_mask = target_mask[:, 1:].contiguous() 668 | 669 | return util.sequence_cross_entropy_with_logits(logits, relevant_targets, relevant_mask) 670 | 671 | @overrides 672 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 673 | all_metrics: Dict[str, float] = {} 674 | if self._bleu and not self.training: 675 | all_metrics.update(self._bleu.get_metric(reset=reset)) 676 | # all_metrics.update({'acc': self._acc.get_metric(reset=reset)}) 677 | all_metrics.update({'acc': self._acc.get_metric(reset=reset)}) 678 | if len(self.new_acc) != 0: 679 | all_metrics.update({'new_acc': sum(self.new_acc)/len(self.new_acc)}) 680 | print('Num of total, angle, len, other', len(self.new_acc), len(self.angle), len(self.length), len(self.other)) 681 | if len(self.angle) != 0: 682 | all_metrics.update({'angle_acc': sum(self.angle)/len(self.angle)}) 683 | if len(self.length) != 0: 684 | all_metrics.update({'length_acc': sum(self.length)/len(self.length)}) 685 | if len(self.other) != 0: 686 | all_metrics.update({'other_acc': sum(self.other)/len(self.other)}) 687 | all_metrics.update({'no_result': self._no_result.get_metric(reset=reset)}) 688 | 689 | # if len(self.point_acc_list) != 0: 690 | # all_metrics.update({'point_acc': sum(self.point_acc_list) / len(self.point_acc_list)}) 691 | 692 | return all_metrics 693 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # GeoQA 3 | 4 | Jiaqi Chen, Jianheng Tang, Jinghui Qin, Xiaodan Liang, Lingbo Liu, Eric P. Xing, Liang Lin. “GeoQA: A Geometric Question Answering Benchmark Towards Multimodal Numerical Reasoning”. Annual Meeting of the Association for Computational Linguistics (Findings of ACL 2021). [PDF] 5 | 6 | Download GeoQA benchmark: [Google Drive] 7 | 8 | 01/10/2022 Update: We find some minor errors in data annotation. These errors have been fixed, please download the latest GeoQA benchmark. Besides, both arXiv paper and this project have also been updated. 9 | 10 | ## Environment 11 | python=3.6 12 | 13 | allennlp==0.9.0 14 | 15 | Document for allennlp 16 | ## Usage of NGS-Auxiliary 17 | 18 | 19 | ### Preparing 20 | 21 | git clone https://github.com/chen-judge/GeoQA.git 22 | 23 | cd GeoQA 24 | 25 | pip install -r requirements.txt 26 | 27 | Download the data.zip, move it to GeoQA path, and unzip it. 28 | 29 | 30 | ### Training 31 | 32 | allennlp train config/NGS_Aux.json --include-package NGS_Aux -s save/test 33 | 34 | ### Evaluation 35 | Evaluate your trained model: 36 | 37 | allennlp evaluate save/test data/GeoQA3/test.pk --include-package NGS_Aux_test --cuda-device 0 38 | 39 | Or, you can use our checkpoint NGS_Aux_CKPT.zip, move it to save path, unzip it, and run: 40 | 41 | allennlp evaluate save/NGS_Aux_CKPT data/GeoQA3/test.pk --include-package NGS_Aux_test --cuda-device 0 42 | 43 | The result of our checkpoint should be: 44 | 45 | | Method | Acc | Angle | Length | Other | 46 | | --- | --- | --- | --- |--- | 47 | | NGS-Auxiliary | 60.0 | 71.5 | 48.8 | 29.6 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /config/NGS_Aux.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "s2s_manual_reader", 4 | "tokenizer": { 5 | "word_splitter":{ 6 | "type": "just_spaces" 7 | } 8 | }, 9 | "source_token_indexer": { 10 | "tokens": { 11 | "type": "single_id" 12 | } 13 | }, 14 | "target_token_indexer": { 15 | "tokens": { 16 | "type": "single_id" 17 | } 18 | }, 19 | }, 20 | "train_data_path": "data/GeoQA3/train.pk", 21 | "validation_data_path": "data/GeoQA3/dev.pk", 22 | "test_data_path" : "data/GeoQA3/test.pk", 23 | "evaluate_on_test": true, 24 | "model": { 25 | "type": "geo_s2s", 26 | "max_decoding_steps": 16, 27 | "beam_size": 10, 28 | "target_embedding_dim": 512, 29 | "scheduled_sampling_ratio": 0, 30 | "resnet_pretrained": "data/pretrain/best_jigsaw_model_state_dict", 31 | "knowledge_points_ratio": 1, 32 | "source_embedder": { 33 | "token_embedders": { 34 | "tokens": { 35 | "type": "embedding", 36 | "embedding_dim": 256, 37 | "trainable": true 38 | } 39 | } 40 | }, 41 | "encoder": { 42 | "type": "lstm", 43 | "bidirectional": false, 44 | "input_size": 256, 45 | "hidden_size": 512, 46 | "num_layers": 1 47 | }, 48 | }, 49 | "iterator": { 50 | "type": "basic", 51 | "batch_size": 32 52 | }, 53 | "trainer": { 54 | "validation_metric": "+acc", 55 | "num_epochs": 100, 56 | "patience": 100, 57 | "grad_norm": 10.0, 58 | "cuda_device": 0, 59 | "optimizer": { 60 | "type": "adam", 61 | "lr": 1e-3, 62 | "parameter_groups": [ 63 | [["mcan", "channel_transform", "attflat_img", "attflat_lang", "decode_transform"], {"lr": 1e-5}], 64 | [["resnet"], {"lr": 0}] 65 | ] 66 | } 67 | } 68 | } 69 | -------------------------------------------------------------------------------- /mcan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import torchvision 5 | import math 6 | 7 | 8 | def build_model(): 9 | cnn = getattr(torchvision.models, 'resnet101')(pretrained=True) 10 | layers = [cnn.conv1, 11 | cnn.bn1, 12 | cnn.relu, 13 | cnn.maxpool] 14 | for i in range(3): 15 | name = 'layer%d' % (i + 1) 16 | layers.append(getattr(cnn, name)) 17 | model = torch.nn.Sequential(*layers) 18 | model.cuda() 19 | model.eval() 20 | return model 21 | 22 | 23 | def make_mask(feature): 24 | return (torch.sum( 25 | torch.abs(feature), 26 | dim=-1 27 | ) == 0).unsqueeze(1).unsqueeze(2) 28 | 29 | 30 | class Cfgs: 31 | def __init__(self): 32 | super(Cfgs, self).__init__() 33 | 34 | self.LAYER = 6 35 | self.HIDDEN_SIZE = 512 36 | self.BBOXFEAT_EMB_SIZE = 2048 37 | self.FF_SIZE = 2048 38 | self.MULTI_HEAD = 8 39 | self.DROPOUT_R = 0.1 40 | self.FLAT_MLP_SIZE = 512 41 | self.FLAT_GLIMPSES = 1 42 | # self.FLAT_OUT_SIZE = 1024 43 | self.FLAT_OUT_SIZE = 512 44 | self.USE_AUX_FEAT = False 45 | self.USE_BBOX_FEAT = False 46 | 47 | 48 | class MCA_ED(nn.Module): 49 | def __init__(self, __C): 50 | super(MCA_ED, self).__init__() 51 | self.enc_list = nn.ModuleList([SA(__C) for _ in range(__C.LAYER)]) 52 | self.dec_list = nn.ModuleList([SGA(__C) for _ in range(__C.LAYER)]) 53 | 54 | def forward(self, lang, image, lang_mask, image_mask): # lang, image 55 | for enc in self.enc_list: 56 | lang = enc(lang, lang_mask) 57 | 58 | for dec in self.dec_list: 59 | image = dec(image, lang, image_mask, lang_mask) 60 | 61 | return lang, image 62 | 63 | 64 | class SA(nn.Module): 65 | def __init__(self, __C): 66 | super(SA, self).__init__() 67 | 68 | self.mhatt = MHAtt(__C) 69 | self.ffn = FFN(__C) 70 | 71 | self.dropout1 = nn.Dropout(__C.DROPOUT_R) 72 | self.norm1 = LayerNorm(__C.HIDDEN_SIZE) 73 | 74 | self.dropout2 = nn.Dropout(__C.DROPOUT_R) 75 | self.norm2 = LayerNorm(__C.HIDDEN_SIZE) 76 | 77 | def forward(self, y, y_mask): 78 | y = self.norm1(y + self.dropout1( 79 | self.mhatt(y, y, y, y_mask) 80 | )) 81 | 82 | y = self.norm2(y + self.dropout2( 83 | self.ffn(y) 84 | )) 85 | 86 | return y 87 | 88 | 89 | class SGA(nn.Module): 90 | def __init__(self, __C): 91 | super(SGA, self).__init__() 92 | 93 | self.mhatt1 = MHAtt(__C) 94 | self.mhatt2 = MHAtt(__C) 95 | self.ffn = FFN(__C) 96 | 97 | self.dropout1 = nn.Dropout(__C.DROPOUT_R) 98 | self.norm1 = LayerNorm(__C.HIDDEN_SIZE) 99 | 100 | self.dropout2 = nn.Dropout(__C.DROPOUT_R) 101 | self.norm2 = LayerNorm(__C.HIDDEN_SIZE) 102 | 103 | self.dropout3 = nn.Dropout(__C.DROPOUT_R) 104 | self.norm3 = LayerNorm(__C.HIDDEN_SIZE) 105 | 106 | def forward(self, x, y, x_mask, y_mask): 107 | x = self.norm1(x + self.dropout1( 108 | self.mhatt1(v=x, k=x, q=x, mask=x_mask) 109 | )) 110 | 111 | x = self.norm2(x + self.dropout2( 112 | self.mhatt2(v=y, k=y, q=x, mask=y_mask) 113 | )) 114 | 115 | x = self.norm3(x + self.dropout3( 116 | self.ffn(x) 117 | )) 118 | 119 | return x 120 | 121 | 122 | class LayerNorm(nn.Module): 123 | def __init__(self, size, eps=1e-6): 124 | super(LayerNorm, self).__init__() 125 | self.eps = eps 126 | 127 | self.a_2 = nn.Parameter(torch.ones(size)) 128 | self.b_2 = nn.Parameter(torch.zeros(size)) 129 | 130 | def forward(self, x): 131 | mean = x.mean(-1, keepdim=True) 132 | std = x.std(-1, keepdim=True) 133 | 134 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 135 | 136 | 137 | class MHAtt(nn.Module): 138 | def __init__(self, __C): 139 | super(MHAtt, self).__init__() 140 | self.__C = __C 141 | 142 | self.linear_v = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 143 | self.linear_k = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 144 | self.linear_q = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 145 | self.linear_merge = nn.Linear(__C.HIDDEN_SIZE, __C.HIDDEN_SIZE) 146 | 147 | self.dropout = nn.Dropout(__C.DROPOUT_R) 148 | 149 | def forward(self, v, k, q, mask): 150 | n_batches = q.size(0) 151 | 152 | v = self.linear_v(v).view( 153 | n_batches, 154 | -1, 155 | self.__C.MULTI_HEAD, 156 | int(self.__C.HIDDEN_SIZE / self.__C.MULTI_HEAD) 157 | ).transpose(1, 2) 158 | 159 | k = self.linear_k(k).view( 160 | n_batches, 161 | -1, 162 | self.__C.MULTI_HEAD, 163 | int(self.__C.HIDDEN_SIZE / self.__C.MULTI_HEAD) 164 | ).transpose(1, 2) 165 | 166 | q = self.linear_q(q).view( 167 | n_batches, 168 | -1, 169 | self.__C.MULTI_HEAD, 170 | int(self.__C.HIDDEN_SIZE / self.__C.MULTI_HEAD) 171 | ).transpose(1, 2) 172 | 173 | atted = self.att(v, k, q, mask) 174 | atted = atted.transpose(1, 2).contiguous().view( 175 | n_batches, 176 | -1, 177 | self.__C.HIDDEN_SIZE 178 | ) 179 | 180 | atted = self.linear_merge(atted) 181 | 182 | return atted 183 | 184 | def att(self, value, key, query, mask): 185 | d_k = query.size(-1) 186 | 187 | scores = torch.matmul( 188 | query, key.transpose(-2, -1) 189 | ) / math.sqrt(d_k) 190 | 191 | if mask is not None: 192 | scores = scores.masked_fill(mask, -1e9) 193 | 194 | att_map = F.softmax(scores, dim=-1) 195 | att_map = self.dropout(att_map) 196 | 197 | return torch.matmul(att_map, value) 198 | 199 | 200 | class FFN(nn.Module): 201 | def __init__(self, __C): 202 | super(FFN, self).__init__() 203 | 204 | self.mlp = MLP( 205 | in_size=__C.HIDDEN_SIZE, 206 | mid_size=__C.FF_SIZE, 207 | out_size=__C.HIDDEN_SIZE, 208 | dropout_r=__C.DROPOUT_R, 209 | use_relu=True 210 | ) 211 | 212 | def forward(self, x): 213 | return self.mlp(x) 214 | 215 | 216 | class MLP(nn.Module): 217 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True): 218 | super(MLP, self).__init__() 219 | 220 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu) 221 | self.linear = nn.Linear(mid_size, out_size) 222 | 223 | def forward(self, x): 224 | return self.linear(self.fc(x)) 225 | 226 | 227 | class FC(nn.Module): 228 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True): 229 | super(FC, self).__init__() 230 | self.dropout_r = dropout_r 231 | self.use_relu = use_relu 232 | 233 | self.linear = nn.Linear(in_size, out_size) 234 | 235 | if use_relu: 236 | self.relu = nn.ReLU(inplace=True) 237 | 238 | if dropout_r > 0: 239 | self.dropout = nn.Dropout(dropout_r) 240 | 241 | def forward(self, x): 242 | x = self.linear(x) 243 | 244 | if self.use_relu: 245 | x = self.relu(x) 246 | 247 | if self.dropout_r > 0: 248 | x = self.dropout(x) 249 | 250 | return x 251 | 252 | 253 | class AttFlat(nn.Module): 254 | def __init__(self, __C): 255 | super(AttFlat, self).__init__() 256 | self.__C = __C 257 | 258 | self.mlp = MLP( 259 | in_size=__C.HIDDEN_SIZE, 260 | mid_size=__C.FLAT_MLP_SIZE, 261 | out_size=__C.FLAT_GLIMPSES, 262 | dropout_r=__C.DROPOUT_R, 263 | use_relu=True 264 | ) 265 | 266 | self.linear_merge = nn.Linear( 267 | __C.HIDDEN_SIZE * __C.FLAT_GLIMPSES, 268 | __C.FLAT_OUT_SIZE 269 | ) 270 | 271 | def forward(self, x, x_mask): 272 | att = self.mlp(x) 273 | att = att.masked_fill( 274 | x_mask.squeeze(1).squeeze(1).unsqueeze(2), 275 | -1e9 276 | ) 277 | att = F.softmax(att, dim=1) 278 | 279 | att_list = [] 280 | for i in range(self.__C.FLAT_GLIMPSES): 281 | att_list.append( 282 | torch.sum(att[:, :, i: i + 1] * x, dim=1) 283 | ) 284 | 285 | x_atted = torch.cat(att_list, dim=1) 286 | x_atted = self.linear_merge(x_atted) 287 | 288 | return x_atted -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | alabaster==0.7.12 2 | allennlp==0.9.0 3 | attrs==21.2.0 4 | Babel==2.9.1 5 | blis==0.2.4 6 | boto3==1.17.91 7 | botocore==1.20.91 8 | cached-property==1.5.2 9 | certifi==2021.5.30 10 | chardet==4.0.0 11 | click==8.0.1 12 | conllu==1.3.1 13 | cycler==0.10.0 14 | cymem==2.0.5 15 | Cython==0.29.23 16 | dataclasses==0.8 17 | distro==1.5.0 18 | docutils==0.17.1 19 | editdistance==0.5.3 20 | flaky==3.7.0 21 | Flask==2.0.1 22 | Flask-Cors==3.0.10 23 | ftfy==6.0.3 24 | gevent==21.1.2 25 | greenlet==1.1.0 26 | h5py==3.1.0 27 | idna==2.10 28 | imagesize==1.2.0 29 | importlib-metadata==4.5.0 30 | iniconfig==1.1.1 31 | itsdangerous==2.0.1 32 | jieba==0.42.1 33 | Jinja2==3.0.1 34 | jmespath==0.10.0 35 | joblib==1.0.1 36 | jsonnet==0.17.0 37 | jsonpickle==2.0.0 38 | kiwisolver==1.3.1 39 | MarkupSafe==2.0.1 40 | matplotlib==3.3.4 41 | murmurhash==1.0.5 42 | nltk==3.6.2 43 | numpy==1.19.5 44 | numpydoc==1.1.0 45 | opencv-python==4.2.0.32 46 | overrides==3.1.0 47 | packaging==20.9 48 | parsimonious==0.8.1 49 | Pillow==8.2.0 50 | plac==0.9.6 51 | pluggy==0.13.1 52 | preshed==2.0.1 53 | protobuf==3.17.3 54 | py==1.10.0 55 | Pygments==2.9.0 56 | pyparsing==2.4.7 57 | pytest==6.2.4 58 | python-dateutil==2.8.1 59 | pytorch-pretrained-bert==0.6.2 60 | pytorch-transformers==1.1.0 61 | pytz==2021.1 62 | regex==2021.4.4 63 | requests==2.25.1 64 | responses==0.13.3 65 | s3transfer==0.4.2 66 | scikit-build==0.11.1 67 | scikit-learn==0.24.2 68 | scipy==1.5.4 69 | sentencepiece==0.1.95 70 | six==1.16.0 71 | snowballstemmer==2.1.0 72 | spacy==2.1.9 73 | Sphinx==4.0.2 74 | sphinxcontrib-applehelp==1.0.2 75 | sphinxcontrib-devhelp==1.0.2 76 | sphinxcontrib-htmlhelp==2.0.0 77 | sphinxcontrib-jsmath==1.0.1 78 | sphinxcontrib-qthelp==1.0.3 79 | sphinxcontrib-serializinghtml==1.1.5 80 | sqlparse==0.4.1 81 | srsly==1.0.5 82 | tensorboardX==2.2 83 | thinc==7.0.8 84 | threadpoolctl==2.1.0 85 | toml==0.10.2 86 | torch==1.2.0 87 | torchvision==0.4.0 88 | tqdm==4.61.0 89 | typing-extensions==3.10.0.0 90 | typing-utils==0.1.0 91 | Unidecode==1.2.0 92 | urllib3==1.26.5 93 | wasabi==0.8.2 94 | wcwidth==0.2.5 95 | Werkzeug==2.0.1 96 | word2number==1.1 97 | zipp==3.4.1 98 | zope.event==4.5.0 99 | zope.interface==5.4.0 100 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 7 | 'wide_resnet50_2', 'wide_resnet101_2'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=dilation, groups=groups, bias=False, dilation=dilation) 14 | 15 | 16 | def conv1x1(in_planes, out_planes, stride=1): 17 | """1x1 convolution""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 19 | 20 | 21 | class BasicBlock(nn.Module): 22 | expansion = 1 23 | __constants__ = ['downsample'] 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 26 | base_width=64, dilation=1, norm_layer=None): 27 | super(BasicBlock, self).__init__() 28 | if norm_layer is None: 29 | norm_layer = nn.BatchNorm2d 30 | if groups != 1 or base_width != 64: 31 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 32 | if dilation > 1: 33 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 34 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = norm_layer(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = norm_layer(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | identity = x 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | identity = self.downsample(x) 53 | 54 | out += identity 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | __constants__ = ['downsample'] 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 65 | base_width=64, dilation=1, norm_layer=None): 66 | super(Bottleneck, self).__init__() 67 | if norm_layer is None: 68 | norm_layer = nn.BatchNorm2d 69 | width = int(planes * (base_width / 64.)) * groups 70 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 71 | self.conv1 = conv1x1(inplanes, width) 72 | self.bn1 = norm_layer(width) 73 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 74 | self.bn2 = norm_layer(width) 75 | self.conv3 = conv1x1(width, planes * self.expansion) 76 | self.bn3 = norm_layer(planes * self.expansion) 77 | self.relu = nn.ReLU(inplace=True) 78 | self.downsample = downsample 79 | self.stride = stride 80 | 81 | def forward(self, x): 82 | identity = x 83 | 84 | out = self.conv1(x) 85 | out = self.bn1(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv2(out) 89 | out = self.bn2(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv3(out) 93 | out = self.bn3(out) 94 | 95 | if self.downsample is not None: 96 | identity = self.downsample(x) 97 | 98 | out += identity 99 | out = self.relu(out) 100 | 101 | return out 102 | 103 | 104 | class ResNet(nn.Module): 105 | 106 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 107 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 108 | norm_layer=None): 109 | super(ResNet, self).__init__() 110 | if norm_layer is None: 111 | norm_layer = nn.BatchNorm2d 112 | self._norm_layer = norm_layer 113 | 114 | self.inplanes = 64 115 | self.dilation = 1 116 | if replace_stride_with_dilation is None: 117 | # each element in the tuple indicates if we should replace 118 | # the 2x2 stride with a dilated convolution instead 119 | replace_stride_with_dilation = [False, False, False] 120 | if len(replace_stride_with_dilation) != 3: 121 | raise ValueError("replace_stride_with_dilation should be None " 122 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 123 | self.groups = groups 124 | self.base_width = width_per_group 125 | self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3, 126 | bias=False) 127 | self.bn1 = norm_layer(self.inplanes) 128 | self.relu = nn.ReLU(inplace=True) 129 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 130 | self.layer1 = self._make_layer(block, 64, layers[0]) 131 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 132 | dilate=replace_stride_with_dilation[0]) 133 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 134 | dilate=replace_stride_with_dilation[1]) 135 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 136 | dilate=replace_stride_with_dilation[2]) 137 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 138 | self.fc = nn.Linear(512 * block.expansion, num_classes) 139 | 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 143 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 144 | nn.init.constant_(m.weight, 1) 145 | nn.init.constant_(m.bias, 0) 146 | 147 | # Zero-initialize the last BN in each residual branch, 148 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 149 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 150 | if zero_init_residual: 151 | for m in self.modules(): 152 | if isinstance(m, Bottleneck): 153 | nn.init.constant_(m.bn3.weight, 0) 154 | elif isinstance(m, BasicBlock): 155 | nn.init.constant_(m.bn2.weight, 0) 156 | 157 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 158 | norm_layer = self._norm_layer 159 | downsample = None 160 | previous_dilation = self.dilation 161 | if dilate: 162 | self.dilation *= stride 163 | stride = 1 164 | if stride != 1 or self.inplanes != planes * block.expansion: 165 | downsample = nn.Sequential( 166 | conv1x1(self.inplanes, planes * block.expansion, stride), 167 | norm_layer(planes * block.expansion), 168 | ) 169 | 170 | layers = [] 171 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 172 | self.base_width, previous_dilation, norm_layer)) 173 | self.inplanes = planes * block.expansion 174 | for _ in range(1, blocks): 175 | layers.append(block(self.inplanes, planes, groups=self.groups, 176 | base_width=self.base_width, dilation=self.dilation, 177 | norm_layer=norm_layer)) 178 | 179 | return nn.Sequential(*layers) 180 | 181 | def _forward_impl(self, x): 182 | # See note [TorchScript super()] 183 | x = self.conv1(x) 184 | x = self.bn1(x) 185 | x = self.relu(x) 186 | x = self.maxpool(x) 187 | 188 | x = self.layer1(x) 189 | x = self.layer2(x) 190 | x = self.layer3(x) 191 | x = self.layer4(x) 192 | 193 | x = self.avgpool(x) 194 | x = torch.flatten(x, 1) 195 | x = self.fc(x) 196 | 197 | return x 198 | 199 | def forward(self, x): 200 | return self._forward_impl(x) 201 | 202 | 203 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 204 | model = ResNet(block, layers, **kwargs) 205 | return model 206 | 207 | 208 | def resnet18(pretrained=False, progress=True, **kwargs): 209 | r"""ResNet-18 model from 210 | `"Deep Residual Learning for Image Recognition" `_ 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | progress (bool): If True, displays a progress bar of the download to stderr 215 | """ 216 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 217 | **kwargs) 218 | 219 | 220 | def resnet34(pretrained=False, progress=True, **kwargs): 221 | r"""ResNet-34 model from 222 | `"Deep Residual Learning for Image Recognition" `_ 223 | 224 | Args: 225 | pretrained (bool): If True, returns a model pre-trained on ImageNet 226 | progress (bool): If True, displays a progress bar of the download to stderr 227 | """ 228 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 229 | **kwargs) 230 | 231 | 232 | def resnet50(pretrained=False, progress=True, **kwargs): 233 | r"""ResNet-50 model from 234 | `"Deep Residual Learning for Image Recognition" `_ 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | progress (bool): If True, displays a progress bar of the download to stderr 239 | """ 240 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 241 | **kwargs) 242 | 243 | 244 | def resnet101(pretrained=False, progress=True, **kwargs): 245 | r"""ResNet-101 model from 246 | `"Deep Residual Learning for Image Recognition" `_ 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | progress (bool): If True, displays a progress bar of the download to stderr 251 | """ 252 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 253 | **kwargs) 254 | 255 | 256 | def resnet152(pretrained=False, progress=True, **kwargs): 257 | r"""ResNet-152 model from 258 | `"Deep Residual Learning for Image Recognition" `_ 259 | 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 265 | **kwargs) 266 | 267 | 268 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 269 | r"""ResNeXt-50 32x4d model from 270 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 271 | 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | kwargs['groups'] = 32 277 | kwargs['width_per_group'] = 4 278 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 279 | pretrained, progress, **kwargs) 280 | 281 | 282 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 283 | r"""ResNeXt-101 32x8d model from 284 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 285 | 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | kwargs['groups'] = 32 291 | kwargs['width_per_group'] = 8 292 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 293 | pretrained, progress, **kwargs) 294 | 295 | 296 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 297 | r"""Wide ResNet-50-2 model from 298 | `"Wide Residual Networks" `_ 299 | 300 | The model is the same as ResNet except for the bottleneck number of channels 301 | which is twice larger in every block. The number of channels in outer 1x1 302 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 303 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 304 | 305 | Args: 306 | pretrained (bool): If True, returns a model pre-trained on ImageNet 307 | progress (bool): If True, displays a progress bar of the download to stderr 308 | """ 309 | kwargs['width_per_group'] = 64 * 2 310 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 311 | pretrained, progress, **kwargs) 312 | 313 | 314 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 315 | r"""Wide ResNet-101-2 model from 316 | `"Wide Residual Networks" `_ 317 | 318 | The model is the same as ResNet except for the bottleneck number of channels 319 | which is twice larger in every block. The number of channels in outer 1x1 320 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 321 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 322 | 323 | Args: 324 | pretrained (bool): If True, returns a model pre-trained on ImageNet 325 | progress (bool): If True, displays a progress bar of the download to stderr 326 | """ 327 | kwargs['width_per_group'] = 64 * 2 328 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 329 | pretrained, progress, **kwargs) 330 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from allennlp.data.fields import * 2 | from allennlp.data.instance import Instance 3 | from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer 4 | from allennlp.nn.util import get_text_field_mask 5 | from allennlp.data.tokenizers import Token 6 | from allennlp.models import BasicClassifier, Model 7 | from allennlp.training.metrics.fbeta_measure import FBetaMeasure 8 | from allennlp.data import Vocabulary 9 | from allennlp.models.model import Model 10 | from allennlp.modules import Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder 11 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 12 | from allennlp.training.metrics import F1Measure, Average, Metric 13 | from allennlp.common.params import Params 14 | from allennlp.commands.train import train_model 15 | from allennlp.data import Instance 16 | from allennlp.data.dataset_readers import DatasetReader 17 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 18 | from allennlp.data.tokenizers import Tokenizer 19 | from allennlp.common.util import START_SYMBOL, END_SYMBOL 20 | from allennlp.training.metrics.metric import Metric 21 | from allennlp.nn import util 22 | 23 | from typing import * 24 | from overrides import overrides 25 | import jieba 26 | import numpy as np 27 | import pickle 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | import resnet 32 | import cv2 as cv 33 | import os 34 | 35 | torch.manual_seed(123) 36 | 37 | 38 | def process_image(img, min_side=224): # 等比例缩放与填充 39 | size = img.shape 40 | h, w = size[0], size[1] 41 | # 长边缩放为min_side 42 | scale = max(w, h) / float(min_side) 43 | new_w, new_h = int(w/scale), int(h/scale) 44 | resize_img = cv.resize(img, (new_w, new_h)) 45 | # 填充至min_side * min_side 46 | # 下右填充 47 | top, bottom, left, right = 0, min_side-new_h, 0, min_side-new_w 48 | 49 | pad_img = cv.copyMakeBorder(resize_img, int(top), int(bottom), int(left), int(right), 50 | cv.BORDER_CONSTANT, value=[255,255,255]) # 从图像边界向上,下,左,右扩的像素数目 51 | 52 | return pad_img 53 | 54 | 55 | 56 | 57 | @DatasetReader.register("s2s_manual_reader") 58 | class SeqReader(DatasetReader): 59 | def __init__(self, 60 | tokenizer: Tokenizer = None, 61 | source_token_indexer: Dict[str, TokenIndexer] = None, 62 | target_token_indexer: Dict[str, TokenIndexer] = None, 63 | model_name: str = None) -> None: 64 | super().__init__(lazy=False) 65 | self._tokenizer = tokenizer 66 | self._source_token_indexer = source_token_indexer 67 | self._target_token_indexer = target_token_indexer 68 | self._model_name = model_name 69 | 70 | sub_dict_path = "data/sub_dataset_dict.pk" # problems type 71 | with open(sub_dict_path, 'rb') as file: 72 | subset_dict = pickle.load(file) 73 | self.subset_dict = subset_dict 74 | 75 | self.all_points = ['切线', '垂径定理', '勾股定理', '同位角', '平行线', '三角形内角和', '三角形中位线', '平行四边形', 76 | '相似三角形', '正方形', '圆周角', '直角三角形', '距离', '邻补角', '圆心角', '圆锥的计算', '三角函数', 77 | '矩形', '旋转', '等腰三角形', '外接圆', '内错角', '菱形', '多边形', '对顶角', '三角形的外角', '角平分线', 78 | '对称', '立体图形', '三视图', '圆内接四边形', '垂直平分线', '垂线', '扇形面积', '等边三角形', '平移', 79 | '含30度角的直角三角形', '仰角', '三角形的外接圆与外心', '方向角', '坡角', '直角三角形斜边上的中线', '位似', 80 | '平行线分线段成比例', '坐标与图形性质', '圆柱的计算', '俯角', '射影定理', '黄金分割', '钟面角'] 81 | 82 | @overrides 83 | def _read(self, file_path: str): 84 | with open(file_path, 'rb') as f: 85 | dataset = pickle.load(f) 86 | for sample in dataset: 87 | yield self.text_to_instance(sample) 88 | 89 | @overrides 90 | def text_to_instance(self, sample) -> Instance: 91 | fields = {} 92 | 93 | image = sample['image'] 94 | image = process_image(image) 95 | image = image/255 96 | img_rgb = np.zeros((3, image.shape[0], image.shape[1])) 97 | for i in range(3): 98 | img_rgb[i, :, :] = image 99 | fields['image'] = ArrayField(img_rgb) 100 | 101 | s_token = self._tokenizer.tokenize(' '.join(sample['token_list'])) 102 | fields['source_tokens'] = TextField(s_token, self._source_token_indexer) 103 | t_token = self._tokenizer.tokenize(' '.join(sample['manual_program'])) 104 | t_token.insert(0, Token(START_SYMBOL)) 105 | t_token.append(Token(END_SYMBOL)) 106 | fields['target_tokens'] = TextField(t_token, self._target_token_indexer) 107 | fields['source_nums'] = MetadataField(sample['numbers']) 108 | fields['choice_nums'] = MetadataField(sample['choice_nums']) 109 | fields['label'] = MetadataField(sample['label']) 110 | 111 | type = self.subset_dict[sample['id']] 112 | fields['type'] = MetadataField(type) 113 | fields['data_id'] = MetadataField(sample['id']) 114 | 115 | equ_list = [] 116 | 117 | equ = sample['manual_program'] 118 | equ_token = self._tokenizer.tokenize(' '.join(equ)) 119 | equ_token.insert(0, Token(START_SYMBOL)) 120 | equ_token.append(Token(END_SYMBOL)) 121 | equ_token = TextField(equ_token, self._source_token_indexer) 122 | equ_list.append(equ_token) 123 | 124 | fields['equ_list'] = ListField(equ_list) 125 | fields['manual_program'] = MetadataField(sample['manual_program']) 126 | 127 | point_label = np.zeros(50, np.float32) 128 | exam_points = sample['formal_point'] 129 | for point in exam_points: 130 | point_id = self.all_points.index(point) 131 | point_label[point_id] = 1 132 | fields['point_label'] = ArrayField(np.array(point_label)) 133 | 134 | return Instance(fields) 135 | 136 | --------------------------------------------------------------------------------