├── LICENSE ├── README.md ├── data_mc └── RACE │ └── ignore.rtf ├── output_mc └── ignore.rtf └── src_mc ├── cvc-iv.sh ├── cvc-mv.sh ├── main.py ├── model.py ├── post_model.py ├── post_train.py ├── train.sh ├── train_eval.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 FFishYU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CVC-QA 2 | 3 | ## Counterfactual Variable Control for Robust and Interpretable Question Answering 4 | This repository contains the code for the following paper: 5 | * Sicheng Yu, Yulei Niu, Shuohang Wang, Jing Jiang, Qianru Sun *"Counterfactual Variable Control for Robust and Interpretable Question Answering (https://arxiv.org/abs/2010.05581) 6 | 7 | 8 | ## Requirement 9 | * torch 1.3.1 10 | * transformers 2.1.1 11 | * apex 0.1 12 | * tensorboardX 1.8 13 | * prettytable 0.7.2 14 | 15 | ## Multiple-Choice Question Answering 16 | Here we use RACE with BERT-base as example for MCQA task. 17 | 18 | ### Dowload data 19 | - Step 1: Download original dataset via this link (http://www.cs.cmu.edu/~glai1/data/race/), and store them in directory `/data_mc/RACE`. 20 | - Step 2: Download the adversarial sets via this link (https://drive.google.com/drive/folders/1ufPl0aP-QglVdsDtlKq9kTnt_0Fqmw2i?usp=sharing), and store them in same directory as step 1. 21 | 22 | ### CVC Training 23 | ```sh 24 | cd src_mc 25 | bash train.sh 26 | ``` 27 | You may visualize the loss trend using tensorboardX in directory `/src_mc/runs`. 28 | 29 | ### CVC-IV inference 30 | Please change `--timestamp` according to your training time. 31 | ```sh 32 | bash cvc_iv.sh 33 | ``` 34 | 35 | ### CVC-MV inference (including training for c-adaptor) 36 | Please change `--pre_model_dir` according to model selected by you. 37 | ```sh 38 | bash cvc_mv.sh 39 | ``` 40 | ### MCQA model trained by me 41 | You can download the CVC model trained by us (CVC-MV is not included). You can find the results we reported in our paper. (https://drive.google.com/drive/folders/14ZMUwW_bxnpaDX4HbjdUxGwcNBzFIYR6?usp=sharing) 42 | 43 | ## Span-Extraction Question Answering 44 | Coming Soon! 45 | -------------------------------------------------------------------------------- /data_mc/RACE/ignore.rtf: -------------------------------------------------------------------------------- 1 | {\rtf1\ansi\ansicpg936\cocoartf2511 2 | \cocoatextscaling0\cocoaplatform0{\fonttbl\f0\fswiss\fcharset0 Helvetica;} 3 | {\colortbl;\red255\green255\blue255;} 4 | {\*\expandedcolortbl;;} 5 | \paperw11900\paperh16840\margl1440\margr1440\vieww10800\viewh8400\viewkind0 6 | \pard\tx566\tx1133\tx1700\tx2267\tx2834\tx3401\tx3968\tx4535\tx5102\tx5669\tx6236\tx6803\pardirnatural\partightenfactor0 7 | 8 | \f0\fs24 \cf0 Ignore this file.} -------------------------------------------------------------------------------- /output_mc/ignore.rtf: -------------------------------------------------------------------------------- 1 | {\rtf1\ansi\ansicpg936\cocoartf2511 2 | \cocoatextscaling0\cocoaplatform0{\fonttbl\f0\fswiss\fcharset0 Helvetica;} 3 | {\colortbl;\red255\green255\blue255;} 4 | {\*\expandedcolortbl;;} 5 | \paperw11900\paperh16840\margl1440\margr1440\vieww10800\viewh8400\viewkind0 6 | \pard\tx566\tx1133\tx1700\tx2267\tx2834\tx3401\tx3968\tx4535\tx5102\tx5669\tx6236\tx6803\pardirnatural\partightenfactor0 7 | 8 | \f0\fs24 \cf0 Ignore this file.} -------------------------------------------------------------------------------- /src_mc/cvc-iv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES=0 3 | python main.py --model_type bert \ 4 | --model_name_or_path bert-base-uncased \ 5 | --do_lower_case \ 6 | --max_seq_length 384 \ 7 | --per_gpu_eval_batch_size 8 \ 8 | --do_test \ 9 | --eval_all_checkpoints \ 10 | --task_name RACE \ 11 | --time_stamp 03-23-14-25 12 | 13 | -------------------------------------------------------------------------------- /src_mc/cvc-mv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export CUDA_VISIBLE_DEVICES=0,1 3 | python post_train.py --model_type bert \ 4 | --model_name_or_path bert-base-uncased \ 5 | --do_train \ 6 | --do_test \ 7 | --do_lower_case \ 8 | --learning_rate 3e-5 \ 9 | --num_train_epochs 1 \ 10 | --max_seq_length 384 \ 11 | --per_gpu_eval_batch_size=4 \ 12 | --per_gpu_train_batch_size=12 \ 13 | --gradient_accumulation_steps 1 \ 14 | --fp16 \ 15 | --task_name RACE \ 16 | --pre_model_dir 2020-03-23-14-25-checkpoint-118044-star \ 17 | --output_dir ../output_mc 18 | -------------------------------------------------------------------------------- /src_mc/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import random 7 | import glob 8 | import apex 9 | from prettytable import PrettyTable 10 | import numpy as np 11 | import torch 12 | from train_eval import train, set_seed, evaluate 13 | 14 | from transformers import (WEIGHTS_NAME, BertConfig, 15 | BertForMultipleChoice, BertTokenizer, 16 | XLNetConfig, XLNetForMultipleChoice, 17 | XLNetTokenizer, RobertaConfig, 18 | RobertaForMultipleChoice, RobertaTokenizer) 19 | from utils import load_and_cache_examples, MULTIPLE_CHOICE_TASKS_NUM_LABELS, load_and_cache_several_examples 20 | from train_eval import train, evaluate 21 | from model import BertForMultipleChoice_CVC 22 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ()) 23 | 24 | MODEL_CLASSES = { 25 | 'bert': (BertConfig, BertForMultipleChoice_CVC, BertTokenizer), 26 | } 27 | logging.basicConfig(level=logging.INFO) 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | def main(): 32 | parser = argparse.ArgumentParser() 33 | 34 | parser.add_argument("--model_type", default='bert', type=str, 35 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 36 | parser.add_argument("--model_name_or_path", default='bert-base-uncased', type=str, 37 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) 38 | parser.add_argument("--output_dir", default='../output_mc', type=str, 39 | help="The output directory where the model checkpoints and predictions will be written.") 40 | parser.add_argument("--raw_data_dir", default='../data_mc', type=str) 41 | parser.add_argument("--config_name", default="", type=str, 42 | help="Pretrained config name or path if not the same as model_name") 43 | parser.add_argument("--tokenizer_name", default="", type=str, 44 | help="Pretrained tokenizer name or path if not the same as model_name") 45 | parser.add_argument("--max_seq_length", default=384, type=int, 46 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 47 | "longer than this will be truncated, and sequences shorter than this will be padded.") 48 | parser.add_argument("--task_name", default='DREAM') 49 | parser.add_argument("--do_train", action='store_true', 50 | help="Whether to run training.") 51 | parser.add_argument("--do_eval", action='store_true', 52 | help="Whether to run eval on the dev set.") 53 | parser.add_argument("--do_test", action='store_true', 54 | help='Whether to run test on the test set') 55 | parser.add_argument("--evaluate_during_training", action='store_true', 56 | help="Rul evaluation during training at each logging step.") 57 | parser.add_argument("--do_lower_case", action='store_true', 58 | help="Set this flag if you are using an uncased model.") 59 | 60 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 61 | help="Batch size per GPU/CPU for training.") 62 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 63 | help="Batch size per GPU/CPU for evaluation.") 64 | parser.add_argument("--learning_rate", default=3e-5, type=float, 65 | help="The initial learning rate for Adam.") 66 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 67 | help="Number of updates steps to accumulate before performing a backward/update pass.") 68 | parser.add_argument("--weight_decay", default=0.0, type=float, 69 | help="Weight deay if we apply some.") 70 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 71 | help="Epsilon for Adam optimizer.") 72 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 73 | help="Max gradient norm.") 74 | parser.add_argument("--num_train_epochs", default=2.0, type=float, 75 | help="Total number of training epochs to perform.") 76 | parser.add_argument("--max_steps", default=-1, type=int, 77 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 78 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 79 | help="Linear warmup over warmup_steps.") 80 | parser.add_argument("--time_stamp", default='', type=str) 81 | parser.add_argument("--verbose_logging", action='store_true', 82 | help="If true, all of the warnings related to data processing will be printed. " 83 | "A number of warnings are expected for a normal SQuAD evaluation.") 84 | parser.add_argument("--eval_all_checkpoints", action='store_true', 85 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 86 | parser.add_argument("--no_cuda", action='store_true', 87 | help="Whether not to use CUDA when available") 88 | parser.add_argument('--overwrite_output_dir', action='store_true', 89 | help="Overwrite the content of the output directory") 90 | parser.add_argument('--overwrite_cache', action='store_true', 91 | help="Overwrite the cached training and evaluation sets") 92 | parser.add_argument('--seed', type=int, default=42, 93 | help="random seed for initialization") 94 | 95 | parser.add_argument("--local_rank", type=int, default=-1, 96 | help="local_rank for distributed training on gpus") 97 | parser.add_argument('--fp16', action='store_true', 98 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 99 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 100 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 101 | "See details at https://nvidia.github.io/apex/amp.html") 102 | args = parser.parse_args() 103 | 104 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 105 | logger.info("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 106 | # Setup CUDA, GPU & distributed training 107 | if args.local_rank == -1 or args.no_cuda: 108 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 109 | args.n_gpu = torch.cuda.device_count() 110 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 111 | torch.cuda.set_device(args.local_rank) 112 | device = torch.device("cuda", args.local_rank) 113 | torch.distributed.init_process_group(backend='nccl') 114 | args.n_gpu = 1 115 | args.device = device 116 | # Set seed 117 | set_seed(args) 118 | args.model_type = args.model_type.lower() 119 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 120 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) 121 | config.output_hidden_states = True 122 | config.num_options = int(MULTIPLE_CHOICE_TASKS_NUM_LABELS[args.task_name.lower()]) 123 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 124 | do_lower_case=args.do_lower_case) 125 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), 126 | config=config) 127 | model.MLP.copy_from_bert(model.bert) 128 | model.to(args.device) 129 | logger.info("Training/evaluation parameters %s", args) 130 | 131 | if args.fp16: 132 | try: 133 | import apex 134 | apex.amp.register_half_function(torch, 'einsum') 135 | except ImportError: 136 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 137 | 138 | if args.do_train: 139 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) # Reduce logging 140 | train_dataset = load_and_cache_examples(args, task=args.task_name, tokenizer=tokenizer, evaluate=False) 141 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 142 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 143 | 144 | time_stamp = args.time_stamp 145 | # Evaluation 146 | # We do not use dev set 147 | if args.do_eval and args.local_rank in [-1, 0]: 148 | checkpoints = [args.output_dir] 149 | if args.eval_all_checkpoints: 150 | checkpoints = list( 151 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 152 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 153 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) # Reduce logging 154 | logging.getLogger("transformers.configuration_utils").setLevel(logging.WARN) # Reduce logging 155 | checkpoints = [checkpoint for checkpoint in checkpoints if time_stamp in checkpoint] 156 | logger.info("Evaluate the following checkpoints for validation: %s", checkpoints) 157 | best_ckpt = 0 158 | best_acc = 0 159 | for checkpoint in checkpoints: 160 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 161 | prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else "" 162 | logger.info("Load the model: %s", checkpoint) 163 | 164 | model = model_class.from_pretrained(checkpoint) 165 | model.to(args.device) 166 | result = evaluate(args, args.task_name, model, tokenizer, prefix=prefix) 167 | if result[0]['eval_acc'] > best_acc: 168 | best_ckpt = checkpoint 169 | best_acc = result[0]['eval_acc'] 170 | if args.do_test and args.local_rank in [-1, 0]: 171 | try: 172 | checkpoints = [best_ckpt] 173 | except: 174 | checkpoints = list( 175 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 176 | checkpoints = [checkpoint for checkpoint in checkpoints if time_stamp in checkpoint] 177 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 178 | logging.getLogger("transformers.configuration_utils").setLevel(logging.WARN) # Reduce logging 179 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) # Reduce logging 180 | 181 | logger.info("Evaluate the following checkpoints for final testing: %s", checkpoints) 182 | for checkpoint in checkpoints: 183 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 184 | prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else "" 185 | logger.info("Load the model: %s", checkpoint) 186 | model = model_class.from_pretrained(checkpoint) 187 | model.to(args.device) 188 | task_string = ['', '-Add1OtherTruth2Opt', '-Add2OtherTruth2Opt', '-Add1PasSent2Opt', '-Add1NER2Pass'] 189 | task_string = [args.task_name + item for item in task_string] 190 | result = evaluate(args, task_string, model, tokenizer, prefix=prefix, test=True) 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /src_mc/model.py: -------------------------------------------------------------------------------- 1 | from transformers.modeling_bert import BertPreTrainedModel, BertForMultipleChoice, BertModel, BertLayer, BertConfig, gelu, BertPooler 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from torch.nn import CrossEntropyLoss, MSELoss 6 | import copy 7 | 8 | class Fusion(nn.Module): 9 | def __init__(self, 10 | fusion_name, config): 11 | super(Fusion, self).__init__() 12 | FUSION_METHOD = { 13 | 'prob_rubi': self.prob_rubi_fusion 14 | } 15 | self.fusion_method = FUSION_METHOD[fusion_name] 16 | 17 | def prob_rubi_fusion(self, zk, zb, hidden=None, cf=False): 18 | if not cf: 19 | prob_zk = F.softmax(zk, dim=-1) 20 | prob_zb = F.softmax(zb, dim=-1) 21 | fusion_prob = prob_zk * prob_zb 22 | log_fusion_prob = torch.log(fusion_prob) 23 | return log_fusion_prob 24 | else: 25 | prob_zk = F.softmax(zk, dim=-1) 26 | prob_zb = F.softmax(zb, dim=-1) 27 | num_choice = prob_zk.size()[-1] 28 | similarity = Jensen_Shannon_Div(prob_zk, prob_zb).unsqueeze(-1) 29 | # c = torch.mean(prob_zk, dim=-1).unsqueeze(-1).repeat(1, num_choice) 30 | c = similarity 31 | log_fusion_prob = prob_zk * prob_zb - c * prob_zb 32 | return log_fusion_prob, c 33 | 34 | def forward(self, zk, zb, hidden=None, cf=False): 35 | return self.fusion_method(zk, zb, hidden, cf) 36 | 37 | class BertForMultipleChoice_CVC(BertPreTrainedModel): 38 | def __init__(self, config): 39 | super(BertForMultipleChoice_CVC, self).__init__(config) 40 | 41 | self.bert = BertModel(config) 42 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 43 | self.classifier = nn.Linear(config.hidden_size, 1) 44 | self.classifier_ = nn.Linear(config.hidden_size, 1) 45 | if config.num_hidden_layers == 12: 46 | self.num_shared_layers = 10 47 | elif config.num_hidden_layers == 24: 48 | self.num_shared_layers = 20 49 | self.fusion_name = 'prob_rubi' 50 | self.Fusion = Fusion(self.fusion_name, config) 51 | self.MLP = transformer_block(config, num_shared_layers=self.num_shared_layers, num_layers=config.num_hidden_layers-self.num_shared_layers) 52 | self.init_weights() 53 | 54 | def forward(self, input_ids, attention_mask=None, token_type_ids=None, 55 | input_ids_np=None, attention_mask_np=None, token_type_ids_np=None, 56 | position_ids=None, head_mask=None, labels=None): 57 | num_choices = input_ids.shape[1] 58 | 59 | input_ids = input_ids.view(-1, input_ids.size(-1)) 60 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 61 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 62 | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 63 | 64 | outputs = self.bert(input_ids, 65 | attention_mask=attention_mask, 66 | token_type_ids=token_type_ids, 67 | position_ids=position_ids, 68 | head_mask=head_mask) 69 | 70 | pooled_output = outputs[1] 71 | 72 | pooled_output = self.dropout(pooled_output) 73 | logits = self.classifier(pooled_output).view(-1, num_choices) 74 | 75 | outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here 76 | 77 | if labels is not None: 78 | input_ids_np = input_ids_np.view(-1, input_ids_np.size( 79 | -1)) if input_ids_np is not None else None 80 | attention_mask_np = attention_mask_np.view(-1, attention_mask_np.size( 81 | -1)) if attention_mask_np is not None else None 82 | token_type_ids_np = token_type_ids_np.view(-1, token_type_ids_np.size( 83 | -1)) if token_type_ids_np is not None else None 84 | 85 | outputs_np = self.bert(input_ids_np, 86 | attention_mask=attention_mask_np, 87 | token_type_ids=token_type_ids_np, 88 | position_ids=position_ids, 89 | head_mask=head_mask) 90 | 91 | # Use Transformer block on top of bias branch 92 | outputs_np = outputs_np[2][self.num_shared_layers - 1] 93 | outputs_np = grad_mul_const(outputs_np, 0.0) 94 | logits_np = self.MLP(outputs_np, attention_mask_np) 95 | logits_np = self.classifier_(logits_np).view(-1, num_choices) 96 | 97 | logits_np_ = logits_np*1 98 | fusion_logits = self.Fusion(logits, logits_np_.detach()) 99 | 100 | loss_fct = CrossEntropyLoss() 101 | fusion_loss = loss_fct(fusion_logits, labels) 102 | np_loss = loss_fct(logits_np, labels) 103 | 104 | 105 | outputs = (fusion_loss, np_loss) + outputs 106 | return outputs # (loss), reshaped_logits, (hidden_states), (attentions) 107 | 108 | def inference_IV(self, input_ids, attention_mask=None, token_type_ids=None, 109 | input_ids_np=None, attention_mask_np=None, token_type_ids_np=None, 110 | position_ids=None, head_mask=None, labels=None): 111 | num_choices = input_ids.shape[1] 112 | 113 | input_ids = input_ids.view(-1, input_ids.size(-1)) 114 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 115 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 116 | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 117 | 118 | outputs = self.bert(input_ids, 119 | attention_mask=attention_mask, 120 | token_type_ids=token_type_ids, 121 | position_ids=position_ids, 122 | head_mask=head_mask) 123 | pooled_output = outputs[1] 124 | logits = self.classifier(pooled_output).view(-1, num_choices) 125 | return logits 126 | 127 | class transformer_block(nn.Module): 128 | def __init__(self, config, num_shared_layers, num_layers): 129 | super(transformer_block, self).__init__() 130 | self.num_layers = num_layers 131 | self.num_shared_layers = num_shared_layers 132 | self.bert_layers = nn.ModuleList([BertLayer(config) for _ in range(num_layers)]) 133 | self.pooler = BertPooler(config) 134 | 135 | def copy_from_bert(self, bert): 136 | for i, layer in enumerate(self.bert_layers): 137 | self.bert_layers[i] = copy.deepcopy(bert.encoder.layer[i+self.num_shared_layers]) 138 | self.pooler = copy.deepcopy(bert.pooler) 139 | 140 | def forward(self, x, attention_mask): 141 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 142 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 143 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 144 | for bert_layer in self.bert_layers: 145 | x = bert_layer(x, extended_attention_mask) 146 | x = x[0] # only the hidden part, bert layer will also output attention 147 | x = self.pooler(x) 148 | return x 149 | 150 | class GradMulConst(torch.autograd.Function): 151 | """ 152 | This layer is used to create an adversarial loss. 153 | """ 154 | @staticmethod 155 | def forward(ctx, x, const): 156 | ctx.const = const 157 | return x.view_as(x) 158 | 159 | @staticmethod 160 | def backward(ctx, grad_output): 161 | return grad_output * ctx.const, None 162 | 163 | def grad_mul_const(x, const): 164 | return GradMulConst.apply(x, const) 165 | 166 | def logits_norm(x): 167 | return F.softmax(x, dim=-1) 168 | 169 | def Jensen_Shannon_Div(p1, p2): 170 | batch_size = p1.size()[0] 171 | result = [] 172 | for i in range(batch_size): 173 | p1_, p2_ = p1[i], p2[i] 174 | JS_div = 0.5*KL_Div(p1_, (p1_+p2_)/2) + 0.5*KL_Div(p2_, (p1_+p2_)/2) 175 | result.append(JS_div.unsqueeze(0)) 176 | output = torch.cat(result, dim=0) 177 | return output 178 | 179 | 180 | def KL_Div(P, Q): 181 | output = (P * (P / Q).log()).sum() 182 | return output -------------------------------------------------------------------------------- /src_mc/post_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from torch.nn import CrossEntropyLoss 6 | from transformers.modeling_bert import BertPreTrainedModel 7 | from model import BertForMultipleChoice_CVC, Jensen_Shannon_Div 8 | from transformers import PreTrainedModel 9 | WEIGHTS_NAME = "pytorch_model.bin" 10 | 11 | class Post_MV(BertPreTrainedModel): 12 | def __init__(self, args, config): 13 | super(Post_MV, self).__init__(config) 14 | self.args = args 15 | self.config = config 16 | self.pre_model = BertForMultipleChoice_CVC.from_pretrained(args.checkpoint) 17 | self.config_class = self.pre_model.config_class 18 | 19 | for name, p in self.pre_model.named_parameters(): 20 | p.requires_grad = False 21 | self.TIE = MV(config) 22 | 23 | def forward(self, input_ids, attention_mask=None, token_type_ids=None, 24 | input_ids_np=None, attention_mask_np=None, token_type_ids_np=None, 25 | position_ids=None, head_mask=None, labels=None): 26 | num_choices = input_ids.shape[1] 27 | 28 | input_ids = input_ids.view(-1, input_ids.size(-1)) 29 | attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None 30 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None 31 | position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None 32 | outputs = self.pre_model.bert(input_ids, 33 | attention_mask=attention_mask, 34 | token_type_ids=token_type_ids, 35 | position_ids=position_ids, 36 | head_mask=head_mask) 37 | pooled_output = outputs[1] 38 | logits = self.pre_model.classifier(pooled_output).view(-1, num_choices) 39 | 40 | input_ids_np = input_ids_np.view(-1, input_ids_np.size( 41 | -1)) if input_ids_np is not None else None 42 | attention_mask_np = attention_mask_np.view(-1, attention_mask_np.size( 43 | -1)) if attention_mask_np is not None else None 44 | token_type_ids_np = token_type_ids_np.view(-1, token_type_ids_np.size( 45 | -1)) if token_type_ids_np is not None else None 46 | outputs_np = self.pre_model.bert(input_ids_np, 47 | attention_mask=attention_mask_np, 48 | token_type_ids=token_type_ids_np, 49 | position_ids=position_ids, 50 | head_mask=head_mask) 51 | outputs_np_ = outputs_np[2][self.pre_model.num_shared_layers - 1] 52 | logits_np = self.pre_model.MLP(outputs_np_, attention_mask_np) 53 | logits_np = self.pre_model.classifier_(logits_np).view(-1, num_choices) 54 | 55 | TIE_logits, item_1, item_2 = self.TIE(logits, logits_np) 56 | loss_fct = CrossEntropyLoss() 57 | TIE_loss = loss_fct(TIE_logits, labels) 58 | return TIE_logits, TIE_loss, item_1, item_2 59 | 60 | class MV(nn.Module): 61 | def __init__(self, config): 62 | super(MV, self).__init__() 63 | self.config = config 64 | self.num_options = config.num_options 65 | self.compute_c_1 = nn.Linear(self.num_options*2+1, 100) 66 | self.compute_c_2 = nn.Linear(100, 1) 67 | self.tanh = nn.Tanh() 68 | 69 | def compute_c(self, x): 70 | x = self.compute_c_1(x) 71 | x = self.tanh(x) 72 | x = self.compute_c_2(x) 73 | x = F.sigmoid(x) 74 | return x 75 | 76 | def forward(self, logit, logit_np): 77 | prob_zk = F.softmax(logit, dim=-1) 78 | prob_zb = F.softmax(logit_np, dim=-1) 79 | num_choice = prob_zk.size()[-1] 80 | js = Jensen_Shannon_Div(prob_zk, prob_zb).unsqueeze(-1) 81 | c = torch.cat([prob_zk, prob_zb, js], dim=-1) 82 | # c = torch.cat([logit, logit_np, logit-logit_np, js], dim=-1) 83 | c = self.compute_c(c) 84 | TIE_logits = prob_zk * prob_zb - c * prob_zb 85 | return TIE_logits, prob_zk * prob_zb, c * prob_zb -------------------------------------------------------------------------------- /src_mc/post_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import logging 5 | import os 6 | import random 7 | import glob 8 | import apex 9 | from prettytable import PrettyTable 10 | import numpy as np 11 | import torch 12 | from train_eval import train, set_seed, evaluate 13 | 14 | from transformers import (WEIGHTS_NAME, BertConfig, BertTokenizer 15 | ) 16 | from utils import load_and_cache_examples, MULTIPLE_CHOICE_TASKS_NUM_LABELS, load_and_cache_several_examples 17 | from train_eval import train, evaluate 18 | from post_model import Post_MV 19 | from model import BertForMultipleChoice_CVC 20 | from tqdm import tqdm, trange 21 | from datetime import datetime, timezone, timedelta 22 | from transformers import AdamW, WarmupLinearSchedule 23 | from transformers import (WEIGHTS_NAME, BertConfig, 24 | BertForMultipleChoice, BertTokenizer, 25 | XLNetConfig, XLNetForMultipleChoice, 26 | XLNetTokenizer, RobertaConfig, 27 | RobertaForMultipleChoice, RobertaTokenizer) 28 | from tensorboardX import SummaryWriter 29 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 30 | TensorDataset) 31 | 32 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, RobertaConfig)), ()) 33 | 34 | MODEL_CLASSES = { 35 | 'bert': (BertConfig, BertForMultipleChoice_CVC, BertTokenizer), 36 | } 37 | logging.basicConfig(level=logging.INFO) 38 | logger = logging.getLogger(__name__) 39 | 40 | def simple_accuracy(preds, labels): 41 | return (preds == labels).mean() 42 | 43 | def train_process(args, train_dataset, model, tokenizer): 44 | """ Train the model """ 45 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 46 | train_sampler = RandomSampler(train_dataset) 47 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 48 | 49 | if args.max_steps > 0: 50 | t_total = args.max_steps 51 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 52 | else: 53 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 54 | 55 | # Prepare optimizer and schedule (linear warmup and decay) 56 | no_decay = ['bias', 'LayerNorm.weight'] 57 | optimizer_grouped_parameters = [ 58 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 59 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 60 | ] 61 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 62 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_proportion * t_total, t_total=t_total) 63 | if args.fp16: 64 | try: 65 | from apex import amp 66 | except ImportError: 67 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 68 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 69 | 70 | # multi-gpu training (should be after apex fp16 initialization) 71 | if args.n_gpu > 1: 72 | model = torch.nn.DataParallel(model) 73 | 74 | # Distributed training (should be after apex fp16 initialization) 75 | if args.local_rank != -1: 76 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 77 | output_device=args.local_rank, 78 | find_unused_parameters=True) 79 | 80 | # Train! 81 | logger.info("***** Running training *****") 82 | logger.info(" Num examples = %d", len(train_dataset)) 83 | logger.info(" Num Epochs = %d", args.num_train_epochs) 84 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 85 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 86 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 87 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 88 | logger.info(" Total optimization steps = %d", t_total) 89 | 90 | global_step = 0 91 | tr_loss, logging_loss = 0.0, 0.0 92 | model.zero_grad() 93 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 94 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 95 | for _ in train_iterator: 96 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 97 | for step, batch in enumerate(epoch_iterator): 98 | model.train() 99 | batch = tuple(t.to(args.device) for t in batch) 100 | inputs = {'input_ids': batch[0], 101 | 'attention_mask': batch[1], 102 | 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, 103 | 'input_ids_np': batch[3], 104 | 'attention_mask_np': batch[4], 105 | 'token_type_ids_np': batch[5], 106 | 'labels': batch[6]} 107 | outputs = model(**inputs) 108 | loss = outputs[1] 109 | if args.n_gpu > 1: 110 | loss = loss.mean() # mean() to average on multi-gpu parallel training 111 | if args.gradient_accumulation_steps > 1: 112 | loss = loss / args.gradient_accumulation_steps 113 | 114 | if args.fp16: 115 | with amp.scale_loss(loss, optimizer) as scaled_loss: 116 | scaled_loss.backward() 117 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 118 | else: 119 | loss.backward() 120 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 121 | 122 | tr_loss += loss.item() 123 | if (step + 1) % args.gradient_accumulation_steps == 0: 124 | optimizer.step() 125 | scheduler.step() # Update learning rate schedule 126 | model.zero_grad() 127 | global_step += 1 128 | output_dir = os.path.join(args.output_dir, args.pre_model_dir+'-TIE') 129 | if not os.path.exists(output_dir): 130 | os.makedirs(output_dir) 131 | model_to_save = model.module if hasattr(model, 132 | 'module') else model # Take care of distributed/parallel training 133 | model_to_save.save_pretrained(output_dir) 134 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 135 | logger.info("Saving model checkpoint to %s", output_dir) 136 | return global_step, loss.detach().cpu().numpy() 137 | 138 | 139 | def evaluate(args, eval_task_names, model, tokenizer, test=False): 140 | results = [] 141 | table = PrettyTable() 142 | table.add_column(' ', ['Accuracy']) 143 | for eval_task in eval_task_names: 144 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=not test, test=test) 145 | 146 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 147 | # Note that DistributedSampler samples randomly 148 | eval_sampler = SequentialSampler(eval_dataset) 149 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 150 | nb_eval_steps = 0 151 | preds = None 152 | out_label_ids = None 153 | 154 | for batch in eval_dataloader: 155 | model.eval() 156 | batch = tuple(t.to(args.device) for t in batch) 157 | 158 | with torch.no_grad(): 159 | inputs = {'input_ids': batch[0], 160 | 'attention_mask': batch[1], 161 | 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, 162 | # XLM don't use segment_ids 163 | 'input_ids_np': batch[3], 164 | 'attention_mask_np': batch[4], 165 | 'token_type_ids_np': batch[5], 166 | 'labels': batch[6] 167 | } 168 | output = model(**inputs) 169 | logits = output[0] 170 | nb_eval_steps += 1 171 | if preds is None: 172 | preds = logits.detach().cpu().numpy() 173 | out_label_ids = inputs['labels'].detach().cpu().numpy() 174 | else: 175 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 176 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 177 | 178 | 179 | preds = np.argmax(preds, axis=1) 180 | acc = simple_accuracy(preds, out_label_ids) 181 | result = {"task_name": eval_task, "eval_acc": acc} 182 | results.append(result) 183 | for key in sorted(result.keys()): 184 | logger.info(" %s = %s", key, str(result[key])) 185 | table.add_column(eval_task, [round(acc * 100, 2)]) 186 | print(table) 187 | return results 188 | 189 | def main(): 190 | parser = argparse.ArgumentParser() 191 | 192 | parser.add_argument("--model_type", default='bert', type=str, 193 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 194 | parser.add_argument("--model_name_or_path", default='bert-base-uncased', type=str, 195 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) 196 | parser.add_argument("--output_dir", default='../output_mc', type=str, 197 | help="The output directory where the model checkpoints and predictions will be written.") 198 | parser.add_argument("--raw_data_dir", default='../data_mc', type=str) 199 | parser.add_argument("--config_name", default="", type=str, 200 | help="Pretrained config name or path if not the same as model_name") 201 | parser.add_argument("--tokenizer_name", default="", type=str, 202 | help="Pretrained tokenizer name or path if not the same as model_name") 203 | parser.add_argument("--max_seq_length", default=384, type=int, 204 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 205 | "longer than this will be truncated, and sequences shorter than this will be padded.") 206 | parser.add_argument("--task_name", default='DREAM') 207 | parser.add_argument("--pre_model_dir", default='2020-03-12-10-58-checkpoint-3048') 208 | parser.add_argument("--do_train", action='store_true', 209 | help="Whether to run training.") 210 | parser.add_argument("--do_eval", action='store_true', 211 | help="Whether to run eval on the dev set.") 212 | parser.add_argument("--do_test", action='store_true', 213 | help='Whether to run test on the test set') 214 | parser.add_argument("--evaluate_during_training", action='store_true', 215 | help="Rul evaluation during training at each logging step.") 216 | parser.add_argument("--do_lower_case", action='store_true', 217 | help="Set this flag if you are using an uncased model.") 218 | 219 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 220 | help="Batch size per GPU/CPU for training.") 221 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 222 | help="Batch size per GPU/CPU for evaluation.") 223 | parser.add_argument("--learning_rate", default=3e-5, type=float, 224 | help="The initial learning rate for Adam.") 225 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 226 | help="Number of updates steps to accumulate before performing a backward/update pass.") 227 | parser.add_argument("--weight_decay", default=0.0, type=float, 228 | help="Weight deay if we apply some.") 229 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 230 | help="Epsilon for Adam optimizer.") 231 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 232 | help="Max gradient norm.") 233 | parser.add_argument("--num_train_epochs", default=2.0, type=float, 234 | help="Total number of training epochs to perform.") 235 | parser.add_argument("--max_steps", default=-1, type=int, 236 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 237 | parser.add_argument("--warmup_proportion", default=0.1, type=float, 238 | help="Linear warmup over warmup_steps.") 239 | parser.add_argument("--verbose_logging", action='store_true', 240 | help="If true, all of the warnings related to data processing will be printed. " 241 | "A number of warnings are expected for a normal SQuAD evaluation.") 242 | parser.add_argument("--eval_all_checkpoints", action='store_true', 243 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 244 | parser.add_argument("--no_cuda", action='store_true', 245 | help="Whether not to use CUDA when available") 246 | parser.add_argument('--overwrite_output_dir', action='store_true', 247 | help="Overwrite the content of the output directory") 248 | parser.add_argument('--overwrite_cache', action='store_true', 249 | help="Overwrite the cached training and evaluation sets") 250 | parser.add_argument('--seed', type=int, default=42, 251 | help="random seed for initialization") 252 | 253 | parser.add_argument("--local_rank", type=int, default=-1, 254 | help="local_rank for distributed training on gpus") 255 | parser.add_argument('--fp16', action='store_true', 256 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 257 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 258 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 259 | "See details at https://nvidia.github.io/apex/amp.html") 260 | args = parser.parse_args() 261 | 262 | args.checkpoint = os.path.join(args.output_dir, args.pre_model_dir) 263 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 264 | logger.info("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 265 | # Setup CUDA, GPU & distributed training 266 | if args.local_rank == -1 or args.no_cuda: 267 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 268 | args.n_gpu = torch.cuda.device_count() 269 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 270 | torch.cuda.set_device(args.local_rank) 271 | device = torch.device("cuda", args.local_rank) 272 | torch.distributed.init_process_group(backend='nccl') 273 | args.n_gpu = 1 274 | args.device = device 275 | # Set seed 276 | set_seed(args) 277 | args.model_type = args.model_type.lower() 278 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 279 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) 280 | config.output_hidden_states = True 281 | config.num_options = int(MULTIPLE_CHOICE_TASKS_NUM_LABELS[args.task_name.lower()]) 282 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 283 | do_lower_case=args.do_lower_case) 284 | post_model = Post_MV(args, config) 285 | post_model.to(args.device) 286 | logger.info("Training/evaluation parameters %s", args) 287 | 288 | if args.fp16: 289 | try: 290 | import apex 291 | apex.amp.register_half_function(torch, 'einsum') 292 | except ImportError: 293 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 294 | 295 | if args.do_train: 296 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) # Reduce logging 297 | train_dataset = load_and_cache_examples(args, task=args.task_name, tokenizer=tokenizer, evaluate=False) 298 | global_step, tr_loss = train_process(args, train_dataset, post_model, tokenizer) 299 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 300 | 301 | if args.do_test: 302 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 303 | logging.getLogger("transformers.configuration_utils").setLevel(logging.WARN) # Reduce logging 304 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) # Reduce logging 305 | checkpoint = os.path.join(args.output_dir, args.pre_model_dir+'-TIE') 306 | logger.info(" Load model from %s", checkpoint) 307 | post_model.load_state_dict(torch.load(os.path.join(checkpoint, 'pytorch_model.bin'))) 308 | post_model.to(args.device) 309 | task_string = ['', '-Add1OtherTruth2Opt', '-Add2OtherTruth2Opt', '-Add1PasSent2Opt', '-Add1NER2Pass'] 310 | task_string = [args.task_name+item for item in task_string] 311 | result = evaluate(args, task_string, post_model, tokenizer, test=True) 312 | 313 | if __name__ == "__main__": 314 | main() 315 | 316 | 317 | -------------------------------------------------------------------------------- /src_mc/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # For RACE 3 | export CUDA_VISIBLE_DEVICES=0,1 4 | python main.py --model_type bert \ 5 | --model_name_or_path bert-base-uncased \ 6 | --do_train \ 7 | --do_lower_case \ 8 | --learning_rate 3e-5 \ 9 | --num_train_epochs 36 \ 10 | --max_seq_length 384 \ 11 | --per_gpu_eval_batch_size=8 \ 12 | --per_gpu_train_batch_size=6 \ 13 | --gradient_accumulation_steps 2 \ 14 | --fp16 \ 15 | --task_name RACE \ 16 | --seed 122 -------------------------------------------------------------------------------- /src_mc/train_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 3 | TensorDataset) 4 | from tensorboardX import SummaryWriter 5 | from datetime import datetime, timezone, timedelta 6 | from transformers import AdamW, WarmupLinearSchedule 7 | import torch 8 | from prettytable import PrettyTable 9 | import logging 10 | import os 11 | import numpy as np 12 | import random 13 | from tqdm import tqdm, trange 14 | from utils import load_and_cache_examples 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | def set_seed(args): 19 | random.seed(args.seed) 20 | np.random.seed(args.seed) 21 | torch.manual_seed(args.seed) 22 | if args.n_gpu > 0: 23 | torch.cuda.manual_seed_all(args.seed) 24 | 25 | def simple_accuracy(preds, labels): 26 | return (preds == labels).mean() 27 | 28 | def train(args, train_dataset, model, tokenizer): 29 | """ Train the model """ 30 | exec_time = datetime.utcnow().astimezone(timezone(timedelta(hours=8))) \ 31 | .strftime("%Y-%m-%d-%H-%M") 32 | if args.local_rank in [-1, 0]: 33 | tb_writer = SummaryWriter(log_dir=os.path.join('runs', exec_time)) 34 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 35 | train_sampler = RandomSampler(train_dataset) 36 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 37 | 38 | if args.max_steps > 0: 39 | t_total = args.max_steps 40 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 41 | else: 42 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 43 | 44 | # Prepare optimizer and schedule (linear warmup and decay) 45 | no_decay = ['bias', 'LayerNorm.weight'] 46 | optimizer_grouped_parameters = [ 47 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 48 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 49 | ] 50 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 51 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_proportion * t_total, t_total=t_total) 52 | if args.fp16: 53 | try: 54 | from apex import amp 55 | except ImportError: 56 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 57 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 58 | 59 | # multi-gpu training (should be after apex fp16 initialization) 60 | if args.n_gpu > 1: 61 | model = torch.nn.DataParallel(model) 62 | 63 | # Distributed training (should be after apex fp16 initialization) 64 | if args.local_rank != -1: 65 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 66 | output_device=args.local_rank, 67 | find_unused_parameters=True) 68 | 69 | # Train! 70 | logger.info("***** Running training *****") 71 | logger.info(" Num examples = %d", len(train_dataset)) 72 | logger.info(" Num Epochs = %d", args.num_train_epochs) 73 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 74 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 75 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 76 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 77 | logger.info(" Total optimization steps = %d", t_total) 78 | 79 | global_step = 0 80 | tr_loss, logging_loss = 0.0, 0.0 81 | model.zero_grad() 82 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 83 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 84 | for _ in train_iterator: 85 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 86 | for step, batch in enumerate(epoch_iterator): 87 | model.train() 88 | batch = tuple(t.to(args.device) for t in batch) 89 | inputs = {'input_ids': batch[0], 90 | 'attention_mask': batch[1], 91 | 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, 92 | 'input_ids_np': batch[3], 93 | 'attention_mask_np': batch[4], 94 | 'token_type_ids_np': batch[5], 95 | 'labels': batch[6]} 96 | outputs = model(**inputs) 97 | fusion_loss, np_loss = outputs[0], outputs[1] # model outputs are always tuple in transformers (see doc) 98 | loss = fusion_loss+np_loss 99 | if args.n_gpu > 1: 100 | loss = loss.mean() # mean() to average on multi-gpu parallel training 101 | if args.gradient_accumulation_steps > 1: 102 | loss = loss / args.gradient_accumulation_steps 103 | 104 | if args.fp16: 105 | with amp.scale_loss(loss, optimizer) as scaled_loss: 106 | scaled_loss.backward() 107 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 108 | else: 109 | loss.backward() 110 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 111 | 112 | tr_loss += loss.item() 113 | if (step + 1) % args.gradient_accumulation_steps == 0: 114 | 115 | optimizer.step() 116 | scheduler.step() # Update learning rate schedule 117 | model.zero_grad() 118 | global_step += 1 119 | 120 | tb_writer.add_scalar('train/lr', scheduler.get_lr()[0], global_step) 121 | tb_writer.add_scalar('train/loss', loss, global_step) 122 | tb_writer.add_scalar('train/fusion_loss', fusion_loss.mean(), global_step) 123 | tb_writer.add_scalar('train/np_loss', np_loss.mean(), global_step) 124 | 125 | if args.max_steps > 0 and global_step > args.max_steps: 126 | epoch_iterator.close() 127 | break 128 | 129 | if args.local_rank in [-1, 0]: 130 | # Save model checkpoint 131 | output_dir = os.path.join(args.output_dir, '{}-checkpoint-{}'.format(exec_time, global_step)) 132 | if not os.path.exists(output_dir): 133 | os.makedirs(output_dir) 134 | model_to_save = model.module if hasattr(model, 135 | 'module') else model # Take care of distributed/parallel training 136 | model_to_save.save_pretrained(output_dir) 137 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 138 | logger.info("Saving model checkpoint to %s", output_dir) 139 | 140 | 141 | if args.local_rank in [-1, 0]: 142 | tb_writer.close() 143 | return global_step, tr_loss / global_step 144 | 145 | def evaluate(args, eval_task_names, model, tokenizer, prefix="", test=False): 146 | eval_outputs_dirs = (args.output_dir,) 147 | results = [] 148 | table = PrettyTable() 149 | table.add_column(' ', ['Accuracy']) 150 | for eval_task in eval_task_names: 151 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=not test, test=test) 152 | 153 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 154 | # Note that DistributedSampler samples randomly 155 | eval_sampler = SequentialSampler(eval_dataset) 156 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 157 | # Eval! 158 | # logger.info("***** Running evaluation {} *****".format(prefix)) 159 | # logger.info(" Num examples = %d", len(eval_dataset)) 160 | # logger.info(" Batch size = %d", args.eval_batch_size) 161 | nb_eval_steps = 0 162 | preds = None 163 | out_label_ids = None 164 | for batch in eval_dataloader: 165 | model.eval() 166 | batch = tuple(t.to(args.device) for t in batch) 167 | 168 | with torch.no_grad(): 169 | inputs = {'input_ids': batch[0], 170 | 'attention_mask': batch[1], 171 | 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids 172 | 'input_ids_np': batch[3], 173 | 'attention_mask_np': batch[4], 174 | 'token_type_ids_np': batch[5], 175 | 'labels': batch[6] 176 | } 177 | 178 | logits = model.inference_IV(**inputs) 179 | nb_eval_steps += 1 180 | if preds is None: 181 | preds = logits.detach().cpu().numpy() 182 | out_label_ids = inputs['labels'].detach().cpu().numpy() 183 | else: 184 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 185 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 186 | preds = np.argmax(preds, axis=1) 187 | acc = simple_accuracy(preds, out_label_ids) 188 | result = {"task_name": eval_task, "eval_acc": acc} 189 | results.append(result) 190 | for key in sorted(result.keys()): 191 | logger.info(" %s = %s", key, str(result[key])) 192 | table.add_column(eval_task, [round(acc*100, 2)]) 193 | print(table) 194 | return results 195 | 196 | -------------------------------------------------------------------------------- /src_mc/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import json 3 | import logging 4 | import os 5 | import glob 6 | import re 7 | from io import open 8 | import torch 9 | import tqdm 10 | from typing import List 11 | 12 | from transformers.tokenization_bert import PreTrainedTokenizer 13 | 14 | from torch.utils.data import TensorDataset 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | def load_and_cache_several_examples(args, tasks, tokenizer, evaluate=False, test=False, no_para=True): 19 | all_features = [] 20 | for task in tasks: 21 | if 'Add' not in task: 22 | if args.local_rank not in [-1, 0]: 23 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 24 | args.data_dir = os.path.join(args.raw_data_dir, task) 25 | processor = processors[task.lower()]() 26 | # Load data features from cache or dataset file 27 | if evaluate: 28 | cached_mode = 'dev' 29 | elif test: 30 | cached_mode = 'test' 31 | else: 32 | cached_mode = 'train' 33 | assert (evaluate == True and test == True) == False 34 | cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format( 35 | cached_mode, 36 | list(filter(None, args.model_name_or_path.split('/'))).pop(), 37 | str(args.max_seq_length), 38 | str(task))) 39 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 40 | logger.info("Loading features from cached file %s", cached_features_file) 41 | features = torch.load(cached_features_file) 42 | else: 43 | logger.info("Creating features from dataset file at %s", args.data_dir) 44 | label_list = processor.get_labels() 45 | if evaluate: 46 | examples = processor.get_dev_examples(args.data_dir) 47 | elif test: 48 | examples = processor.get_test_examples(args.data_dir) 49 | else: 50 | examples = processor.get_train_examples(args.data_dir) 51 | logger.info("Total number: %s", str(len(examples))) 52 | features = convert_examples_to_features( 53 | examples, 54 | label_list, 55 | args.max_seq_length, 56 | tokenizer, 57 | pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet 58 | pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, 59 | no_para=no_para 60 | ) 61 | if args.local_rank in [-1, 0]: 62 | logger.info("Saving features into cached file %s", cached_features_file) 63 | torch.save(features, cached_features_file) 64 | else: 65 | logger.info("Evaluate on {}!".format(task)) 66 | task, adv_type = task.split('-') 67 | args.data_dir = os.path.join(args.raw_data_dir, task) 68 | processor = processors[task.lower()]() 69 | examples = processor.get_adv_examples(args.data_dir, adv_type=adv_type) 70 | label_list = processor.get_labels() 71 | logger.info("Total number: %s", str(len(examples))) 72 | features = convert_examples_to_features( 73 | examples, 74 | label_list, 75 | args.max_seq_length, 76 | tokenizer, 77 | pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet 78 | pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, 79 | no_para=no_para 80 | ) 81 | all_features.extend(features) 82 | if args.local_rank == 0: 83 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 84 | features = all_features 85 | # Convert to Tensors and build dataset 86 | all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long) 87 | all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long) 88 | all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long) 89 | all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long) 90 | if no_para: 91 | all_input_ids_np = torch.tensor(select_field(features, 'input_ids_np'), dtype=torch.long) 92 | all_input_mask_np = torch.tensor(select_field(features, 'input_mask_np'), dtype=torch.long) 93 | all_segment_ids_np = torch.tensor(select_field(features, 'segment_ids_np'), dtype=torch.long) 94 | 95 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 96 | all_input_ids_np, all_input_mask_np, all_segment_ids_np, 97 | all_label_ids) 98 | return dataset 99 | 100 | 101 | def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False, no_para=True): 102 | if 'Add' not in task: 103 | if args.local_rank not in [-1, 0]: 104 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 105 | args.data_dir = os.path.join(args.raw_data_dir, task) 106 | processor = processors[task.lower()]() 107 | # Load data features from cache or dataset file 108 | if evaluate: 109 | cached_mode = 'dev' 110 | elif test: 111 | cached_mode = 'test' 112 | else: 113 | cached_mode = 'train' 114 | assert (evaluate == True and test == True) == False 115 | cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format( 116 | cached_mode, 117 | list(filter(None, args.model_name_or_path.split('/'))).pop(), 118 | str(args.max_seq_length), 119 | str(task))) 120 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 121 | logger.info("Loading features from cached file %s", cached_features_file) 122 | features = torch.load(cached_features_file) 123 | else: 124 | logger.info("Creating features from dataset file at %s", args.data_dir) 125 | label_list = processor.get_labels() 126 | if evaluate: 127 | examples = processor.get_dev_examples(args.data_dir) 128 | elif test: 129 | examples = processor.get_test_examples(args.data_dir) 130 | else: 131 | examples = processor.get_train_examples(args.data_dir) 132 | logger.info("Total number: %s", str(len(examples))) 133 | features = convert_examples_to_features( 134 | examples, 135 | label_list, 136 | args.max_seq_length, 137 | tokenizer, 138 | pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet 139 | pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, 140 | no_para=no_para 141 | ) 142 | if args.local_rank in [-1, 0]: 143 | logger.info("Saving features into cached file %s", cached_features_file) 144 | torch.save(features, cached_features_file) 145 | else: 146 | logger.info("Evaluate on {}!".format(task)) 147 | task, adv_type = task.split('-') 148 | args.data_dir = os.path.join(args.raw_data_dir, task) 149 | processor = processors[task.lower()]() 150 | examples = processor.get_adv_examples(args.data_dir, adv_type=adv_type) 151 | label_list = processor.get_labels() 152 | logger.info("Total number: %s", str(len(examples))) 153 | features = convert_examples_to_features( 154 | examples, 155 | label_list, 156 | args.max_seq_length, 157 | tokenizer, 158 | pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet 159 | pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, 160 | no_para=no_para 161 | ) 162 | 163 | if args.local_rank == 0: 164 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 165 | 166 | # Convert to Tensors and build dataset 167 | all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long) 168 | all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long) 169 | all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long) 170 | all_label_ids = torch.tensor([f.label for f in features], dtype=torch.long) 171 | if no_para: 172 | all_input_ids_np = torch.tensor(select_field(features, 'input_ids_np'), dtype=torch.long) 173 | all_input_mask_np = torch.tensor(select_field(features, 'input_mask_np'), dtype=torch.long) 174 | all_segment_ids_np = torch.tensor(select_field(features, 'segment_ids_np'), dtype=torch.long) 175 | 176 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, 177 | all_input_ids_np, all_input_mask_np, all_segment_ids_np, 178 | all_label_ids) 179 | return dataset 180 | 181 | 182 | class InputExample(object): 183 | """A single training/test example for multiple choice""" 184 | 185 | def __init__(self, example_id, question, contexts, endings, label=None): 186 | """Constructs a InputExample. 187 | 188 | Args: 189 | example_id: Unique id for the example. 190 | contexts: list of str. The untokenized text of the first sequence (context of corresponding question). 191 | question: string. The untokenized text of the second sequence (question). 192 | endings: list of str. multiple choice's options. Its length must be equal to contexts' length. 193 | label: (Optional) string. The label of the example. This should be 194 | specified for train and dev examples, but not for test examples. 195 | """ 196 | self.example_id = example_id 197 | self.question = question 198 | self.contexts = contexts 199 | self.endings = endings 200 | self.label = label 201 | 202 | 203 | class InputFeatures(object): 204 | def __init__(self, 205 | example_id, 206 | choices_features, 207 | label 208 | 209 | ): 210 | self.example_id = example_id 211 | self.choices_features = [ 212 | { 213 | 'input_ids': input_ids, 214 | 'input_mask': input_mask, 215 | 'segment_ids': segment_ids, 216 | 'input_ids_np': input_ids_np, 217 | 'input_mask_np': input_mask_np, 218 | 'segment_ids_np': segment_ids_np, 219 | } 220 | for input_ids, input_mask, segment_ids, 221 | input_ids_np, input_mask_np, segment_ids_np in choices_features 222 | ] 223 | self.label = label 224 | 225 | class DataProcessor(object): 226 | """Base class for data converters for multiple choice data sets.""" 227 | 228 | def get_train_examples(self, data_dir): 229 | """Gets a collection of `InputExample`s for the train set.""" 230 | raise NotImplementedError() 231 | 232 | def get_dev_examples(self, data_dir): 233 | """Gets a collection of `InputExample`s for the dev set.""" 234 | raise NotImplementedError() 235 | 236 | def get_test_examples(self, data_dir): 237 | """Gets a collection of `InputExample`s for the test set.""" 238 | raise NotImplementedError() 239 | 240 | def get_labels(self): 241 | """Gets the list of labels for this data set.""" 242 | raise NotImplementedError() 243 | 244 | 245 | class RaceProcessor(DataProcessor): 246 | """Processor for the RACE data set.""" 247 | 248 | def get_train_examples(self, data_dir): 249 | """See base class.""" 250 | logger.info("LOOKING AT {} train".format(data_dir)) 251 | high = os.path.join(data_dir, 'train/high') 252 | middle = os.path.join(data_dir, 'train/middle') 253 | high = self._read_txt(high) 254 | middle = self._read_txt(middle) 255 | return self._create_examples(high + middle, 'train') 256 | 257 | def get_dev_examples(self, data_dir): 258 | """See base class.""" 259 | logger.info("LOOKING AT {} dev".format(data_dir)) 260 | high = os.path.join(data_dir, 'dev/high') 261 | middle = os.path.join(data_dir, 'dev/middle') 262 | high = self._read_txt(high) 263 | middle = self._read_txt(middle) 264 | return self._create_examples(high + middle, 'dev') 265 | 266 | def get_test_examples(self, data_dir): 267 | """See base class.""" 268 | logger.info("LOOKING AT {} test".format(data_dir)) 269 | high = os.path.join(data_dir, 'test/high') 270 | middle = os.path.join(data_dir, 'test/middle') 271 | high = self._read_txt(high) 272 | middle = self._read_txt(middle) 273 | return self._create_examples(high + middle, 'test') 274 | 275 | def get_adv_examples(self, data_dir, adv_type): 276 | path = os.path.join(data_dir, adv_type+'.pkl') 277 | with open(path, 'rb') as f: 278 | examples = torch.load(f) 279 | return examples 280 | 281 | def get_labels(self): 282 | """See base class.""" 283 | return ["0", "1", "2", "3"] 284 | 285 | def _read_txt(self, input_dir): 286 | lines = [] 287 | files = glob.glob(input_dir + "/*txt") 288 | for file in tqdm.tqdm(files, desc="read files"): 289 | with open(file, 'r', encoding='utf-8') as fin: 290 | data_raw = json.load(fin) 291 | data_raw["race_id"] = file 292 | lines.append(data_raw) 293 | return lines 294 | 295 | 296 | def _create_examples(self, lines, set_type): 297 | """Creates examples for the training and dev sets.""" 298 | examples = [] 299 | for (_, data_raw) in enumerate(lines): 300 | race_id = "%s-%s" % (set_type, data_raw["race_id"]) 301 | article = data_raw["article"] 302 | for i in range(len(data_raw["answers"])): 303 | truth = str(ord(data_raw['answers'][i]) - ord('A')) 304 | question = data_raw['questions'][i] 305 | options = data_raw['options'][i] 306 | 307 | examples.append( 308 | InputExample( 309 | example_id=race_id, 310 | question=question, 311 | contexts=[article, article, article, article], # this is not efficient but convenient 312 | endings=[options[0], options[1], options[2], options[3]], 313 | label=truth)) 314 | return examples 315 | 316 | class ArcProcessor(DataProcessor): 317 | """Processor for the ARC data set (request from allennlp).""" 318 | 319 | def get_train_examples(self, data_dir): 320 | """See base class.""" 321 | logger.info("LOOKING AT {} train".format(data_dir)) 322 | return self._create_examples(self._read_json(os.path.join(data_dir, "train.jsonl")), "train") 323 | 324 | def get_dev_examples(self, data_dir): 325 | """See base class.""" 326 | logger.info("LOOKING AT {} dev".format(data_dir)) 327 | return self._create_examples(self._read_json(os.path.join(data_dir, "dev.jsonl")), "dev") 328 | 329 | def get_test_examples(self, data_dir): 330 | logger.info("LOOKING AT {} test".format(data_dir)) 331 | return self._create_examples(self._read_json(os.path.join(data_dir, "test.jsonl")), "test") 332 | 333 | def get_labels(self): 334 | """See base class.""" 335 | return ["0", "1", "2", "3"] 336 | 337 | def _read_json(self, input_file): 338 | with open(input_file, 'r', encoding='utf-8') as fin: 339 | lines = fin.readlines() 340 | return lines 341 | 342 | 343 | def _create_examples(self, lines, type): 344 | """Creates examples for the training and dev sets.""" 345 | 346 | #There are two types of labels. They should be normalized 347 | def normalize(truth): 348 | if truth in "ABCD": 349 | return ord(truth) - ord("A") 350 | elif truth in "1234": 351 | return int(truth) - 1 352 | else: 353 | logger.info("truth ERROR! %s", str(truth)) 354 | return None 355 | 356 | examples = [] 357 | three_choice = 0 358 | four_choice = 0 359 | five_choice = 0 360 | other_choices = 0 361 | # we deleted example which has more than or less than four choices 362 | for line in tqdm.tqdm(lines, desc="read arc data"): 363 | data_raw = json.loads(line.strip("\n")) 364 | if len(data_raw["question"]["choices"]) == 3: 365 | three_choice += 1 366 | continue 367 | elif len(data_raw["question"]["choices"]) == 5: 368 | five_choice += 1 369 | continue 370 | elif len(data_raw["question"]["choices"]) != 4: 371 | other_choices += 1 372 | continue 373 | four_choice += 1 374 | truth = str(normalize(data_raw["answerKey"])) 375 | assert truth != "None" 376 | question_choices = data_raw["question"] 377 | question = question_choices["stem"] 378 | id = data_raw["id"] 379 | options = question_choices["choices"] 380 | if len(options) == 4: 381 | examples.append( 382 | InputExample( 383 | example_id = id, 384 | question=question, 385 | contexts=[options[0]["para"].replace("_", ""), options[1]["para"].replace("_", ""), 386 | options[2]["para"].replace("_", ""), options[3]["para"].replace("_", "")], 387 | endings=[options[0]["text"], options[1]["text"], options[2]["text"], options[3]["text"]], 388 | label=truth)) 389 | 390 | if type == "train": 391 | assert len(examples) > 1 392 | assert examples[0].label is not None 393 | logger.info("len examples: %s}", str(len(examples))) 394 | logger.info("Three choices: %s", str(three_choice)) 395 | logger.info("Five choices: %s", str(five_choice)) 396 | logger.info("Other choices: %s", str(other_choices)) 397 | logger.info("four choices: %s", str(four_choice)) 398 | 399 | return examples 400 | 401 | class MctestProcessor(DataProcessor): 402 | """Processor for the MCTest data set (request from allennlp).""" 403 | 404 | def get_train_examples(self, data_dir): 405 | """See base class.""" 406 | logger.info("LOOKING AT {} train".format(data_dir)) 407 | return self._create_examples(self._read_file(data_dir, "train"), "train") 408 | 409 | def get_dev_examples(self, data_dir): 410 | """See base class.""" 411 | logger.info("LOOKING AT {} dev".format(data_dir)) 412 | return self._create_examples(self._read_file(data_dir, "dev"), "dev") 413 | 414 | def get_test_examples(self, data_dir): 415 | logger.info("LOOKING AT {} test".format(data_dir)) 416 | return self._create_examples(self._read_file(data_dir, "test"), "test") 417 | 418 | def get_adv_examples(self, data_dir, adv_type): 419 | path = os.path.join(data_dir, adv_type + '.pkl') 420 | with open(path, 'rb') as f: 421 | examples = torch.load(f) 422 | return examples 423 | 424 | def get_labels(self): 425 | """See base class.""" 426 | return [0, 1, 2, 3] 427 | 428 | def _read_file(self, data_dir, set_name): 429 | context_160_file = 'mc160.' + set_name + '.tsv' 430 | context_500_file = 'mc500.' + set_name + '.tsv' 431 | answer_160_file = 'mc160.' + set_name + '.ans' 432 | answer_500_file = 'mc500.' + set_name + '.ans' 433 | with open(os.path.join(data_dir, context_160_file)) as f: 434 | context_160 = f.read() 435 | with open(os.path.join(data_dir, context_500_file)) as f: 436 | context_500 = f.read() 437 | with open(os.path.join(data_dir, answer_160_file)) as f: 438 | answer_160 = f.read() 439 | with open(os.path.join(data_dir, answer_500_file)) as f: 440 | answer_500 = f.read() 441 | context = (context_160, context_500) 442 | answer = (answer_160, answer_500) 443 | return (context, answer) 444 | 445 | def _create_examples(self, context_answer, set_name): 446 | raw_context, raw_answer = context_answer[0], context_answer[1] 447 | raw_context_160, raw_context_500 = raw_context[0], raw_context[1] 448 | raw_answer_160, raw_answer_500 = raw_answer[0], raw_answer[1] 449 | answer_160 = [ord(option)-ord('A') for option in raw_answer_160 if option in ['A', 'B', 'C', 'D']] 450 | context_160 = raw_context_160.split('\n')[:-1] 451 | answer_500 = [ord(option)-ord('A') for option in raw_answer_500 if option in ['A', 'B', 'C', 'D']] 452 | context_500 = raw_context_500.split('\n')[:-1] 453 | context = context_160+context_500 454 | answer = answer_160+answer_500 455 | idx = 0 456 | examples = [] 457 | for i, sample in enumerate(context): 458 | elements = sample.split('\t') 459 | passage = elements[2] # remove title newlines and tabs 460 | passage = re.sub(r'\\newline', '\n', passage) 461 | passage = re.sub(r'\\tab', '\t', passage) 462 | for j in range(4): 463 | question_elements = elements[3 + 5 * j:3 + 5 * (j + 1)] # get question elements 464 | qtype, qtext = question_elements[0].split(': ') # get question type and text 465 | options = [text for text in question_elements[1:5]] # get answers 466 | truth = answer[idx] # get correct answer (from answer data) 467 | idx += 1 468 | examples.append( 469 | InputExample( 470 | example_id=idx, 471 | question=qtext, 472 | contexts=[passage, passage, passage, passage], # this is not efficient but convenient 473 | endings=[options[0], options[1], options[2], options[3]], 474 | label=truth)) 475 | assert len(examples) == len(answer) 476 | return examples 477 | 478 | class Semeval2018Processor(DataProcessor): 479 | """Processor for the SemEval 2018 Task 11 data set.""" 480 | 481 | def get_train_examples(self, data_dir): 482 | """See base class.""" 483 | logger.info("LOOKING AT {} train".format(data_dir)) 484 | return self._create_examples(self._read_json(os.path.join(data_dir, "train-data.json")), "train") 485 | 486 | def get_dev_examples(self, data_dir): 487 | """See base class.""" 488 | logger.info("LOOKING AT {} dev".format(data_dir)) 489 | return self._create_examples(self._read_json(os.path.join(data_dir, "dev-data.json")), "dev") 490 | 491 | def get_test_examples(self, data_dir): 492 | logger.info("LOOKING AT {} test".format(data_dir)) 493 | return self._create_examples(self._read_json(os.path.join(data_dir, "test-data.json")), "test") 494 | 495 | def get_adv_examples(self, data_dir, adv_type): 496 | path = os.path.join(data_dir, adv_type + '.pkl') 497 | with open(path, 'rb') as f: 498 | examples = torch.load(f) 499 | return examples 500 | 501 | def get_labels(self): 502 | """See base class.""" 503 | return [0, 1] 504 | 505 | def _read_json(self, input_file): 506 | with open(input_file, 'r', encoding='utf-8') as f: 507 | data = json.load(f) 508 | return data 509 | 510 | def _create_examples(self, data, type): 511 | """Creates examples for the training and dev sets.""" 512 | dataset = data['data']['instance'] 513 | idx = 0 514 | examples = [] 515 | corrupted_sample = 0 516 | for i in dataset: 517 | passage = i['text'] 518 | try: 519 | if isinstance(i['questions']['question'], list): 520 | for question in i['questions']['question']: 521 | query = question['@text'] 522 | options = [item['@text'] for item in question['answer']] 523 | truth = 0 if question['answer'][0]['@correct'] == 'True' else 1 524 | examples.append( 525 | InputExample( 526 | example_id=idx, 527 | question=query, 528 | contexts=[passage, passage], # this is not efficient but convenient 529 | endings=[options[0], options[1]], 530 | label=truth)) 531 | idx += 1 532 | else: 533 | question = i['questions']['question'] 534 | query = question['@text'] 535 | options = [item['@text'] for item in question['answer']] 536 | truth = 0 if question['answer'][0]['@correct'] == 'True' else 1 537 | examples.append( 538 | InputExample( 539 | example_id=idx, 540 | question=query, 541 | contexts=[passage, passage], # this is not efficient but convenient 542 | endings=[options[0], options[1]], 543 | label=truth)) 544 | idx += 1 545 | except: 546 | corrupted_sample += 1 547 | continue 548 | logger.info(" Corrupted sample :{}".format(corrupted_sample)) 549 | return examples 550 | 551 | class DreamProcessor(DataProcessor): 552 | """Processor for the DREAM 2018 Task 11 data set.""" 553 | 554 | def get_train_examples(self, data_dir): 555 | """See base class.""" 556 | logger.info("LOOKING AT {} train".format(data_dir)) 557 | return self._create_examples(self._read_json(os.path.join(data_dir, "train.json")), "train") 558 | 559 | def get_dev_examples(self, data_dir): 560 | """See base class.""" 561 | logger.info("LOOKING AT {} dev".format(data_dir)) 562 | return self._create_examples(self._read_json(os.path.join(data_dir, "dev.json")), "dev") 563 | 564 | def get_test_examples(self, data_dir): 565 | logger.info("LOOKING AT {} test".format(data_dir)) 566 | return self._create_examples(self._read_json(os.path.join(data_dir, "test.json")), "test") 567 | 568 | def get_adv_examples(self, data_dir, adv_type): 569 | path = os.path.join(data_dir, adv_type+'.pkl') 570 | with open(path, 'rb') as f: 571 | examples = torch.load(f) 572 | return examples 573 | 574 | def get_labels(self): 575 | """See base class.""" 576 | return [0, 1, 2] 577 | 578 | def _read_json(self, input_file): 579 | with open(input_file, 'r', encoding='utf-8') as f: 580 | data = json.load(f) 581 | return data 582 | 583 | def _create_examples(self, data, type): 584 | idx = 0 585 | examples = [] 586 | for item in data: 587 | passage = ' '.join(item[0]) 588 | for question in item[1]: 589 | query = question['question'] 590 | options = question['choice'] 591 | truth = options.index(question['answer']) 592 | examples.append( 593 | InputExample( 594 | example_id=idx, 595 | question=query, 596 | contexts=[passage, passage, passage], # this is not efficient but convenient 597 | endings=[options[0], options[1], options[2]], 598 | label=truth)) 599 | idx += 1 600 | return examples 601 | 602 | 603 | 604 | def convert_examples_to_features( 605 | examples: List[InputExample], 606 | label_list: List[str], 607 | max_length: int, 608 | tokenizer: PreTrainedTokenizer, 609 | pad_token_segment_id=0, 610 | pad_on_left=False, 611 | pad_token=0, 612 | mask_padding_with_zero=True, 613 | no_para=False 614 | ) -> List[InputFeatures]: 615 | """ 616 | Loads a data file into a list of `InputFeatures` 617 | """ 618 | label_map = {label : i for i, label in enumerate(label_list)} 619 | 620 | features = [] 621 | for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"): 622 | # if ex_index % 10000 == 0: 623 | # logger.info("Writing example %d of %d" % (ex_index, len(examples))) 624 | choices_features = [] 625 | for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)): 626 | text_a = context 627 | if example.question.find("_") != -1: 628 | # this is for cloze question 629 | text_b = example.question.replace("_", ending) 630 | else: 631 | text_b = example.question + " " + ending 632 | 633 | inputs = tokenizer.encode_plus( 634 | text_a, 635 | text_b, 636 | add_special_tokens=True, 637 | max_length=max_length, 638 | ) 639 | # if 'num_truncated_tokens' in inputs and inputs['num_truncated_tokens'] > 0: 640 | # logger.info('Attention! you are cropping tokens (swag task is ok). ' 641 | # 'If you are training ARC and RACE and you are poping question + options,' 642 | # 'you need to try to use a bigger max seq length!') 643 | 644 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] 645 | 646 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 647 | # tokens are attended to. 648 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 649 | 650 | # Zero-pad up to the sequence length. 651 | padding_length = max_length - len(input_ids) 652 | if pad_on_left: 653 | input_ids = ([pad_token] * padding_length) + input_ids 654 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 655 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids 656 | else: 657 | input_ids = input_ids + ([pad_token] * padding_length) 658 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 659 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 660 | 661 | if no_para: 662 | text_a_np = ' '.join([tokenizer.pad_token]*1) 663 | text_b_np = ending 664 | inputs_np = tokenizer.encode_plus( 665 | text_a, 666 | text_b_np, 667 | add_special_tokens=True, 668 | max_length=max_length, 669 | ) 670 | input_ids_np, token_type_ids_np = inputs_np["input_ids"], inputs_np["token_type_ids"] 671 | attention_mask_np = [1 if mask_padding_with_zero else 0] * len(input_ids_np) 672 | padding_length_np = max_length - len(input_ids_np) 673 | if pad_on_left: 674 | input_ids_np = ([pad_token] * padding_length_np) + input_ids_np 675 | attention_mask_np = ([0 if mask_padding_with_zero else 1] * padding_length_np) + attention_mask_np 676 | token_type_ids_np = ([pad_token_segment_id] * padding_length_np) + token_type_ids_np 677 | else: 678 | input_ids_np = input_ids_np + ([pad_token] * padding_length_np) 679 | attention_mask_np = attention_mask_np + ([0 if mask_padding_with_zero else 1] * padding_length_np) 680 | token_type_ids_np = token_type_ids_np + ([pad_token_segment_id] * padding_length_np) 681 | assert len(input_ids_np) == max_length 682 | assert len(attention_mask_np) == max_length 683 | assert len(token_type_ids_np) == max_length 684 | 685 | assert len(input_ids) == max_length 686 | assert len(attention_mask) == max_length 687 | assert len(token_type_ids) == max_length 688 | choices_features.append((input_ids, attention_mask, token_type_ids, 689 | input_ids_np, attention_mask_np, token_type_ids_np)) 690 | 691 | 692 | label = label_map[example.label] 693 | features.append( 694 | InputFeatures( 695 | example_id=example.example_id, 696 | choices_features=choices_features, 697 | label=label, 698 | ) 699 | ) 700 | 701 | return features 702 | 703 | def select_field(features, field): 704 | return [ 705 | [ 706 | choice[field] 707 | for choice in feature.choices_features 708 | ] 709 | for feature in features 710 | ] 711 | 712 | 713 | 714 | processors = { 715 | "race": RaceProcessor, 716 | "arc": ArcProcessor, 717 | "mctest": MctestProcessor, 718 | "semeval2018": Semeval2018Processor, 719 | "dream": DreamProcessor, 720 | } 721 | 722 | 723 | MULTIPLE_CHOICE_TASKS_NUM_LABELS = { 724 | "race": 4, 725 | "arc": 4, 726 | "mctest": 4, 727 | "semeval2018": 2, 728 | "dream": 3 729 | } 730 | --------------------------------------------------------------------------------