├── LICENSE ├── README.md ├── measure_evidence.py ├── preprocessing.py ├── run_mrc.py └── src ├── functions ├── __pycache__ │ ├── evaluate_v1_0.cpython-37.pyc │ ├── hotpotqa_metric.cpython-37.pyc │ ├── processor.cpython-37.pyc │ ├── processor_sent.cpython-37.pyc │ ├── squad_metric.cpython-37.pyc │ └── utils.cpython-37.pyc ├── evaluate_v1_0.py ├── hotpotqa_metric.py ├── processor_sent.py ├── squad_metric.py └── utils.py └── model ├── __pycache__ ├── attention.cpython-37.pyc ├── main_function_rnn.cpython-37.pyc ├── main_functions.cpython-37.pyc ├── model.cpython-37.pyc └── model_rnn.cpython-37.pyc ├── main_function_rnn.py └── model_rnn.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Nicola De Cao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XAI_EvidenceExtraction 2 | 3 | 4 | # Dependencies 5 | * python 3.7 6 | * PyTorch 1.6.0 7 | * Transformers 2.11.0 8 | 9 | # Data 10 | * HOTPOT QA 11 | 12 | # Train & Test 13 | * Train : run_mrc.py --init_weight True --do_train True 14 | * Test : run_mrc.py --init_weight False --do_eval True --checkpoint [saved model global Step] 15 | 16 | -------------------------------------------------------------------------------- /measure_evidence.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | 6 | # answer_file = open('./all_data/refine_hotpot_dev_distractor_v1.json','r',encoding='utf8') 7 | answer_file = open('./data/hotpot_dev_distractor_v1.json','r',encoding='utf8') 8 | 9 | answers = json.load(answer_file) 10 | 11 | prediction_dict = {} 12 | n = 0 13 | c = 0 14 | d = './first' 15 | for f_name in os.listdir(d): 16 | if 'nbest' not in f_name: 17 | continue 18 | # if '38000' not in f_name: 19 | # continue 20 | print('\n\n\n\n', f_name) 21 | predict_file = open(os.path.join(d, '{}'.format(f_name)),'r',encoding='utf8') 22 | predictions = json.load(predict_file) 23 | for qas in predictions.keys(): 24 | # if len(predictions[qas][0]["evidence"])>1: 25 | # print(qas) 26 | prediction_dict[qas] = [predictions[qas][0]["text"], predictions[qas][0]["evidence"]] 27 | # if prediction_dict[qas][0] not in prediction_dict[qas][1][0] and prediction_dict[qas][0] in ' '.join(prediction_dict[qas][1]): 28 | # print(qas) 29 | # print("???") 30 | 31 | 32 | # print(c/n) 33 | 34 | all_answer_num = 0 35 | all_predict_num = 0 36 | all_correct = 0 37 | 38 | precision_list = [] 39 | recall_list = [] 40 | f1_list = [] 41 | for data in answers: 42 | answer_num = 0 43 | predict_num = 0 44 | correct = 0 45 | qas_id = data["_id"] 46 | answer_text = data["answer"] 47 | documents = {e[0]: e[1] for e in data["context"]} 48 | supporting_facts = data['supporting_facts'] 49 | supporting_sentences = [] 50 | for idx, support_fact in enumerate(supporting_facts): 51 | try: 52 | sentence = documents[support_fact[0]][support_fact[1]] 53 | supporting_sentences.append(sentence) 54 | except: 55 | continue 56 | # if answer_text.strip() != prediction_dict[qas_id][0].strip(): 57 | # continue 58 | # if len(prediction_dict[qas_id][1]) < 3: 59 | # continue 60 | # if prediction_dict[qas_id][0] not in prediction_dict[qas_id][1][0]: 61 | # print(qas_id) 62 | # print("???") 63 | try: 64 | prediction = list(prediction_dict[qas_id][1]) 65 | except: 66 | print(qas_id) 67 | continue 68 | # tmp = ''.join(prediction) 69 | # if answer_text not in tmp: 70 | # print("??") 71 | for sent in supporting_sentences: 72 | if sent in prediction: 73 | correct +=1 74 | 75 | all_correct +=1 76 | 77 | answer_num +=1 78 | 79 | predict_num += len(prediction) 80 | all_answer_num += len(supporting_sentences) 81 | all_predict_num += len(prediction) 82 | if not predict_num: 83 | predict_num = 1e-10 84 | precision = correct / predict_num 85 | recall = correct / answer_num 86 | # if recall == 1.0 and precision == 1: 87 | # print(qas_id) 88 | 89 | # if recall < 0.5: 90 | # print(qas_id) 91 | # if correct == 0: 92 | # print(data['question']) 93 | # 94 | # print("??") 95 | f1 = (2*precision*recall) / (recall + precision + 1e-10) 96 | 97 | precision_list.append(precision) 98 | recall_list.append(recall) 99 | f1_list.append(f1) 100 | 101 | per_precision = sum(precision_list) / len(precision_list) 102 | per_recall = sum(recall_list) / len(recall_list) 103 | per_f1 = sum(f1_list) / len(f1_list) 104 | 105 | 106 | print("Per Precision : {}\tRecall : {}\tF1 : {}".format(round(per_precision, 3), round(per_recall, 3), round(per_f1, 3))) 107 | 108 | all_precision = all_correct / all_predict_num 109 | all_recall = all_correct / all_answer_num 110 | all_f1 = (2*all_precision*all_recall) / (all_recall + all_precision + 1e-10) 111 | 112 | print("All Precision : {}\tRecall : {}\tF1 : {}".format(round(all_precision, 3), round(all_recall, 3), round(all_f1, 3))) 113 | 114 | 115 | 116 | 117 | # 다음주 화요일 논의 118 | 119 | -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/preprocessing.py -------------------------------------------------------------------------------- /run_mrc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import logging 4 | from attrdict import AttrDict 5 | import torch 6 | from transformers import ElectraTokenizer, ElectraConfig 7 | from transformers.modeling_electra import ElectraModel 8 | 9 | # from src.model.model import ElectraForQuestionAnswering as ElectraForQuestionAnswering 10 | # from src.model.main_functions import train, evaluate, predict 11 | # from src.functions.utils import init_logger, set_seed 12 | 13 | from src.model.model_rnn import ElectraForQuestionAnswering_1208 as ElectraForQuestionAnswering 14 | from src.model.main_function_rnn import sample_train2, evaluate 15 | from src.functions.utils import init_logger, set_seed 16 | 17 | # from src.model.model import ElectraForQuestionAnswering_graph as ElectraForQuestionAnswering 18 | # from src.model.main_functions_graph import train, evaluate 19 | # from src.graph_functions.utils import init_logger, set_seed 20 | 21 | def create_model(args): 22 | # 모델 파라미터 Load 23 | config = ElectraConfig.from_pretrained( 24 | args.model_name_or_path if args.from_init_weight else os.path.join(args.output_dir, "checkpoint-{}".format(args.checkpoint)), 25 | # os.path.join("./first", "checkpoint-{}".format(args.checkpoint)) if args.from_init_weight else os.path.join(args.output_dir, "checkpoint-{}".format(args.checkpoint)), 26 | 27 | ) 28 | 29 | # tokenizer는 pre-trained된 것을 불러오는 과정이 아닌 불러오는 모델의 vocab 등을 Load 30 | tokenizer = ElectraTokenizer.from_pretrained( 31 | args.model_name_or_path if args.from_init_weight else os.path.join(args.output_dir, "checkpoint-{}".format(args.checkpoint)), 32 | # os.path.join("./first", "checkpoint-{}".format(args.checkpoint)) if args.from_init_weight else os.path.join( args.output_dir, "checkpoint-{}".format(args.checkpoint)), 33 | do_lower_case=args.do_lower_case, 34 | 35 | ) 36 | config.max_sent_num = args.max_sent_num 37 | config.max_dec_len = args.max_dec_len 38 | config.num_samples = args.num_samples 39 | model = ElectraForQuestionAnswering.from_pretrained( 40 | # os.path.join("./first", "checkpoint-{}".format(args.checkpoint)) if args.from_init_weight else os.path.join(args.output_dir, "checkpoint-{}".format(args.checkpoint)), 41 | args.model_name_or_path if args.from_init_weight else os.path.join(args.output_dir, "checkpoint-{}".format(args.checkpoint)), 42 | config=config, 43 | # from_tf= True if args.from_init_weight else False 44 | ) 45 | 46 | 47 | 48 | # vocab 추가 49 | # 중요 단어의 UNK 방지 및 tokenize를 방지해야하는 경우(HTML 태그 등)에 활용 50 | # "세종대왕"이 OOV인 경우 ['세종대왕'] --tokenize--> ['UNK'] (X) 51 | # html tag인 [td]는 tokenize가 되지 않아야 함. (완전한 tag의 형태를 갖췄을 때, 의미를 갖기 때문) 52 | # ['[td]'] --tokenize--> ['[', 't', 'd', ']'] (X) 53 | 54 | if args.from_init_weight and args.add_vocab: 55 | if args.from_init_weight: 56 | add_token = { 57 | "additional_special_tokens": ["[td]", "추가 단어 1", "추가 단어 2"]} 58 | # 추가된 단어는 tokenize 되지 않음 59 | # ex 60 | # '[td]' vocab 추가 전 -> ['[', 't', 'd', ']'] 61 | # '[td]' vocab 추가 후 -> ['[td]'] 62 | tokenizer.add_special_tokens(add_token) 63 | model.resize_token_embeddings(len(tokenizer)) 64 | model.to(args.device) 65 | return model, tokenizer 66 | 67 | def main(cli_args): 68 | # 파라미터 업데이트 69 | args = AttrDict(vars(cli_args)) 70 | args.device = "cuda" 71 | logger = logging.getLogger(__name__) 72 | 73 | # logger 및 seed 지정 74 | init_logger() 75 | set_seed(args) 76 | 77 | # 모델 불러오기 78 | 79 | model, tokenizer = create_model(args) 80 | # Running mode에 따른 실행 81 | if args.do_train: 82 | sample_train2(args, model, tokenizer, logger) 83 | elif args.do_eval: 84 | model, tokenizer = create_model(args) 85 | evaluate(args, model, tokenizer, logger) 86 | elif args.do_predict: 87 | predict(args, model, tokenizer) 88 | if __name__ == '__main__': 89 | cli_parser = argparse.ArgumentParser() 90 | 91 | # Directory 92 | cli_parser.add_argument("--data_dir", type=str, default="./data") 93 | # cli_parser.add_argument("--data_dir", type=str, default="./all_data") 94 | cli_parser.add_argument("--model_name_or_path", type=str, default="./init_weight") 95 | 96 | cli_parser.add_argument("--output_dir", type=str, default="./4sampled_first") 97 | 98 | cli_parser.add_argument("--train_file", type=str, default="hotpot_train_v1.1.json") 99 | cli_parser.add_argument("--predict_file", type=str, default="hotpot_dev_distractor_v1.json") 100 | # cli_parser.add_argument("--train_file", type=str, default="refine_hotpot_train_v1.1.json") 101 | # cli_parser.add_argument("--predict_file", type=str, default="refine_hotpot_dev_fullwiki_v1.json") 102 | cli_parser.add_argument("--checkpoint", type=str, default="26000") 103 | 104 | # Model Hyper Parameter 105 | cli_parser.add_argument("--max_seq_length", type=int, default=512) 106 | cli_parser.add_argument("--doc_stride", type=int, default=128) 107 | cli_parser.add_argument("--max_query_length", type=int, default=64) 108 | cli_parser.add_argument("--max_answer_length", type=int, default=30) 109 | cli_parser.add_argument("--n_best_size", type=int, default=20) 110 | 111 | 112 | # Training Parameter 113 | cli_parser.add_argument("--learning_rate", type=float, default=5e-5) 114 | cli_parser.add_argument("--train_batch_size", type=int, default=3) 115 | cli_parser.add_argument("--eval_batch_size", type=int, default=3) 116 | cli_parser.add_argument("--max_sent_num", type=int, default=40) 117 | cli_parser.add_argument("--num_samples", type=int, default=4) 118 | cli_parser.add_argument("--max_dec_len", type=int, default=3) 119 | cli_parser.add_argument("--num_train_epochs", type=int, default=5) 120 | 121 | cli_parser.add_argument("--save_steps", type=int, default=2000) 122 | cli_parser.add_argument("--logging_steps", type=int, default=2000) 123 | cli_parser.add_argument("--seed", type=int, default=42) 124 | cli_parser.add_argument("--threads", type=int, default=8) 125 | 126 | cli_parser.add_argument("--weight_decay", type=float, default=0.0) 127 | cli_parser.add_argument("--adam_epsilon", type=int, default=1e-10) 128 | 129 | cli_parser.add_argument("--gradient_accumulation_steps", type=int, default=4) 130 | cli_parser.add_argument("--warmup_steps", type=int, default=0) 131 | cli_parser.add_argument("--max_steps", type=int, default=-1) 132 | cli_parser.add_argument("--max_grad_norm", type=int, default=1.0) 133 | 134 | cli_parser.add_argument("--verbose_logging", type=bool, default=False) 135 | cli_parser.add_argument("--do_lower_case", type=bool, default=True) 136 | cli_parser.add_argument("--no_cuda", type=bool, default=False) 137 | 138 | # For SQuAD v2.0 (Yes/No Question) 139 | cli_parser.add_argument("--version_2_with_negative", type=bool, default=False) 140 | cli_parser.add_argument("--null_score_diff_threshold", type=float, default=0.0) 141 | 142 | # Running Mode 143 | cli_parser.add_argument("--from_init_weight", type=bool, default=True) 144 | cli_parser.add_argument("--add_vocab", type=bool, default=False) 145 | cli_parser.add_argument("--do_train", type=bool, default=True) 146 | cli_parser.add_argument("--do_eval", type=bool, default=True) 147 | cli_parser.add_argument("--do_predict", type=bool, default=False) 148 | cli_args = cli_parser.parse_args() 149 | 150 | 151 | main(cli_args) -------------------------------------------------------------------------------- /src/functions/__pycache__/evaluate_v1_0.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/functions/__pycache__/evaluate_v1_0.cpython-37.pyc -------------------------------------------------------------------------------- /src/functions/__pycache__/hotpotqa_metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/functions/__pycache__/hotpotqa_metric.cpython-37.pyc -------------------------------------------------------------------------------- /src/functions/__pycache__/processor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/functions/__pycache__/processor.cpython-37.pyc -------------------------------------------------------------------------------- /src/functions/__pycache__/processor_sent.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/functions/__pycache__/processor_sent.cpython-37.pyc -------------------------------------------------------------------------------- /src/functions/__pycache__/squad_metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/functions/__pycache__/squad_metric.cpython-37.pyc -------------------------------------------------------------------------------- /src/functions/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/functions/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/functions/evaluate_v1_0.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from collections import Counter 3 | import string 4 | import re 5 | import argparse 6 | import json 7 | import sys 8 | import os 9 | def normalize_answer(s): 10 | def remove_(text): 11 | ''' 불필요한 기호 제거 ''' 12 | text = re.sub("'", " ", text) 13 | text = re.sub('"', " ", text) 14 | text = re.sub('《', " ", text) 15 | text = re.sub('》', " ", text) 16 | text = re.sub('<', " ", text) 17 | text = re.sub('>', " ", text) 18 | text = re.sub('〈', " ", text) 19 | text = re.sub('〉', " ", text) 20 | text = re.sub("\(", " ", text) 21 | text = re.sub("\)", " ", text) 22 | text = re.sub("‘", " ", text) 23 | text = re.sub("’", " ", text) 24 | return text 25 | 26 | def white_space_fix(text): 27 | return ' '.join(text.split()) 28 | 29 | def remove_punc(text): 30 | exclude = set(string.punctuation) 31 | return ''.join(ch for ch in text if ch not in exclude) 32 | 33 | def lower(text): 34 | return text.lower() 35 | 36 | return white_space_fix(remove_punc(lower(remove_(s)))) 37 | 38 | 39 | def f1_score(prediction, ground_truth): 40 | prediction_tokens = normalize_answer(prediction).split() 41 | ground_truth_tokens = normalize_answer(ground_truth).split() 42 | 43 | # F1 by character 44 | prediction_Char = [] 45 | for tok in prediction_tokens: 46 | now = [a for a in tok] 47 | prediction_Char.extend(now) 48 | ground_truth_Char = [] 49 | for tok in ground_truth_tokens: 50 | now = [a for a in tok] 51 | ground_truth_Char.extend(now) 52 | common = Counter(prediction_Char) & Counter(ground_truth_Char) 53 | num_same = sum(common.values()) 54 | if num_same == 0: 55 | return 0 56 | 57 | precision = 1.0 * num_same / len(prediction_Char) 58 | recall = 1.0 * num_same / len(ground_truth_Char) 59 | f1 = (2 * precision * recall) / (precision + recall) 60 | 61 | return f1 62 | 63 | def exact_match_score(prediction, ground_truth): 64 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 65 | 66 | 67 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 68 | scores_for_ground_truths = [] 69 | for ground_truth in ground_truths: 70 | score = metric_fn(prediction, ground_truth) 71 | scores_for_ground_truths.append(score) 72 | return max(scores_for_ground_truths) 73 | 74 | 75 | def evaluate(dataset, predictions): 76 | f1 = exact_match = total = 0 77 | for article in dataset: 78 | title = article["_id"] 79 | q_id = article['_id'] 80 | if q_id not in predictions: 81 | message = 'Unanswered question ' + title + \ 82 | ' will receive score 0.' 83 | print(message, file=sys.stderr) 84 | continue 85 | total += 1 86 | ground_truths = [article["answer"]] 87 | prediction = predictions[q_id] 88 | e = metric_max_over_ground_truths( 89 | exact_match_score, prediction, ground_truths) 90 | exact_match += metric_max_over_ground_truths( 91 | exact_match_score, prediction, ground_truths) 92 | f1 += metric_max_over_ground_truths( 93 | f1_score, prediction, ground_truths) 94 | 95 | exact_match = 100.0 * exact_match / total 96 | f1 = 100.0 * f1 / total 97 | return {'official_exact_match': exact_match, 'official_f1': f1} 98 | 99 | 100 | def eval_during_train(args, global_step): 101 | expected_version = 'KorQuAD_v1.0' 102 | 103 | dataset_file = os.path.join(args.data_dir, args.predict_file) 104 | prediction_file = os.path.join(args.output_dir, 'predictions_{}.json'.format(global_step)) 105 | 106 | with open(dataset_file) as dataset_f: 107 | dataset_json = json.load(dataset_f) 108 | 109 | dataset = dataset_json 110 | with open(prediction_file) as prediction_f: 111 | predictions = json.load(prediction_f) 112 | 113 | return evaluate(dataset, predictions) 114 | 115 | 116 | if __name__ == '__main__': 117 | expected_version = 'KorQuAD_v1.0' 118 | parser = argparse.ArgumentParser( 119 | description='Evaluation for KorQuAD ' + expected_version) 120 | parser.add_argument('--dataset_file', default="../../data/hotpot_dev_distractor_v1.json") 121 | parser.add_argument('--prediction_file',default="../../proposed_model_1019/predictions_14000.json") 122 | 123 | args = parser.parse_args() 124 | with open(args.dataset_file) as dataset_file: 125 | dataset_json = json.load(dataset_file) 126 | # read_version = "_".join(dataset_json['version'].split("_")[:-1]) 127 | # if (read_version != expected_version): 128 | # print('Evaluation expects ' + expected_version + 129 | # ', but got dataset with ' + read_version, 130 | # file=sys.stderr) 131 | dataset = dataset_json 132 | with open(args.prediction_file) as prediction_file: 133 | predictions = json.load(prediction_file) 134 | print(json.dumps(evaluate(dataset, predictions))) 135 | -------------------------------------------------------------------------------- /src/functions/hotpotqa_metric.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ujson as json 3 | import re 4 | import string 5 | from collections import Counter 6 | import pickle 7 | 8 | def normalize_answer(s): 9 | 10 | def remove_articles(text): 11 | return re.sub(r'\b(a|an|the)\b', ' ', text) 12 | 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return ''.join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | 26 | def f1_score(prediction, ground_truth): 27 | normalized_prediction = normalize_answer(prediction) 28 | normalized_ground_truth = normalize_answer(ground_truth) 29 | 30 | ZERO_METRIC = (0, 0, 0) 31 | 32 | if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 33 | return ZERO_METRIC 34 | if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 35 | return ZERO_METRIC 36 | 37 | prediction_tokens = normalized_prediction.split() 38 | ground_truth_tokens = normalized_ground_truth.split() 39 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 40 | num_same = sum(common.values()) 41 | if num_same == 0: 42 | return ZERO_METRIC 43 | precision = 1.0 * num_same / len(prediction_tokens) 44 | recall = 1.0 * num_same / len(ground_truth_tokens) 45 | f1 = (2 * precision * recall) / (precision + recall) 46 | return f1, precision, recall 47 | 48 | 49 | def exact_match_score(prediction, ground_truth): 50 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 51 | 52 | def update_answer(metrics, prediction, gold): 53 | em = exact_match_score(prediction, gold) 54 | f1, prec, recall = f1_score(prediction, gold) 55 | metrics['em'] += float(em) 56 | metrics['f1'] += f1 57 | metrics['prec'] += prec 58 | metrics['recall'] += recall 59 | return em, prec, recall 60 | 61 | def update_sp(metrics, prediction, gold): 62 | cur_sp_pred = set(map(tuple, prediction)) 63 | gold_sp_pred = set(map(tuple, gold)) 64 | tp, fp, fn = 0, 0, 0 65 | for e in cur_sp_pred: 66 | if e in gold_sp_pred: 67 | tp += 1 68 | else: 69 | fp += 1 70 | for e in gold_sp_pred: 71 | if e not in cur_sp_pred: 72 | fn += 1 73 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 74 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 75 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 76 | em = 1.0 if fp + fn == 0 else 0.0 77 | metrics['sp_em'] += em 78 | metrics['sp_f1'] += f1 79 | metrics['sp_prec'] += prec 80 | metrics['sp_recall'] += recall 81 | return em, prec, recall 82 | 83 | def eval(prediction_file, gold_file): 84 | with open(prediction_file) as f: 85 | prediction = json.load(f) 86 | prediction = {"answer": prediction, "sp": {}} 87 | with open(gold_file) as f: 88 | gold = json.load(f) 89 | 90 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 91 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, 92 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0} 93 | for dp in gold: 94 | cur_id = dp['_id'] 95 | can_eval_joint = True 96 | 97 | if cur_id not in prediction['answer']: 98 | print('missing answer {}'.format(cur_id)) 99 | can_eval_joint = False 100 | else: 101 | em, prec, recall = update_answer( 102 | metrics, prediction['answer'][cur_id], dp['answer']) 103 | if cur_id not in prediction['sp']: 104 | #print('missing sp fact {}'.format(cur_id)) 105 | can_eval_joint = False 106 | else: 107 | sp_em, sp_prec, sp_recall = update_sp( 108 | metrics, prediction['sp'][cur_id], dp['supporting_facts']) 109 | 110 | if can_eval_joint: 111 | joint_prec = prec * sp_prec 112 | joint_recall = recall * sp_recall 113 | if joint_prec + joint_recall > 0: 114 | joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) 115 | else: 116 | joint_f1 = 0. 117 | joint_em = em * sp_em 118 | 119 | metrics['joint_em'] += joint_em 120 | metrics['joint_f1'] += joint_f1 121 | metrics['joint_prec'] += joint_prec 122 | metrics['joint_recall'] += joint_recall 123 | 124 | N = len(gold) 125 | for k in metrics.keys(): 126 | metrics[k] /= N 127 | 128 | print(metrics) 129 | 130 | if __name__ == '__main__': 131 | # eval(sys.argv[1], sys.argv[2]) 132 | eval("../../rnn_model__/predictions_.json", "../../all_data/hotpot_dev_distractor_v1.json") 133 | -------------------------------------------------------------------------------- /src/functions/processor_sent.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from functools import partial 5 | from multiprocessing import Pool, cpu_count 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | import nltk 10 | from transformers.file_utils import is_tf_available, is_torch_available 11 | from transformers.tokenization_bert import whitespace_tokenize 12 | from transformers.data.processors.utils import DataProcessor 13 | 14 | if is_torch_available(): 15 | import torch 16 | from torch.utils.data import TensorDataset 17 | max_sent_num = 0 18 | if is_tf_available(): 19 | import tensorflow as tf 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): 25 | """Returns tokenized answer spans that better match the annotated answer.""" 26 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 27 | 28 | for new_start in range(input_start, input_end + 1): 29 | for new_end in range(input_end, new_start - 1, -1): 30 | text_span = " ".join(doc_tokens[new_start: (new_end + 1)]) 31 | if text_span == tok_answer_text: 32 | return (new_start, new_end) 33 | 34 | return (input_start, input_end) 35 | 36 | 37 | def _check_is_max_context(doc_spans, cur_span_index, position): 38 | """Check if this is the 'max context' doc span for the token.""" 39 | best_score = None 40 | best_span_index = None 41 | for (span_index, doc_span) in enumerate(doc_spans): 42 | end = doc_span.start + doc_span.length - 1 43 | if position < doc_span.start: 44 | continue 45 | if position > end: 46 | continue 47 | num_left_context = position - doc_span.start 48 | num_right_context = end - position 49 | score = min(num_left_context, num_right_context) + 0.01 * doc_span.length 50 | if best_score is None or score > best_score: 51 | best_score = score 52 | best_span_index = span_index 53 | 54 | return cur_span_index == best_span_index 55 | 56 | 57 | def _new_check_is_max_context(doc_spans, cur_span_index, position): 58 | """Check if this is the 'max context' doc span for the token.""" 59 | # if len(doc_spans) == 1: 60 | # return True 61 | best_score = None 62 | best_span_index = None 63 | for (span_index, doc_span) in enumerate(doc_spans): 64 | end = doc_span["start"] + doc_span["length"] - 1 65 | if position < doc_span["start"]: 66 | continue 67 | if position > end: 68 | continue 69 | num_left_context = position - doc_span["start"] 70 | num_right_context = end - position 71 | score = min(num_left_context, num_right_context) + 0.01 * doc_span["length"] 72 | if best_score is None or score > best_score: 73 | best_score = score 74 | best_span_index = span_index 75 | 76 | return cur_span_index == best_span_index 77 | 78 | 79 | def _is_whitespace(c): 80 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 81 | return True 82 | return False 83 | 84 | 85 | def squad_convert_example_to_features(examples, max_seq_length, doc_stride, max_query_length, is_training): 86 | features = [] 87 | refine_examples = [] 88 | for ex_id, example in enumerate(examples): 89 | if is_training and not example.is_impossible: 90 | # Get start and end position 91 | start_position = example.start_position 92 | end_position = example.end_position 93 | 94 | # If the answer cannot be found in the text, then skip this example. 95 | actual_text = " ".join(example.doc_tokens[start_position: (end_position + 1)]) 96 | cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) 97 | if actual_text.find(cleaned_answer_text) == -1: 98 | logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) 99 | return [], [] 100 | 101 | tok_to_orig_index = [] 102 | orig_to_tok_index = [] 103 | all_doc_tokens = [] 104 | tok_to_sent_index = [] 105 | # doc_sent_tokens = [] 106 | # 107 | # for sentence in example.doc_sentences: 108 | # doc_sent_tokens.append([]) 109 | # for (i, token) in enumerate(sentence.split(' ')): 110 | # sub_tokens = tokenizer.tokenize(token) 111 | # # sub tokens?? => 어절을 wordpiece 112 | # for sub_token in sub_tokens: 113 | # doc_sent_tokens[-1].append(sub_token) 114 | # if is_training: 115 | # example.doc_sentences = None 116 | # example.doc_sent_tokens = doc_sent_tokens 117 | 118 | refine_examples.append(example) 119 | for (i, token) in enumerate(example.doc_tokens): 120 | # doc_tokens ?? ==> 어절 단위의 문서(context) 121 | # token ==> 어절 122 | # i = 어절 index 123 | 124 | orig_to_tok_index.append(len(all_doc_tokens)) 125 | # ??에 길이를 저장 126 | 127 | sub_tokens = tokenizer.tokenize(token) 128 | # sub tokens?? => 어절을 wordpiece 129 | for sub_token in sub_tokens: 130 | tok_to_orig_index.append(i) 131 | all_doc_tokens.append(sub_token) 132 | tok_to_sent_index.append(example.word_to_sent_offset[i]) 133 | 134 | if is_training and not example.is_impossible: 135 | tok_start_position = orig_to_tok_index[example.start_position] 136 | if example.end_position < len(example.doc_tokens) - 1: 137 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 138 | else: 139 | tok_end_position = len(all_doc_tokens) - 1 140 | 141 | (tok_start_position, tok_end_position) = _improve_answer_span( 142 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text 143 | ) 144 | 145 | spans = [] 146 | 147 | truncated_query = tokenizer.encode(example.question_text, add_special_tokens=False, max_length=max_query_length) 148 | sequence_added_tokens = ( 149 | tokenizer.max_len - tokenizer.max_len_single_sentence + 1 150 | if "roberta" in str(type(tokenizer)) or "camembert" in str(type(tokenizer)) 151 | else tokenizer.max_len - tokenizer.max_len_single_sentence 152 | ) 153 | sequence_pair_added_tokens = tokenizer.max_len - tokenizer.max_len_sentences_pair 154 | # [] 155 | span_doc_tokens = all_doc_tokens 156 | while len(spans) * doc_stride < len(all_doc_tokens): 157 | 158 | encoded_dict = tokenizer.encode_plus( 159 | truncated_query if tokenizer.padding_side == "right" else span_doc_tokens, 160 | span_doc_tokens if tokenizer.padding_side == "right" else truncated_query, 161 | max_length=max_seq_length, 162 | return_overflowing_tokens=True, 163 | pad_to_max_length=True, 164 | stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, 165 | truncation_strategy="only_second" if tokenizer.padding_side == "right" else "only_first", 166 | return_token_type_ids=True, 167 | ) 168 | 169 | paragraph_len = min( 170 | len(all_doc_tokens) - len(spans) * doc_stride, 171 | max_seq_length - len(truncated_query) - sequence_pair_added_tokens, 172 | ) 173 | 174 | if tokenizer.pad_token_id in encoded_dict["input_ids"]: 175 | if tokenizer.padding_side == "right": 176 | non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)] 177 | else: 178 | last_padding_id_position = ( 179 | len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id) 180 | ) 181 | non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1:] 182 | 183 | else: 184 | non_padded_ids = encoded_dict["input_ids"] 185 | 186 | tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) 187 | 188 | token_to_orig_map = {} 189 | cur_sent_to_orig_sent_map = {} 190 | sent_mask = [0]*(len(truncated_query) + sequence_added_tokens) 191 | sent_offset = tok_to_sent_index[len(spans) * doc_stride] 192 | 193 | cur_sent_range = [[] for _ in range(40)] 194 | cur_sent_range[0] = [e for e in range(len(sent_mask))] 195 | for i in range(paragraph_len): 196 | cur_sent_num = tok_to_sent_index[len(spans) * doc_stride + i] - sent_offset + 1 197 | orig_sent_num = tok_to_sent_index[len(spans) * doc_stride + i] 198 | 199 | index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i 200 | token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i] 201 | 202 | sent_mask.append(cur_sent_num) 203 | cur_sent_range[cur_sent_num].append(len(sent_mask)-1) 204 | cur_sent_to_orig_sent_map[cur_sent_num] = orig_sent_num 205 | encoded_dict["paragraph_len"] = paragraph_len 206 | encoded_dict["tokens"] = tokens 207 | encoded_dict["token_to_orig_map"] = token_to_orig_map 208 | encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens 209 | encoded_dict["token_is_max_context"] = {} 210 | encoded_dict["start"] = len(spans) * doc_stride 211 | encoded_dict["length"] = paragraph_len 212 | 213 | encoded_dict["sent_mask"] = sent_mask + [0]*(max_seq_length-len(sent_mask)) 214 | encoded_dict["cur_sent_to_orig_sent"] = cur_sent_to_orig_sent_map 215 | encoded_dict["example_id"] = ex_id 216 | encoded_dict["truncated_query"] = truncated_query 217 | encoded_dict["cur_sent_range"] = cur_sent_range 218 | spans.append(encoded_dict) 219 | 220 | if "overflowing_tokens" not in encoded_dict: 221 | break 222 | span_doc_tokens = encoded_dict["overflowing_tokens"] 223 | 224 | for doc_span_index in range(len(spans)): 225 | for j in range(spans[doc_span_index]["paragraph_len"]): 226 | is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) 227 | index = ( 228 | j 229 | if tokenizer.padding_side == "left" 230 | else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j 231 | ) 232 | spans[doc_span_index]["token_is_max_context"][index] = is_max_context 233 | 234 | for span in spans: 235 | # Identify the position of the CLS token 236 | cls_index = span["input_ids"].index(tokenizer.cls_token_id) 237 | 238 | 239 | 240 | pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id) 241 | special_token_indices = np.asarray( 242 | tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True) 243 | ).nonzero() 244 | 245 | 246 | 247 | span_is_impossible = example.is_impossible 248 | start_position = 0 249 | end_position = 0 250 | 251 | if is_training and not span_is_impossible: 252 | # For training, if our document chunk does not contain an annotation 253 | # we throw it out, since there is nothing to predict. 254 | doc_start = span["start"] 255 | doc_end = span["start"] + span["length"] - 1 256 | out_of_span = False 257 | 258 | if not (tok_start_position >= doc_start and tok_end_position <= doc_end): 259 | out_of_span = True 260 | 261 | if out_of_span: 262 | start_position = cls_index 263 | end_position = cls_index 264 | span_is_impossible = True 265 | else: 266 | if tokenizer.padding_side == "left": 267 | doc_offset = 0 268 | else: 269 | doc_offset = len(truncated_query) + sequence_added_tokens 270 | 271 | start_position = tok_start_position - doc_start + doc_offset 272 | end_position = tok_end_position - doc_start + doc_offset 273 | features.append( 274 | SquadFeatures( 275 | span["input_ids"], 276 | span["attention_mask"], 277 | span["token_type_ids"], 278 | span["cur_sent_range"], 279 | cls_index, 280 | example_index=0, 281 | # Can not set unique_id and example_index here. They will be set after multiple processing. 282 | unique_id=0, 283 | paragraph_len=span["paragraph_len"], 284 | token_is_max_context=span["token_is_max_context"], 285 | tokens=span["tokens"], 286 | 287 | token_to_orig_map=span["token_to_orig_map"], 288 | start_position=start_position, 289 | end_position=end_position, 290 | sent_mask = span["sent_mask"], 291 | cur_sent_to_orig_sent= span["cur_sent_to_orig_sent"], 292 | is_impossible=span_is_impossible, 293 | qas_id=example.qas_id, 294 | example_id = encoded_dict["example_id"], 295 | truncated_query=span['truncated_query'], 296 | question_type=example.q_type 297 | ) 298 | ) 299 | return refine_examples, features 300 | 301 | 302 | def squad_convert_example_to_features_init(tokenizer_for_convert): 303 | global tokenizer 304 | tokenizer = tokenizer_for_convert 305 | 306 | 307 | def squad_convert_examples_to_features( 308 | examples, 309 | tokenizer, 310 | max_seq_length, 311 | doc_stride, 312 | max_query_length, 313 | is_training, 314 | return_dataset=False, 315 | threads=1, 316 | tqdm_enabled=True, 317 | ): 318 | """ 319 | Converts a list of examples into a list of features that can be directly given as input to a model. 320 | It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. 321 | 322 | Args: 323 | examples: list of :class:`~transformers.data.processors.squad.SquadExample` 324 | tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer` 325 | max_seq_length: The maximum sequence length of the inputs. 326 | doc_stride: The stride used when the context is too large and is split across several features. 327 | max_query_length: The maximum length of the query. 328 | is_training: whether to create features for model evaluation or model training. 329 | return_dataset: Default False. Either 'pt' or 'tf'. 330 | if 'pt': returns a torch.data.TensorDataset, 331 | if 'tf': returns a tf.data.Dataset 332 | threads: multiple processing threadsa-smi 333 | 334 | 335 | Returns: 336 | list of :class:`~transformers.data.processors.squad.SquadFeatures` 337 | 338 | Example:: 339 | 340 | processor = SquadV2Processor() 341 | examples = processor.get_dev_examples(data_dir) 342 | 343 | features = squad_convert_examples_to_features( 344 | examples=examples, 345 | tokenizer=tokenizer, 346 | max_seq_length=args.max_seq_length, 347 | doc_stride=args.doc_stride, 348 | max_query_length=args.max_query_length, 349 | is_training=not evaluate, 350 | ) 351 | """ 352 | 353 | # Defining helper methods 354 | features = [] 355 | threads = min(threads, cpu_count()) 356 | with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p: 357 | annotate_ = partial( 358 | squad_convert_example_to_features, 359 | max_seq_length=max_seq_length, 360 | doc_stride=doc_stride, 361 | max_query_length=max_query_length, 362 | is_training=is_training, 363 | ) 364 | features = list( 365 | tqdm( 366 | p.imap(annotate_, examples, chunksize=32), 367 | total=len(examples), 368 | desc="convert squad examples to features", 369 | disable=not tqdm_enabled, 370 | ) 371 | ) 372 | refine_examples = [] 373 | new_features = [] 374 | unique_id = 1000000000 375 | 376 | example_index = 0 377 | for example_features in tqdm( 378 | features, total=len(features), desc="add example index and unique id", disable=not tqdm_enabled 379 | ): 380 | example, example_features = example_features 381 | 382 | if not example_features: 383 | continue 384 | refine_examples.append(example) 385 | 386 | new_feature = [] 387 | for example_feature in example_features: 388 | example_feature.example_index = example_index 389 | example_feature.unique_id = unique_id 390 | new_feature.append(example_feature) 391 | unique_id += 1 392 | example_index += 1 393 | 394 | new_features.append(new_feature) 395 | features = new_features 396 | del new_features 397 | global max_sent_num 398 | print(max_sent_num) 399 | if return_dataset == "pt": 400 | if not is_torch_available(): 401 | raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.") 402 | return refine_examples, features 403 | 404 | 405 | class SquadProcessor(DataProcessor): 406 | """ 407 | Processor for the SQuAD data set. 408 | Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively. 409 | """ 410 | 411 | train_file = None 412 | dev_file = None 413 | 414 | def _get_example_from_tensor_dict(self, tensor_dict, evaluate=False): 415 | if not evaluate: 416 | answer = tensor_dict["answers"]["text"][0].numpy().decode("utf-8") 417 | answer_start = tensor_dict["answers"]["answer_start"][0].numpy() 418 | answers = [] 419 | else: 420 | answers = [ 421 | {"answer_start": start.numpy(), "text": text.numpy().decode("utf-8")} 422 | for start, text in zip(tensor_dict["answers"]["answer_start"], tensor_dict["answers"]["text"]) 423 | ] 424 | 425 | answer = None 426 | answer_start = None 427 | 428 | return SquadExample( 429 | qas_id=tensor_dict["id"].numpy().decode("utf-8"), 430 | question_text=tensor_dict["question"].numpy().decode("utf-8"), 431 | context_text=tensor_dict["context"].numpy().decode("utf-8"), 432 | answer_text=answer, 433 | start_position_character=answer_start, 434 | title=tensor_dict["title"].numpy().decode("utf-8"), 435 | answers=answers, 436 | ) 437 | 438 | def get_examples_from_dataset(self, dataset, evaluate=False): 439 | """ 440 | Creates a list of :class:`~transformers.data.processors.squad.SquadExample` using a TFDS dataset. 441 | 442 | Args: 443 | dataset: The tfds dataset loaded from `tensorflow_datasets.load("squad")` 444 | evaluate: boolean specifying if in evaluation mode or in training mode 445 | 446 | Returns: 447 | List of SquadExample 448 | 449 | Examples:: 450 | 451 | import tensorflow_datasets as tfds 452 | dataset = tfds.load("squad") 453 | 454 | training_examples = get_examples_from_dataset(dataset, evaluate=False) 455 | evaluation_examples = get_examples_from_dataset(dataset, evaluate=True) 456 | """ 457 | 458 | if evaluate: 459 | dataset = dataset["validation"] 460 | else: 461 | dataset = dataset["train"] 462 | 463 | examples = [] 464 | for tensor_dict in tqdm(dataset): 465 | examples.append(self._get_example_from_tensor_dict(tensor_dict, evaluate=evaluate)) 466 | 467 | return examples 468 | 469 | def get_train_examples(self, data_dir, filename=None, tokenizer = None): 470 | """ 471 | Returns the training examples from the data directory. 472 | 473 | Args: 474 | data_dir: Directory containing the data files used for training and evaluating. 475 | filename: None by default, specify this if the training file has a different name than the original one 476 | which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. 477 | 478 | """ 479 | if data_dir is None: 480 | data_dir = "" 481 | 482 | if self.train_file is None: 483 | raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") 484 | 485 | with open( 486 | os.path.join(data_dir, self.train_file if filename is None else filename), "r", encoding="utf-8" 487 | ) as reader: 488 | input_data = json.load(reader) 489 | return self._create_examples(input_data, 'train', tokenizer) 490 | 491 | def get_dev_examples(self, data_dir, filename=None, tokenizer = None): 492 | """ 493 | Returns the evaluation example from the data directory. 494 | 495 | Args: 496 | data_dir: Directory containing the data files used for training and evaluating. 497 | filename: None by default, specify this if the evaluation file has a different name than the original one 498 | which is `train-v1.1.json` and `train-v2.0.json` for squad versions 1.1 and 2.0 respectively. 499 | """ 500 | if data_dir is None: 501 | data_dir = "" 502 | 503 | if self.dev_file is None: 504 | raise ValueError("SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor") 505 | 506 | with open( 507 | os.path.join(data_dir, self.dev_file if filename is None else filename), "r", encoding="utf-8" 508 | ) as reader: 509 | input_data = json.load(reader) 510 | return self._create_examples(input_data, "dev", tokenizer) 511 | 512 | def example_from_input(self, question, context): 513 | return SquadExample( 514 | qas_id="sample", 515 | question_text=question, 516 | context_text=context, 517 | answer_text=None, 518 | start_position_character=None, 519 | title="sample", 520 | is_impossible=False, 521 | answers=[], 522 | ) 523 | 524 | def get_example_from_input(self, input_dictionary): 525 | # context, question, id, title 526 | context_text = input_dictionary["context"] 527 | question_text = input_dictionary["question"] 528 | qas_id = input_dictionary["id"] 529 | start_position_character = None 530 | is_impossible = False 531 | answer_text = None 532 | answers = [] 533 | 534 | examples = [SquadExample( 535 | qas_id=qas_id, 536 | question_text=question_text, 537 | context_text=context_text, 538 | answer_text=answer_text, 539 | start_position_character=start_position_character, 540 | title="", 541 | is_impossible=is_impossible, 542 | answers=answers, 543 | )] 544 | return examples 545 | 546 | def _create_examples(self, input_data, set_type, tokenizer): 547 | is_training = set_type == "train" 548 | num = 0 549 | examples = [] 550 | for entry in tqdm(input_data): 551 | qas_id = entry["_id"] 552 | question_text = entry["question"] 553 | level = entry["level"] 554 | question_type = entry["type"] 555 | if 'question_type' in entry.keys(): 556 | q_type = entry['question_type'] 557 | else: 558 | q_type = None 559 | data_examples = [] 560 | support_facts = {e[0]:e[1] for e in entry["supporting_facts"]} 561 | 562 | for context in entry["context"]: 563 | start_position_character = None 564 | answer_text = None 565 | answers = [] 566 | title = context[0] 567 | is_impossible = False if context[2] > 0 else True 568 | 569 | 570 | if is_training: 571 | answer_text = entry["answer"] 572 | start_position_character = context[2] 573 | else: 574 | answer_text = entry["answer"] 575 | doc_sentences = context[1] 576 | if title in support_facts.keys(): 577 | support_fact = [1 for e in range(len(doc_sentences))] 578 | for e in range(len(doc_sentences)): 579 | for t, idx, in entry["supporting_facts"]: 580 | if t == title and e == idx: 581 | support_fact[idx] = 2 582 | else: 583 | support_fact = [1 for e in range(len(doc_sentences))] 584 | context_text = ''.join(doc_sentences) 585 | example = SquadExample( 586 | qas_id=qas_id, 587 | question_text=question_text, 588 | context_text=context_text, 589 | q_type = q_type, 590 | doc_sentences=doc_sentences, 591 | support_fact=support_fact, 592 | answer_text=answer_text, 593 | start_position_character=start_position_character, 594 | title=title, 595 | is_impossible=is_impossible, 596 | answers=answers, 597 | level=level, 598 | question_type=question_type, 599 | tokenizer = tokenizer 600 | ) 601 | data_examples.append(example) 602 | 603 | examples.append(data_examples) 604 | if len(examples) > 50000: 605 | break 606 | return examples 607 | 608 | 609 | class SquadV1Processor(SquadProcessor): 610 | train_file = "train-v1.1.json" 611 | dev_file = "dev-v1.1.json" 612 | 613 | 614 | class SquadV2Processor(SquadProcessor): 615 | train_file = "train-v2.0.json" 616 | dev_file = "dev-v2.0.json" 617 | 618 | 619 | class SquadExample(object): 620 | """ 621 | A single training/test example for the Squad dataset, as loaded from disk. 622 | 623 | Args: 624 | qas_id: The example's unique identifier 625 | question_text: The question string 626 | context_text: The context string 627 | answer_text: The answer string 628 | start_position_character: The character position of the start of the answer 629 | title: The title of the example 630 | answers: None by default, this is used during evaluation. Holds answers as well as their start positions. 631 | is_impossible: False by default, set to True if the example has no possible answer. 632 | """ 633 | 634 | def __init__( 635 | self, 636 | qas_id, 637 | question_text, 638 | context_text, 639 | doc_sentences, 640 | q_type, 641 | answer_text, 642 | support_fact, 643 | start_position_character, 644 | title, 645 | level, 646 | question_type, 647 | answers=[], 648 | is_impossible=False, 649 | tokenizer=None 650 | ): 651 | self.qas_id = qas_id 652 | self.q_type = q_type 653 | self.question_text = question_text 654 | 655 | self.level = level 656 | self.question_type = question_type 657 | self.answer_text = answer_text 658 | self.title = title 659 | self.support_fact = support_fact 660 | self.is_impossible = is_impossible 661 | self.answers = answers 662 | self.doc_sentences = doc_sentences 663 | self.doc_sent_tokens = None 664 | self.start_position, self.end_position = 0, 0 665 | 666 | doc_tokens = [] 667 | char_to_word_offset = [] 668 | 669 | if q_type == 'yn': 670 | if answer_text == 'yes': 671 | self.q_type = 0 672 | else: 673 | self.q_type = 1 674 | else: 675 | self.q_type = 2 676 | # Split on whitespace so that different tokens may be attributed to their original position. 677 | prev_is_whitespace = True 678 | for sent_num in range(len(doc_sentences)): 679 | 680 | for c_idx, c in enumerate(doc_sentences[sent_num]): 681 | if _is_whitespace(c): 682 | prev_is_whitespace = True 683 | if c_idx == 0: 684 | char_to_word_offset.append(len(doc_tokens)) 685 | else: 686 | char_to_word_offset.append(len(doc_tokens)-1) 687 | else: 688 | if prev_is_whitespace: 689 | doc_tokens.append(c) 690 | else: 691 | doc_tokens[-1] += c 692 | prev_is_whitespace = False 693 | char_to_word_offset.append(len(doc_tokens) - 1) 694 | 695 | self.doc_tokens = doc_tokens 696 | char_to_word_offset = char_to_word_offset 697 | char_to_sent_offset = [] 698 | 699 | for sent_id, sentence in enumerate(doc_sentences): 700 | char_to_sent_offset += [sent_id] * len(sentence) 701 | 702 | self.word_to_sent_offset = {char_to_word_offset[e]: char_to_sent_offset[e] for e in 703 | range(len(char_to_word_offset))} 704 | # Start and end positions only has a value during evaluation. 705 | if start_position_character is not None and not is_impossible: 706 | self.start_position = char_to_word_offset[start_position_character] 707 | self.end_position = char_to_word_offset[ 708 | min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1) 709 | ] 710 | 711 | 712 | class SquadFeatures(object): 713 | """ 714 | Single squad example features to be fed to a model. 715 | Those features are model-specific and can be crafted from :class:`~transformers.data.processors.squad.SquadExample` 716 | using the :method:`~transformers.data.processors.squad.squad_convert_examples_to_features` method. 717 | 718 | Args: 719 | input_ids: Indices of input sequence tokens in the vocabulary. 720 | attention_mask: Mask to avoid performing attention on padding token indices. 721 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 722 | cls_index: the index of the CLS token. 723 | p_mask: Mask identifying tokens that can be answers vs. tokens that cannot. 724 | Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer 725 | example_index: the index of the example 726 | unique_id: The unique Feature identifier 727 | paragraph_len: The length of the context 728 | token_is_max_context: List of booleans identifying which tokens have their maximum context in this feature object. 729 | If a token does not have their maximum context in this feature object, it means that another feature object 730 | has more information related to that token and should be prioritized over this feature for that token. 731 | tokens: list of tokens corresponding to the input ids 732 | token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer. 733 | start_position: start of the answer token index 734 | end_position: end of the answer token index 735 | """ 736 | 737 | def __init__( 738 | self, 739 | input_ids, 740 | attention_mask, 741 | token_type_ids, 742 | cur_sent_range, 743 | cls_index, 744 | 745 | example_index, 746 | unique_id, 747 | paragraph_len, 748 | token_is_max_context, 749 | tokens, 750 | 751 | token_to_orig_map, 752 | start_position, 753 | end_position, 754 | is_impossible, 755 | sent_mask, 756 | cur_sent_to_orig_sent, 757 | qas_id: str = None, 758 | example_id:int = 0, 759 | truncated_query='', 760 | question_type=None 761 | ): 762 | self.input_ids = input_ids 763 | self.truncated_query = truncated_query 764 | self.attention_mask = attention_mask 765 | self.token_type_ids = token_type_ids 766 | self.cls_index = cls_index 767 | self.cur_sent_range = cur_sent_range 768 | 769 | self.sent_mask = sent_mask 770 | self.question_type = question_type 771 | self.cur_sent_to_orig_sent = cur_sent_to_orig_sent 772 | self.example_index = example_index 773 | self.unique_id = unique_id 774 | self.example_id = example_id 775 | self.paragraph_len = paragraph_len 776 | self.token_is_max_context = token_is_max_context 777 | self.tokens = tokens 778 | self.token_to_orig_map = token_to_orig_map 779 | 780 | self.start_position = start_position 781 | self.end_position = end_position 782 | self.is_impossible = is_impossible 783 | self.qas_id = qas_id 784 | 785 | 786 | class SquadResult(object): 787 | """ 788 | Constructs a SquadResult which can be used to evaluate a model's output on the SQuAD dataset. 789 | 790 | Args: 791 | unique_id: The unique identifier corresponding to that example. 792 | start_logits: The logits corresponding to the start of the answer 793 | end_logits: The logits corresponding to the end of the answer 794 | """ 795 | 796 | def __init__(self, unique_id, start_logits, end_logits, evidence=None, start_top_index=None, end_top_index=None, cls_logits=None): 797 | self.start_logits = start_logits 798 | self.end_logits = end_logits 799 | self.unique_id = unique_id 800 | self.evidence = evidence 801 | if start_top_index: 802 | self.start_top_index = start_top_index 803 | self.end_top_index = end_top_index 804 | self.cls_logits = cls_logits 805 | -------------------------------------------------------------------------------- /src/functions/squad_metric.py: -------------------------------------------------------------------------------- 1 | """ Very heavily inspired by the official evaluation script for SQuAD version 2.0 which was 2 | modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0 3 | 4 | In addition to basic functionality, we also compute additional statistics and 5 | plot precision-recall curves if an additional na_prob.json file is provided. 6 | This file is expected to map question ID's to the model's predicted probability 7 | that a question is unanswerable. 8 | """ 9 | 10 | 11 | import collections 12 | import json 13 | import logging 14 | import math 15 | import re 16 | import string 17 | 18 | from transformers.tokenization_bert import BasicTokenizer 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def normalize_answer(s): 25 | """Lower text and remove punctuation, articles and extra whitespace.""" 26 | 27 | def remove_articles(text): 28 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 29 | return re.sub(regex, " ", text) 30 | 31 | def white_space_fix(text): 32 | return " ".join(text.split()) 33 | 34 | def remove_punc(text): 35 | exclude = set(string.punctuation) 36 | return "".join(ch for ch in text if ch not in exclude) 37 | 38 | def lower(text): 39 | return text.lower() 40 | 41 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 42 | 43 | 44 | def get_tokens(s): 45 | if not s: 46 | return [] 47 | return normalize_answer(s).split() 48 | 49 | 50 | def compute_exact(a_gold, a_pred): 51 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 52 | 53 | 54 | def compute_f1(a_gold, a_pred): 55 | gold_toks = get_tokens(a_gold) 56 | pred_toks = get_tokens(a_pred) 57 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 58 | num_same = sum(common.values()) 59 | if len(gold_toks) == 0 or len(pred_toks) == 0: 60 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 61 | return int(gold_toks == pred_toks) 62 | if num_same == 0: 63 | return 0 64 | precision = 1.0 * num_same / len(pred_toks) 65 | recall = 1.0 * num_same / len(gold_toks) 66 | f1 = (2 * precision * recall) / (precision + recall) 67 | return f1 68 | 69 | 70 | def get_raw_scores(examples, preds): 71 | """ 72 | Computes the exact and f1 scores from the examples and the model predictions 73 | """ 74 | exact_scores = {} 75 | f1_scores = {} 76 | 77 | for example in examples: 78 | qas_id = example.qas_id 79 | gold_answers = [answer["text"] for answer in example.answers if normalize_answer(answer["text"])] 80 | 81 | if not gold_answers: 82 | # For unanswerable questions, only correct answer is empty string 83 | gold_answers = [""] 84 | 85 | if qas_id not in preds: 86 | print("Missing prediction for %s" % qas_id) 87 | continue 88 | 89 | prediction = preds[qas_id] 90 | exact_scores[qas_id] = max(compute_exact(a, prediction) for a in gold_answers) 91 | f1_scores[qas_id] = max(compute_f1(a, prediction) for a in gold_answers) 92 | 93 | return exact_scores, f1_scores 94 | 95 | 96 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 97 | new_scores = {} 98 | for qid, s in scores.items(): 99 | pred_na = na_probs[qid] > na_prob_thresh 100 | if pred_na: 101 | new_scores[qid] = float(not qid_to_has_ans[qid]) 102 | else: 103 | new_scores[qid] = s 104 | return new_scores 105 | 106 | 107 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 108 | if not qid_list: 109 | total = len(exact_scores) 110 | return collections.OrderedDict( 111 | [ 112 | ("exact", 100.0 * sum(exact_scores.values()) / total), 113 | ("f1", 100.0 * sum(f1_scores.values()) / total), 114 | ("total", total), 115 | ] 116 | ) 117 | else: 118 | total = len(qid_list) 119 | return collections.OrderedDict( 120 | [ 121 | ("exact", 100.0 * sum(exact_scores[k] for k in qid_list) / total), 122 | ("f1", 100.0 * sum(f1_scores[k] for k in qid_list) / total), 123 | ("total", total), 124 | ] 125 | ) 126 | 127 | 128 | def merge_eval(main_eval, new_eval, prefix): 129 | for k in new_eval: 130 | main_eval["%s_%s" % (prefix, k)] = new_eval[k] 131 | 132 | 133 | def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): 134 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 135 | cur_score = num_no_ans 136 | best_score = cur_score 137 | best_thresh = 0.0 138 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 139 | for i, qid in enumerate(qid_list): 140 | if qid not in scores: 141 | continue 142 | if qid_to_has_ans[qid]: 143 | diff = scores[qid] 144 | else: 145 | if preds[qid]: 146 | diff = -1 147 | else: 148 | diff = 0 149 | cur_score += diff 150 | if cur_score > best_score: 151 | best_score = cur_score 152 | best_thresh = na_probs[qid] 153 | 154 | has_ans_score, has_ans_cnt = 0, 0 155 | for qid in qid_list: 156 | if not qid_to_has_ans[qid]: 157 | continue 158 | has_ans_cnt += 1 159 | 160 | if qid not in scores: 161 | continue 162 | has_ans_score += scores[qid] 163 | 164 | return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt 165 | 166 | 167 | def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 168 | best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans) 169 | best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans) 170 | main_eval["best_exact"] = best_exact 171 | main_eval["best_exact_thresh"] = exact_thresh 172 | main_eval["best_f1"] = best_f1 173 | main_eval["best_f1_thresh"] = f1_thresh 174 | main_eval["has_ans_exact"] = has_ans_exact 175 | main_eval["has_ans_f1"] = has_ans_f1 176 | 177 | 178 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 179 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 180 | cur_score = num_no_ans 181 | best_score = cur_score 182 | best_thresh = 0.0 183 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 184 | for _, qid in enumerate(qid_list): 185 | if qid not in scores: 186 | continue 187 | if qid_to_has_ans[qid]: 188 | diff = scores[qid] 189 | else: 190 | if preds[qid]: 191 | diff = -1 192 | else: 193 | diff = 0 194 | cur_score += diff 195 | if cur_score > best_score: 196 | best_score = cur_score 197 | best_thresh = na_probs[qid] 198 | return 100.0 * best_score / len(scores), best_thresh 199 | 200 | 201 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 202 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 203 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 204 | 205 | main_eval["best_exact"] = best_exact 206 | main_eval["best_exact_thresh"] = exact_thresh 207 | main_eval["best_f1"] = best_f1 208 | main_eval["best_f1_thresh"] = f1_thresh 209 | 210 | 211 | def squad_evaluate(examples, preds, no_answer_probs=None, no_answer_probability_threshold=1.0): 212 | qas_id_to_has_answer = {example.qas_id: bool(example.answers) for example in examples} 213 | has_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if has_answer] 214 | no_answer_qids = [qas_id for qas_id, has_answer in qas_id_to_has_answer.items() if not has_answer] 215 | 216 | if no_answer_probs is None: 217 | no_answer_probs = {k: 0.0 for k in preds} 218 | 219 | exact, f1 = get_raw_scores(examples, preds) 220 | 221 | exact_threshold = apply_no_ans_threshold( 222 | exact, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold 223 | ) 224 | f1_threshold = apply_no_ans_threshold(f1, no_answer_probs, qas_id_to_has_answer, no_answer_probability_threshold) 225 | 226 | evaluation = make_eval_dict(exact_threshold, f1_threshold) 227 | 228 | if has_answer_qids: 229 | has_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=has_answer_qids) 230 | merge_eval(evaluation, has_ans_eval, "HasAns") 231 | 232 | if no_answer_qids: 233 | no_ans_eval = make_eval_dict(exact_threshold, f1_threshold, qid_list=no_answer_qids) 234 | merge_eval(evaluation, no_ans_eval, "NoAns") 235 | 236 | if no_answer_probs: 237 | find_all_best_thresh(evaluation, preds, exact, f1, no_answer_probs, qas_id_to_has_answer) 238 | 239 | return evaluation 240 | 241 | 242 | def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): 243 | """Project the tokenized prediction back to the original text.""" 244 | 245 | # When we created the data, we kept track of the alignment between original 246 | # (whitespace tokenized) tokens and our WordPiece tokenized tokens. So 247 | # now `orig_text` contains the span of our original text corresponding to the 248 | # span that we predicted. 249 | # 250 | # However, `orig_text` may contain extra characters that we don't want in 251 | # our prediction. 252 | # 253 | # For example, let's say: 254 | # pred_text = steve smith 255 | # orig_text = Steve Smith's 256 | # 257 | # We don't want to return `orig_text` because it contains the extra "'s". 258 | # 259 | # We don't want to return `pred_text` because it's already been normalized 260 | # (the SQuAD eval script also does punctuation stripping/lower casing but 261 | # our tokenizer does additional normalization like stripping accent 262 | # characters). 263 | # 264 | # What we really want to return is "Steve Smith". 265 | # 266 | # Therefore, we have to apply a semi-complicated alignment heuristic between 267 | # `pred_text` and `orig_text` to get a character-to-character alignment. This 268 | # can fail in certain cases in which case we just return `orig_text`. 269 | 270 | def _strip_spaces(text): 271 | ns_chars = [] 272 | ns_to_s_map = collections.OrderedDict() 273 | for (i, c) in enumerate(text): 274 | if c == " ": 275 | continue 276 | ns_to_s_map[len(ns_chars)] = i 277 | ns_chars.append(c) 278 | ns_text = "".join(ns_chars) 279 | return (ns_text, ns_to_s_map) 280 | 281 | # We first tokenize `orig_text`, strip whitespace from the result 282 | # and `pred_text`, and check if they are the same length. If they are 283 | # NOT the same length, the heuristic has failed. If they are the same 284 | # length, we assume the characters are one-to-one aligned. 285 | tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 286 | 287 | tok_text = " ".join(tokenizer.tokenize(orig_text)) 288 | 289 | start_position = tok_text.find(pred_text) 290 | if start_position == -1: 291 | if verbose_logging: 292 | logger.info("Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) 293 | return orig_text 294 | end_position = start_position + len(pred_text) - 1 295 | 296 | (orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text) 297 | (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) 298 | 299 | if len(orig_ns_text) != len(tok_ns_text): 300 | if verbose_logging: 301 | logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) 302 | return orig_text 303 | 304 | # We then project the characters in `pred_text` back to `orig_text` using 305 | # the character-to-character alignment. 306 | tok_s_to_ns_map = {} 307 | for (i, tok_index) in tok_ns_to_s_map.items(): 308 | tok_s_to_ns_map[tok_index] = i 309 | 310 | orig_start_position = None 311 | if start_position in tok_s_to_ns_map: 312 | ns_start_position = tok_s_to_ns_map[start_position] 313 | if ns_start_position in orig_ns_to_s_map: 314 | orig_start_position = orig_ns_to_s_map[ns_start_position] 315 | 316 | if orig_start_position is None: 317 | if verbose_logging: 318 | logger.info("Couldn't map start position") 319 | return orig_text 320 | 321 | orig_end_position = None 322 | if end_position in tok_s_to_ns_map: 323 | ns_end_position = tok_s_to_ns_map[end_position] 324 | if ns_end_position in orig_ns_to_s_map: 325 | orig_end_position = orig_ns_to_s_map[ns_end_position] 326 | 327 | if orig_end_position is None: 328 | if verbose_logging: 329 | logger.info("Couldn't map end position") 330 | return orig_text 331 | 332 | output_text = orig_text[orig_start_position : (orig_end_position + 1)] 333 | return output_text 334 | 335 | 336 | def _get_best_indexes(logits, n_best_size): 337 | """Get the n-best logits from a list.""" 338 | index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True) 339 | 340 | best_indexes = [] 341 | for i in range(len(index_and_score)): 342 | if i >= n_best_size: 343 | break 344 | best_indexes.append(index_and_score[i][0]) 345 | return best_indexes 346 | 347 | 348 | def _compute_softmax(scores): 349 | """Compute softmax probability over raw logits.""" 350 | if not scores: 351 | return [] 352 | 353 | max_score = None 354 | for score in scores: 355 | if max_score is None or score > max_score: 356 | max_score = score 357 | 358 | exp_scores = [] 359 | total_sum = 0.0 360 | for score in scores: 361 | x = math.exp(score - max_score) 362 | exp_scores.append(x) 363 | total_sum += x 364 | 365 | probs = [] 366 | for score in exp_scores: 367 | probs.append(score / total_sum) 368 | return probs 369 | def restore_prediction_2(results, features, n_best_size, tokenizer, max_answer_length=30): 370 | 371 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 372 | "PrelimPrediction", ["tokens", "start_index", "end_index", "score"] 373 | ) 374 | 375 | prelim_predictions = [] 376 | for idx, result in enumerate(results): 377 | 378 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 379 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 380 | feature = features[idx].tokens 381 | sep_position = feature.index('[SEP]')+1 382 | for start_index in start_indexes: 383 | for end_index in end_indexes: 384 | if start_index < sep_position or end_index < sep_position: 385 | continue 386 | if start_index >= len(feature): 387 | continue 388 | if end_index >= len(feature): 389 | continue 390 | if end_index < start_index: 391 | continue 392 | length = end_index - start_index 393 | if length > max_answer_length: 394 | continue 395 | prelim_predictions.append( 396 | _PrelimPrediction( 397 | tokens = feature, 398 | start_index=start_index, 399 | end_index=end_index, 400 | score=result.start_logits[start_index]+result.end_logits[end_index] 401 | ) 402 | ) 403 | prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.score), reverse=True) 404 | if prelim_predictions: 405 | pred = prelim_predictions[0] 406 | score = pred.score 407 | tok_tokens = pred.tokens[pred.start_index: (pred.end_index + 1)] 408 | tok_text = tokenizer.convert_tokens_to_string(tok_tokens) 409 | tok_text = tok_text.strip() 410 | tok_text = " ".join(tok_text.split()) 411 | else: 412 | tok_text = '' 413 | score = 0 414 | pred = tok_text 415 | 416 | 417 | return pred, score 418 | def restore_prediction(example, features, results, n_best_size, do_lower_case, verbose_logging, tokenizer): 419 | prelim_predictions = [] 420 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 421 | "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"] 422 | ) 423 | # keep track of the minimum score of null start+end of position 0 424 | score_null = 1000000 # large and positive 425 | min_null_feature_index = 0 # the paragraph slice with min null score 426 | null_start_logit = 0 # the start logit at the slice with min null score 427 | null_end_logit = 0 # the end logit at the slice with min null score 428 | for (feature_index, feature) in enumerate(features): 429 | # 10개 문서에 종속되는 다수의 feature 430 | 431 | result = results[feature_index] 432 | 433 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 434 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 435 | 436 | # if we could have irrelevant answers, get the min score of irrelevant 437 | feature_null_score = result.start_logits[0] + result.end_logits[0] 438 | if feature_null_score < score_null: 439 | score_null = feature_null_score 440 | min_null_feature_index = feature_index 441 | null_start_logit = result.start_logits[0] 442 | null_end_logit = result.end_logits[0] 443 | 444 | for start_index in start_indexes: 445 | for end_index in end_indexes: 446 | # We could hypothetically create invalid predictions, e.g., predict 447 | # that the start of the span is in the question. We throw out all 448 | # invalid predictions. 449 | if start_index >= len(feature.tokens): 450 | continue 451 | if end_index >= len(feature.tokens): 452 | continue 453 | if start_index not in feature.token_to_orig_map: 454 | continue 455 | if end_index not in feature.token_to_orig_map: 456 | continue 457 | if not feature.token_is_max_context.get(start_index, False): 458 | continue 459 | 460 | if end_index < start_index: 461 | continue 462 | prelim_predictions.append( 463 | _PrelimPrediction( 464 | feature_index=feature_index, 465 | start_index=start_index, 466 | end_index=end_index, 467 | start_logit=result.start_logits[start_index], 468 | end_logit=result.end_logits[end_index], 469 | ) 470 | ) 471 | 472 | prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) 473 | 474 | 475 | if prelim_predictions: 476 | pred = prelim_predictions[0] 477 | else: 478 | return '' 479 | feature = features[pred.feature_index] 480 | if pred.start_index > 0: # this is a non-null prediction 481 | tok_tokens = feature.tokens[pred.start_index: (pred.end_index + 1)] 482 | tok_text = tokenizer.convert_tokens_to_string(tok_tokens) 483 | tok_text = tok_text.strip() 484 | tok_text = " ".join(tok_text.split()) 485 | 486 | return tok_text 487 | else: 488 | return '' 489 | def restore_prediction2(tokens, results, n_best_size, tokenizer): 490 | prelim_predictions = [] 491 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 492 | "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit"] 493 | ) 494 | 495 | for result in results: 496 | # 10개 문서에 종속되는 다수의 feature 497 | 498 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 499 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 500 | 501 | for start_index in start_indexes: 502 | for end_index in end_indexes: 503 | # We could hypothetically create invalid predictions, e.g., predict 504 | # that the start of the span is in the question. We throw out all 505 | # invalid predictions. 506 | if start_index >= len(tokens): 507 | continue 508 | if end_index >= len(tokens): 509 | continue 510 | if '[SEP]' in tokens[start_index:end_index+1] or '[CLS]' in tokens[start_index:end_index+1]: 511 | continue 512 | if end_index < start_index: 513 | continue 514 | if end_index - start_index > 30: 515 | continue 516 | prelim_predictions.append( 517 | _PrelimPrediction( 518 | feature_index=0, 519 | start_index=start_index, 520 | end_index=end_index, 521 | start_logit=result.start_logits[start_index], 522 | end_logit=result.end_logits[end_index], 523 | ) 524 | ) 525 | 526 | prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) 527 | 528 | 529 | if prelim_predictions: 530 | pred = prelim_predictions[0] 531 | else: 532 | return '' 533 | 534 | if pred.start_index > 0: # this is a non-null prediction 535 | tok_tokens = tokens[pred.start_index: (pred.end_index + 1)] 536 | tok_text = tokenizer.convert_tokens_to_string(tok_tokens) 537 | tok_text = tok_text.strip() 538 | tok_text = " ".join(tok_text.split()) 539 | 540 | return tok_text 541 | else: 542 | return '' 543 | def compute_predictions_logits( 544 | all_examples, 545 | all_features, 546 | all_results, 547 | n_best_size, 548 | max_answer_length, 549 | do_lower_case, 550 | output_prediction_file, 551 | output_nbest_file, 552 | output_null_log_odds_file, 553 | verbose_logging, 554 | version_2_with_negative, 555 | null_score_diff_threshold, 556 | tokenizer, 557 | ): 558 | """Write final predictions to the json file and log-odds of null if needed.""" 559 | if output_prediction_file: 560 | logger.info(f"Writing predictions to: {output_prediction_file}") 561 | if output_nbest_file: 562 | logger.info(f"Writing nbest to: {output_nbest_file}") 563 | if output_null_log_odds_file and version_2_with_negative: 564 | logger.info(f"Writing null_log_odds to: {output_null_log_odds_file}") 565 | 566 | example_index_to_features = collections.defaultdict(list) 567 | for features in all_features: 568 | for feature in features: 569 | example_index_to_features[feature.example_index].append(feature) 570 | 571 | unique_id_to_result = {} 572 | for result in all_results: 573 | unique_id_to_result[result.unique_id] = result 574 | 575 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 576 | "PrelimPrediction", ["feature_index", "start_index", "end_index", "start_logit", "end_logit", "evidence"] 577 | ) 578 | 579 | all_predictions = collections.OrderedDict() 580 | all_nbest_json = collections.OrderedDict() 581 | scores_diff_json = collections.OrderedDict() 582 | 583 | for (example_index, examples) in enumerate(all_examples): 584 | # examples : 10개의 문서 585 | 586 | features = example_index_to_features[example_index] 587 | prelim_predictions = [] 588 | # keep track of the minimum score of null start+end of position 0 589 | score_null = 1000000 # large and positive 590 | min_null_feature_index = 0 # the paragraph slice with min null score 591 | null_start_logit = 0 # the start logit at the slice with min null score 592 | null_end_logit = 0 # the end logit at the slice with min null score 593 | for (feature_index, feature) in enumerate(features): 594 | # 10개 문서에 종속되는 다수의 feature 595 | 596 | result = unique_id_to_result[feature.unique_id] 597 | start_indexes = _get_best_indexes(result.start_logits, n_best_size) 598 | end_indexes = _get_best_indexes(result.end_logits, n_best_size) 599 | # if we could have irrelevant answers, get the min score of irrelevant 600 | if version_2_with_negative: 601 | feature_null_score = result.start_logits[0] + result.end_logits[0] 602 | if feature_null_score < score_null: 603 | score_null = feature_null_score 604 | min_null_feature_index = feature_index 605 | null_start_logit = result.start_logits[0] 606 | null_end_logit = result.end_logits[0] 607 | for start_index in start_indexes: 608 | for end_index in end_indexes: 609 | # We could hypothetically create invalid predictions, e.g., predict 610 | # that the start of the span is in the question. We throw out all 611 | # invalid predictions. 612 | if start_index >= len(feature.tokens): 613 | continue 614 | if end_index >= len(feature.tokens): 615 | continue 616 | if start_index not in feature.token_to_orig_map: 617 | continue 618 | if end_index not in feature.token_to_orig_map: 619 | continue 620 | if not feature.token_is_max_context.get(start_index, False): 621 | continue 622 | length = end_index-start_index 623 | if length > max_answer_length: 624 | continue 625 | if end_index < start_index: 626 | continue 627 | prelim_predictions.append( 628 | _PrelimPrediction( 629 | feature_index=feature_index, 630 | start_index=start_index, 631 | end_index=end_index, 632 | start_logit=result.start_logits[start_index], 633 | end_logit=result.end_logits[end_index], 634 | evidence=result.evidence, 635 | 636 | ) 637 | ) 638 | 639 | prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True) 640 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 641 | "NbestPrediction", ["text", "start_logit", "end_logit", "evidence"] 642 | ) 643 | 644 | seen_predictions = {} 645 | nbest = [] 646 | for pred in prelim_predictions: 647 | if len(nbest) >= n_best_size: 648 | break 649 | feature = features[pred.feature_index] 650 | example = examples[feature.example_id] 651 | if pred.start_index > 0: # this is a non-null prediction 652 | tok_tokens = feature.tokens[pred.start_index : (pred.end_index + 1)] 653 | orig_doc_start = feature.token_to_orig_map[pred.start_index] 654 | orig_doc_end = feature.token_to_orig_map[pred.end_index] 655 | orig_tokens = example.doc_tokens[orig_doc_start : (orig_doc_end + 1)] 656 | 657 | tok_text = tokenizer.convert_tokens_to_string(tok_tokens) 658 | 659 | # tok_text = " ".join(tok_tokens) 660 | # 661 | # # De-tokenize WordPieces that have been split off. 662 | # tok_text = tok_text.replace(" ##", "") 663 | # tok_text = tok_text.replace("##", "") 664 | 665 | # Clean whitespace 666 | tok_text = tok_text.strip() 667 | tok_text = " ".join(tok_text.split()) 668 | orig_text = " ".join(orig_tokens) 669 | 670 | final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) 671 | if final_text in seen_predictions: 672 | continue 673 | 674 | seen_predictions[final_text] = True 675 | else: 676 | final_text = "" 677 | seen_predictions[final_text] = True 678 | #[example.doc_sentences[feature.cur_sent_to_orig_sent[e]] if e in feature.cur_sent_to_orig_sent.keys() else None for e in pred.evidence] 679 | evidences = [] 680 | for idx, sent_num in enumerate(pred.evidence): 681 | 682 | ex_idx = sent_num // max_answer_length 683 | sent_ids = sent_num % max_answer_length 684 | 685 | cur_feature = features[ex_idx] 686 | cur_example = examples[cur_feature.example_id] 687 | if sent_ids in cur_feature.cur_sent_to_orig_sent.keys(): 688 | evidences.append(cur_example.doc_sentences[cur_feature.cur_sent_to_orig_sent[sent_ids]]) 689 | 690 | # if pred.qt == 0: 691 | # final_text = 'yes' 692 | # elif pred.qt == 1: 693 | # final_text = 'no' 694 | nbest.append(_NbestPrediction(text=final_text, start_logit=pred.start_logit, end_logit=pred.end_logit, evidence=evidences)) 695 | # if we didn't include the empty option in the n-best, include it 696 | if version_2_with_negative: 697 | if "" not in seen_predictions: 698 | nbest.append(_NbestPrediction(text="", start_logit=null_start_logit, end_logit=null_end_logit)) 699 | 700 | # In very rare edge cases we could only have single null prediction. 701 | # So we just create a nonce prediction in this case to avoid failure. 702 | if len(nbest) == 1: 703 | nbest.insert(0, _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0)) 704 | 705 | # In very rare edge cases we could have no valid predictions. So we 706 | # just create a nonce prediction in this case to avoid failure. 707 | if not nbest: 708 | nbest.append(_NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0, evidence=[None, None, None])) 709 | 710 | assert len(nbest) >= 1 711 | 712 | total_scores = [] 713 | best_non_null_entry = None 714 | for entry in nbest: 715 | total_scores.append(entry.start_logit + entry.end_logit) 716 | if not best_non_null_entry: 717 | if entry.text: 718 | best_non_null_entry = entry 719 | 720 | probs = _compute_softmax(total_scores) 721 | 722 | nbest_json = [] 723 | for (i, entry) in enumerate(nbest): 724 | output = collections.OrderedDict() 725 | output["text"] = entry.text 726 | output["probability"] = probs[i] 727 | output["start_logit"] = entry.start_logit 728 | output["end_logit"] = entry.end_logit 729 | output["evidence"] = entry.evidence 730 | nbest_json.append(output) 731 | 732 | assert len(nbest_json) >= 1 733 | 734 | if not version_2_with_negative: 735 | all_predictions[example.qas_id] = nbest_json[0]["text"] 736 | 737 | if example.qas_id not in all_nbest_json.keys(): 738 | all_nbest_json[example.qas_id] = [] 739 | all_nbest_json[example.qas_id] += nbest_json[:2] 740 | 741 | for qas_id in all_predictions.keys(): 742 | all_predictions[qas_id] = sorted(all_nbest_json[qas_id], key=lambda x: x["start_logit"] + x["end_logit"], reverse=True)[0]["text"] 743 | 744 | if output_prediction_file: 745 | with open(output_prediction_file, "w", encoding='utf8') as writer: 746 | json.dump(all_predictions, writer, indent='\t', ensure_ascii=False) 747 | 748 | if output_nbest_file: 749 | with open(output_nbest_file, "w") as writer: 750 | json.dump(all_nbest_json, writer, indent='\t', ensure_ascii=False) 751 | 752 | if output_null_log_odds_file and version_2_with_negative: 753 | with open(output_null_log_odds_file, "w") as writer: 754 | writer.write(json.dumps(scores_diff_json, indent=4) + "\n") 755 | 756 | return all_predictions 757 | -------------------------------------------------------------------------------- /src/functions/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import torch 4 | import numpy as np 5 | import os 6 | 7 | from src.functions.processor_sent import ( 8 | SquadV1Processor, 9 | squad_convert_examples_to_features 10 | ) 11 | 12 | def init_logger(): 13 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 14 | datefmt='%m/%d/%Y %H:%M:%S', 15 | level=logging.INFO) 16 | 17 | def set_seed(args): 18 | random.seed(args.seed) 19 | np.random.seed(args.seed) 20 | torch.manual_seed(args.seed) 21 | if not args.no_cuda and torch.cuda.is_available(): 22 | torch.cuda.manual_seed_all(args.seed) 23 | 24 | # tensor를 list 형으로 변환하기위한 함수 25 | def to_list(tensor): 26 | return tensor.detach().cpu().tolist() 27 | 28 | # dataset을 load 하는 함수 29 | def load_examples(args, tokenizer, evaluate=False, output_examples=False, do_predict=False, input_dict=None): 30 | ''' 31 | 32 | :param args: 하이퍼 파라미터 33 | :param tokenizer: tokenization에 사용되는 tokenizer 34 | :param evaluate: 평가나 open test시, True 35 | :param output_examples: 평가나 open test 시, True / True 일 경우, examples와 features를 같이 return 36 | :param do_predict: open test시, True 37 | :param input_dict: open test시 입력되는 문서와 질문으로 이루어진 dictionary 38 | :return: 39 | examples : max_length 상관 없이, 원문으로 각 데이터를 저장한 리스트 40 | features : max_length에 따라 분할 및 tokenize된 원문 리스트 41 | dataset : max_length에 따라 분할 및 학습에 직접적으로 사용되는 tensor 형태로 변환된 입력 ids 42 | ''' 43 | input_dir = args.data_dir 44 | print("Creating features from dataset file at {}".format(input_dir)) 45 | 46 | # processor 선언 47 | processor = SquadV1Processor() 48 | 49 | # open test 시 50 | if do_predict: 51 | examples = processor.get_example_from_input(input_dict) 52 | # 평가 시 53 | elif evaluate: 54 | examples = processor.get_dev_examples(os.path.join(args.data_dir), 55 | filename=args.predict_file, tokenizer=tokenizer) 56 | # 학습 시 57 | else: 58 | examples = processor.get_train_examples(os.path.join(args.data_dir), 59 | filename=args.train_file, tokenizer=tokenizer) 60 | examples, features = squad_convert_examples_to_features( 61 | examples=examples, 62 | tokenizer=tokenizer, 63 | max_seq_length=args.max_seq_length, 64 | doc_stride=args.doc_stride, 65 | max_query_length=args.max_query_length, 66 | is_training=not evaluate, 67 | return_dataset="pt", 68 | threads=args.threads, 69 | ) 70 | 71 | if output_examples: 72 | return examples, features 73 | return features 74 | def load_input_data(args, tokenizer, question, context): 75 | processor = SquadV1Processor() 76 | example = [processor.example_from_input(question, context)] 77 | features, dataset = squad_convert_examples_to_features( 78 | examples=example, 79 | tokenizer=tokenizer, 80 | max_seq_length=args.max_seq_length, 81 | doc_stride=args.doc_stride, 82 | max_query_length=args.max_query_length, 83 | is_training=False, 84 | return_dataset="pt", 85 | threads=args.threads, 86 | ) 87 | return dataset, example, features -------------------------------------------------------------------------------- /src/model/__pycache__/attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/model/__pycache__/attention.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/main_function_rnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/model/__pycache__/main_function_rnn.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/main_functions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/model/__pycache__/main_functions.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/model/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/model_rnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/XAI_EvidenceExtraction/3bced7e815f72251614c1fbf9384ff087f08f594/src/model/__pycache__/model_rnn.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/main_function_rnn.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | import os 3 | import torch 4 | import timeit 5 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 6 | from tqdm import tqdm 7 | from nltk.translate.bleu_score import sentence_bleu 8 | from transformers import ( 9 | AdamW, 10 | get_linear_schedule_with_warmup 11 | ) 12 | 13 | from src.functions.utils import load_examples, set_seed, to_list, load_input_data 14 | from src.functions.processor_sent import SquadResult 15 | from src.functions.evaluate_v1_0 import eval_during_train, f1_score 16 | from src.functions.hotpotqa_metric import eval 17 | from src.functions.squad_metric import ( 18 | compute_predictions_logits, restore_prediction, restore_prediction2 19 | ) 20 | 21 | 22 | def train(args, model, tokenizer, logger): 23 | # 학습에 사용하기 위한 dataset Load 24 | examples, features = load_examples(args, tokenizer, evaluate=False, output_examples=True) 25 | 26 | # optimization 최적화 schedule 을 위한 전체 training step 계산 27 | t_total = len(features) // args.gradient_accumulation_steps * args.num_train_epochs 28 | 29 | # Layer에 따른 가중치 decay 적용 30 | no_decay = ["bias", "LayerNorm.weight"] 31 | optimizer_grouped_parameters = [ 32 | { 33 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 34 | "weight_decay": args.weight_decay, 35 | }, 36 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 37 | ] 38 | 39 | # optimizer 및 scheduler 선언 40 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 41 | scheduler = get_linear_schedule_with_warmup( 42 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 43 | ) 44 | 45 | # Training Step 46 | logger.info("***** Running training *****") 47 | logger.info(" Num examples = %d", len(features)) 48 | logger.info(" Num Epochs = %d", args.num_train_epochs) 49 | logger.info(" Train batch size per GPU = %d", args.train_batch_size) 50 | logger.info( 51 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 52 | args.train_batch_size 53 | * args.gradient_accumulation_steps) 54 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 55 | logger.info(" Total optimization steps = %d", t_total) 56 | 57 | global_step = 1 58 | 59 | tr_loss, logging_loss = 0.0, 0.0 60 | 61 | # loss buffer 초기화 62 | model.zero_grad() 63 | 64 | set_seed(args) 65 | 66 | for epoch in range(args.num_train_epochs): 67 | for step, batch in enumerate(features): 68 | # if not args.from_init_weight: 69 | # if global_step< int(args.checkpoint): 70 | # global_step+=1 71 | # continue 72 | # try: 73 | model.train() 74 | all_input_ids = torch.tensor([feature.input_ids for feature in batch], dtype=torch.long).cuda() 75 | all_attention_masks = torch.tensor([feature.attention_mask for feature in batch], dtype=torch.long).cuda() 76 | all_token_type_ids = torch.tensor([feature.token_type_ids for feature in batch], dtype=torch.long).cuda() 77 | all_sent_masks = torch.tensor([feature.sent_mask for feature in batch], dtype=torch.long).cuda() 78 | all_start_positions = torch.tensor([feature.start_position for feature in batch], dtype=torch.long).cuda() 79 | all_end_positions = torch.tensor([feature.end_position for feature in batch], dtype=torch.long).cuda() 80 | all_sent_label = torch.tensor([feature.sent_label for feature in batch], dtype=torch.long).cuda() 81 | if torch.sum(all_start_positions).item() == 0: 82 | continue 83 | # 모델에 입력할 입력 tensor 저장 84 | inputs = { 85 | "input_ids": all_input_ids, 86 | "attention_mask": all_attention_masks, 87 | "token_type_ids": all_token_type_ids, 88 | "sent_masks": all_sent_masks, 89 | "start_positions": all_start_positions, 90 | "end_positions": all_end_positions, 91 | 92 | } 93 | 94 | # Loss 계산 및 저장 95 | outputs = model(**inputs) 96 | total_loss = outputs[0] 97 | 98 | if args.gradient_accumulation_steps > 1: 99 | total_loss = total_loss / args.gradient_accumulation_steps 100 | 101 | total_loss.backward() 102 | tr_loss += total_loss.item() 103 | 104 | # Loss 출력 105 | if (global_step + 1) % 50 == 0: 106 | print("{} step processed.. Current Loss : {}".format((global_step+1),total_loss.item())) 107 | 108 | if (step + 1) % args.gradient_accumulation_steps == 0: 109 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 110 | 111 | optimizer.step() 112 | scheduler.step() # Update learning rate schedule 113 | model.zero_grad() 114 | global_step += 1 115 | 116 | # model save 117 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 118 | # 모델 저장 디렉토리 생성 119 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 120 | if not os.path.exists(output_dir): 121 | os.makedirs(output_dir) 122 | 123 | # 학습된 가중치 및 vocab 저장 124 | model.save_pretrained(output_dir) 125 | tokenizer.save_pretrained(output_dir) 126 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 127 | logger.info("Saving model checkpoint to %s", output_dir) 128 | 129 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 130 | # Validation Test!! 131 | logger.info("***** Eval results *****") 132 | evaluate(args, model, tokenizer, logger, global_step=global_step) 133 | # except: 134 | # print("Current Step {} Error!".format(global_step)) 135 | # continue 136 | 137 | return global_step, tr_loss / global_step 138 | 139 | def sample_train(args, model, tokenizer, logger): 140 | # 학습에 사용하기 위한 dataset Load 141 | examples, features = load_examples(args, tokenizer, evaluate=False, output_examples=True) 142 | 143 | # optimization 최적화 schedule 을 위한 전체 training step 계산 144 | t_total = len(features) // args.gradient_accumulation_steps * args.num_train_epochs 145 | 146 | # Layer에 따른 가중치 decay 적용 147 | no_decay = ["bias", "LayerNorm.weight"] 148 | optimizer_grouped_parameters = [ 149 | { 150 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 151 | "weight_decay": args.weight_decay, 152 | }, 153 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 154 | ] 155 | 156 | # optimizer 및 scheduler 선언 157 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 158 | scheduler = get_linear_schedule_with_warmup( 159 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 160 | ) 161 | 162 | # Training Step 163 | logger.info("***** Running training *****") 164 | logger.info(" Num examples = %d", len(features)) 165 | logger.info(" Num Epochs = %d", args.num_train_epochs) 166 | logger.info(" Train batch size per GPU = %d", args.train_batch_size) 167 | logger.info( 168 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 169 | args.train_batch_size 170 | * args.gradient_accumulation_steps) 171 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 172 | logger.info(" Total optimization steps = %d", t_total) 173 | 174 | global_step = 1 175 | 176 | tr_loss, logging_loss = 0.0, 0.0 177 | 178 | # loss buffer 초기화 179 | model.zero_grad() 180 | 181 | set_seed(args) 182 | for name, para in model.named_parameters(): 183 | if 'gru' not in name: 184 | print(name) 185 | para.requires_grad = False 186 | for epoch in range(args.num_train_epochs): 187 | for step, batch in enumerate(features): 188 | model.train() 189 | all_input_ids = torch.tensor([feature.input_ids for feature in batch], dtype=torch.long).cuda() 190 | all_attention_masks = torch.tensor([feature.attention_mask for feature in batch], dtype=torch.long).cuda() 191 | all_token_type_ids = torch.tensor([feature.token_type_ids for feature in batch], dtype=torch.long).cuda() 192 | all_sent_masks = torch.tensor([feature.sent_mask for feature in batch], dtype=torch.long).cuda() 193 | all_start_positions = torch.tensor([feature.start_position for feature in batch], dtype=torch.long).cuda() 194 | all_end_positions = torch.tensor([feature.end_position for feature in batch], dtype=torch.long).cuda() 195 | all_sent_label = torch.tensor([feature.sent_label for feature in batch], dtype=torch.long).cuda() 196 | if torch.sum(all_start_positions).item() == 0: 197 | continue 198 | # 모델에 입력할 입력 tensor 저장 199 | inputs = { 200 | "input_ids": all_input_ids, 201 | "attention_mask": all_attention_masks, 202 | "token_type_ids": all_token_type_ids, 203 | "sent_masks": all_sent_masks, 204 | "start_positions": all_start_positions, 205 | "end_positions": all_end_positions, 206 | 207 | } 208 | 209 | outputs = model(**inputs) 210 | loss, span_loss, mse_loss, sampled_evidence_scores, start_logits, end_logits, sampled_evidence_sentence = outputs 211 | 212 | 213 | # if args.gradient_accumulation_steps > 1: 214 | # loss = loss / args.gradient_accumulation_steps 215 | # if loss.item() == 0: 216 | # continue 217 | # loss.backward() 218 | 219 | if args.gradient_accumulation_steps > 1: 220 | span_loss = span_loss / args.gradient_accumulation_steps 221 | mse_loss = mse_loss / args.gradient_accumulation_steps 222 | loss = loss / args.gradient_accumulation_steps 223 | mse_loss.backward() 224 | 225 | 226 | 227 | 228 | 229 | tr_loss += loss.item() 230 | 231 | # Loss 출력 232 | if (global_step + 1) % 50 == 0: 233 | print("{} step processed.. Current Loss : {}".format((global_step+1),span_loss.item())) 234 | 235 | if (step + 1) % args.gradient_accumulation_steps == 0: 236 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 237 | 238 | optimizer.step() 239 | scheduler.step() # Update learning rate schedule 240 | model.zero_grad() 241 | global_step += 1 242 | 243 | # model save 244 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 245 | # 모델 저장 디렉토리 생성 246 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 247 | if not os.path.exists(output_dir): 248 | os.makedirs(output_dir) 249 | 250 | # 학습된 가중치 및 vocab 저장 251 | model.save_pretrained(output_dir) 252 | tokenizer.save_pretrained(output_dir) 253 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 254 | logger.info("Saving model checkpoint to %s", output_dir) 255 | 256 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 257 | # Validation Test!! 258 | logger.info("***** Eval results *****") 259 | evaluate(args, model, tokenizer, logger, global_step=global_step) 260 | # except: 261 | # print("Current Step {} Error!".format(global_step)) 262 | # continue 263 | 264 | return global_step, tr_loss / global_step 265 | def sample_train2(args, model, tokenizer, logger): 266 | # 학습에 사용하기 위한 dataset Load 267 | examples, features = load_examples(args, tokenizer, evaluate=False, output_examples=True) 268 | 269 | # optimization 최적화 schedule 을 위한 전체 training step 계산 270 | t_total = len(features) // args.gradient_accumulation_steps * args.num_train_epochs 271 | 272 | # Layer에 따른 가중치 decay 적용 273 | no_decay = ["bias", "LayerNorm.weight"] 274 | optimizer_grouped_parameters = [ 275 | { 276 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 277 | "weight_decay": args.weight_decay, 278 | }, 279 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 280 | ] 281 | 282 | # optimizer 및 scheduler 선언 283 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 284 | scheduler = get_linear_schedule_with_warmup( 285 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 286 | ) 287 | 288 | # Training Step 289 | logger.info("***** Running training *****") 290 | logger.info(" Num examples = %d", len(features)) 291 | logger.info(" Num Epochs = %d", args.num_train_epochs) 292 | logger.info(" Train batch size per GPU = %d", args.train_batch_size) 293 | logger.info( 294 | " Total train batch size (w. parallel, distributed & accumulation) = %d", 295 | args.train_batch_size 296 | * args.gradient_accumulation_steps) 297 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 298 | logger.info(" Total optimization steps = %d", t_total) 299 | 300 | global_step = 1 301 | 302 | tr_loss, logging_loss = 0.0, 0.0 303 | 304 | # loss buffer 초기화 305 | model.zero_grad() 306 | 307 | set_seed(args) 308 | # for name, para in model.named_parameters(): 309 | # if 'gru' not in name: 310 | # print(name) 311 | # para.requires_grad = False 312 | for epoch in range(args.num_train_epochs): 313 | for step, batch in enumerate(features): 314 | model.train() 315 | all_input_ids = torch.tensor([feature.input_ids for feature in batch], dtype=torch.long).cuda() 316 | all_attention_masks = torch.tensor([feature.attention_mask for feature in batch], dtype=torch.long).cuda() 317 | all_token_type_ids = torch.tensor([feature.token_type_ids for feature in batch], dtype=torch.long).cuda() 318 | all_sent_masks = torch.tensor([feature.sent_mask for feature in batch], dtype=torch.long).cuda() 319 | all_start_positions = torch.tensor([feature.start_position for feature in batch], dtype=torch.long).cuda() 320 | all_end_positions = torch.tensor([feature.end_position for feature in batch], dtype=torch.long).cuda() 321 | all_question_type = torch.tensor([batch[0].question_type], dtype=torch.long).cuda() 322 | 323 | if torch.sum(all_start_positions).item() == 0: 324 | continue 325 | # 모델에 입력할 입력 tensor 저장 326 | inputs = { 327 | "input_ids": all_input_ids, 328 | "attention_mask": all_attention_masks, 329 | "token_type_ids": all_token_type_ids, 330 | "sent_masks": all_sent_masks, 331 | "start_positions": all_start_positions, 332 | "end_positions": all_end_positions, 333 | #"question_type": all_question_type 334 | 335 | } 336 | 337 | outputs = model(**inputs) 338 | loss, sampled_evidence_scores, mask, start_logits, end_logits, sampled_evidence_sentence = outputs 339 | predicted_answer = [] 340 | evidence_predicted_answer = [] 341 | # print("\n".join([str(e) for e in sampled_evidence_sentence.tolist()])) 342 | for path in range(num_samples): 343 | all_results = [] 344 | start_logit = start_logits[:, :, path] 345 | end_logit = end_logits[:, :, path] 346 | batch_size = start_logits.size(0) 347 | for i in range(batch_size): 348 | # feature 고유 id로 접근하여 원본 q_id 저장 349 | # 각 feature는 유일한 q_id를 갖고 있지 않음 350 | # ==> context가 긴 경우, context를 분할하여 여러 개의 데이터로 변환하기 때문! 351 | eval_feature = batch[i] 352 | 353 | # 입력 질문에 대한 N개의 결과 저장하기위해 q_id 저장 354 | unique_id = int(eval_feature.unique_id) 355 | 356 | # outputs = [start_logits, end_logits] 357 | output = [to_list(output[i]) for output in [start_logit, end_logit]] 358 | 359 | # start_logits: [batch_size, max_length] 360 | # end_logits: [batch_size, max_length] 361 | start, end = output 362 | 363 | # q_id에 대한 예측 정답 시작/끝 위치 확률 저장 364 | result = SquadResult(unique_id, start, end) 365 | 366 | # feature에 종속되는 최종 출력 값을 리스트에 저장 367 | all_results.append(result) 368 | prediction = restore_prediction(examples[step], batch, all_results, args.n_best_size, args.do_lower_case, 369 | args.verbose_logging, tokenizer) 370 | predicted_answer.append(prediction) 371 | evidence_path = sampled_evidence_sentence[path].tolist() 372 | 373 | question = all_input_ids[0, eval_feature.cur_sent_range[0]].tolist() 374 | evidence_1_feature_index = evidence_path[0]//args.max_sent_num 375 | evidence_2_feature_index = evidence_path[1]//args.max_sent_num 376 | evidence_3_feature_index = evidence_path[2] // args.max_sent_num 377 | 378 | evidence_1_sent_num = evidence_path[0] % args.max_sent_num 379 | evidence_2_sent_num = evidence_path[1] % args.max_sent_num 380 | evidence_3_sent_num = evidence_path[2] % args.max_sent_num 381 | 382 | evidence_1_sentence = all_input_ids[ 383 | evidence_1_feature_index, batch[evidence_1_feature_index].cur_sent_range[evidence_1_sent_num]].tolist() 384 | evidence_2_sentence = all_input_ids[ 385 | evidence_2_feature_index, batch[evidence_2_feature_index].cur_sent_range[evidence_2_sent_num]].tolist() 386 | evidence_3_sentence = all_input_ids[ 387 | evidence_3_feature_index, batch[evidence_3_feature_index].cur_sent_range[evidence_3_sent_num]].tolist() 388 | 389 | tmp_input_ids = question + evidence_1_sentence + evidence_2_sentence + evidence_3_sentence 390 | tmp_input_ids = tmp_input_ids[:args.max_seq_length-1] + [tokenizer.sep_token_id] 391 | tokens = tokenizer.convert_ids_to_tokens(tmp_input_ids) 392 | tmp_attention_mask = torch.zeros([1, args.max_seq_length], dtype=torch.long) 393 | input_mask = [e for e in range(len(tmp_input_ids))] 394 | tmp_attention_mask[:, input_mask] = 1 395 | tmp_input_ids = tmp_input_ids + [tokenizer.pad_token_id] * (args.max_seq_length-len(tmp_input_ids)) 396 | 397 | tmp_sentence_mask = [0]*len(question) + [1]*len(evidence_1_sentence) + [2]*len(evidence_2_sentence) + [3]*len(evidence_3_sentence) 398 | tmp_sentence_mask = tmp_sentence_mask[:args.max_seq_length] + [0]*(args.max_seq_length-len(tmp_sentence_mask)) 399 | 400 | tmp_input_ids = torch.tensor([tmp_input_ids], dtype=torch.long).cuda() 401 | tmp_attention_mask = torch.tensor(tmp_attention_mask, dtype=torch.long).cuda() 402 | tmp_sentence_mask = torch.tensor([tmp_sentence_mask], dtype=torch.long).cuda() 403 | inputs = { 404 | "input_ids": tmp_input_ids, 405 | "attention_mask": tmp_attention_mask, 406 | "sent_masks": tmp_sentence_mask, 407 | } 408 | e_start_logits, e_end_logits, e_sampled_evidence_sentence = model(**inputs) 409 | 410 | # start_logits: [batch_size, max_length] 411 | # end_logits: [batch_size, max_length] 412 | start = e_start_logits[0] 413 | end = e_end_logits[0] 414 | 415 | # q_id에 대한 예측 정답 시작/끝 위치 확률 저장 416 | result = [SquadResult(0, start, end)] 417 | prediction = restore_prediction2(tokens, result, args.n_best_size, tokenizer) 418 | 419 | evidence_predicted_answer.append(prediction) 420 | # feature에 종속되는 최종 출력 값을 리스트에 저장 421 | num_samples = predicted_answer.size(0) 422 | f1_list = [1e-3 for _ in range(num_samples)] 423 | g_f1_list = [1e-3 for _ in range(num_samples)] 424 | gold = examples[step][0].answer_text.lower().split(' ') 425 | gold_list = [] 426 | for word in gold: 427 | gold_list += tokenizer.tokenize(word) 428 | gold_list = tokenizer.convert_tokens_to_string(gold_list).strip().split(' ') 429 | 430 | for path in range(num_samples): 431 | predicted = predicted_answer[path].lower().split(' ') 432 | e_predicted = evidence_predicted_answer[path].lower().split(' ') 433 | f1 = sentence_bleu([predicted], e_predicted, weights=(1.0, 0, 0, 0)) 434 | g_f1 = sentence_bleu([gold_list], e_predicted, weights=(1.0, 0, 0, 0)) 435 | 436 | f1_list[path] += f1 437 | g_f1_list[path] += g_f1 438 | f1_list = torch.tensor(f1_list, dtype=torch.float).cuda() 439 | g_f1_list = torch.tensor(g_f1_list, dtype=torch.float).cuda() 440 | #sampled_evidence_scores (10,3, 1, 400) 441 | #sampled_evidence_sentence (10, 3) 442 | 443 | sampled_evidence_scores = sampled_evidence_scores.squeeze(2) 444 | 445 | 446 | # tmp~ : sample별로 추출된 문장 idx들 [1, 0, 0, 1, 0, 0, 1, 0, 0, 0] 447 | s_sampled_evidence_sentence = torch.zeros([num_samples, args.max_dec_len, sampled_evidence_scores.size(-1)], dtype=torch.long).cuda() 448 | g_sampled_evidence_sentence = torch.zeros([num_samples, args.max_dec_len, sampled_evidence_scores.size(-1)], dtype=torch.long).cuda() 449 | for idx in range(num_samples): 450 | sampled_sampled_evidence_sentence = F.one_hot(sampled_evidence_sentence[idx, :], 451 | num_classes=sampled_evidence_scores.size(-1)).unsqueeze(0) 452 | negative_sampled_evidence_sentence = torch.sum(sampled_sampled_evidence_sentence, 1, keepdim=True) 453 | f1 = f1_list[idx] 454 | g_f1 = g_f1_list[idx] 455 | if f1.item() < 0.5: 456 | s_sampled_evidence_sentence[idx, :, :] = mask - negative_sampled_evidence_sentence 457 | f1_list[idx] = 1 - f1 458 | 459 | else: 460 | s_sampled_evidence_sentence[idx, :, :] = sampled_sampled_evidence_sentence 461 | if g_f1.item() < 0.5: 462 | g_sampled_evidence_sentence[idx, :, :] = mask - negative_sampled_evidence_sentence 463 | g_f1_list[idx] = 1- g_f1 464 | else: 465 | g_sampled_evidence_sentence[idx, :, :] = sampled_sampled_evidence_sentence 466 | e_div = torch.sum(s_sampled_evidence_sentence, -1) 467 | g_div = torch.sum(g_sampled_evidence_sentence, -1) 468 | # if e_div.item() == 3: 469 | # e_div = 1 470 | # if g_div.item() == 3: 471 | # g_div = 1 472 | evidence_nll = -F.log_softmax(sampled_evidence_scores, -1) 473 | g_evidence_nll = -F.log_softmax(sampled_evidence_scores, -1) 474 | 475 | evidence_nll = evidence_nll * s_sampled_evidence_sentence 476 | g_evidence_nll = g_evidence_nll * g_sampled_evidence_sentence 477 | f1_list[1:] = f1_list[1:]*0.25 478 | evidence_nll = torch.mean(torch.sum(evidence_nll, -1)/e_div, -1) 479 | evidence_nll = evidence_nll * f1_list 480 | evidence_nll = torch.mean(evidence_nll) 481 | 482 | g_evidence_nll = torch.mean(torch.sum(g_evidence_nll, -1)/g_div, -1) 483 | g_evidence_nll = g_evidence_nll * g_f1_list 484 | g_evidence_nll = torch.mean(g_evidence_nll) 485 | 486 | if evidence_nll.item() != 0 and evidence_nll.item() < 1000: 487 | loss = loss + 0.1 * evidence_nll 488 | if g_evidence_nll.item() != 0 and evidence_nll.item() < 1000: 489 | loss = loss + 0.1 * g_evidence_nll 490 | if args.gradient_accumulation_steps > 1: 491 | loss = loss / args.gradient_accumulation_steps 492 | loss.backward() 493 | 494 | 495 | 496 | 497 | 498 | tr_loss += loss.item() 499 | 500 | # Loss 출력 501 | if (global_step + 1) % 50 == 0: 502 | print("{} step processed.. Current Loss : {}".format((global_step+1),loss.item())) 503 | 504 | if (step + 1) % args.gradient_accumulation_steps == 0: 505 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 506 | 507 | optimizer.step() 508 | scheduler.step() # Update learning rate schedule 509 | model.zero_grad() 510 | global_step += 1 511 | 512 | # model save 513 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 514 | # 모델 저장 디렉토리 생성 515 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 516 | if not os.path.exists(output_dir): 517 | os.makedirs(output_dir) 518 | 519 | # 학습된 가중치 및 vocab 저장 520 | model.save_pretrained(output_dir) 521 | tokenizer.save_pretrained(output_dir) 522 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 523 | logger.info("Saving model checkpoint to %s", output_dir) 524 | 525 | if args.logging_steps > 0 and global_step % args.logging_steps == 0: 526 | # Validation Test!! 527 | logger.info("***** Eval results *****") 528 | evaluate(args, model, tokenizer, logger, global_step=global_step) 529 | # except: 530 | # print("Current Step {} Error!".format(global_step)) 531 | # continue 532 | 533 | return global_step, tr_loss / global_step 534 | # 정답이 사전부착된 데이터로부터 평가하기 위한 함수 535 | def evaluate(args, model, tokenizer, logger, global_step = ""): 536 | # 데이터셋 Load 537 | try: 538 | examples, features = load_examples(args, tokenizer, evaluate=True, output_examples=True) 539 | except: 540 | return None 541 | # 최종 출력 파일 저장을 위한 디렉토리 생성 542 | if not os.path.exists(args.output_dir): 543 | os.makedirs(args.output_dir) 544 | 545 | 546 | 547 | # Eval! 548 | logger.info("***** Running evaluation {} *****".format(global_step)) 549 | logger.info(" Num examples = %d", len(features)) 550 | logger.info(" Batch size = %d", args.eval_batch_size) 551 | 552 | # 모델 출력을 저장하기위한 리스트 선언 553 | all_results = [] 554 | 555 | # 평가 시간 측정을 위한 time 변수 556 | start_time = timeit.default_timer() 557 | model.eval() 558 | tmp_scores = [] 559 | for batch_idx, batch in enumerate(features): 560 | # 모델을 평가 모드로 변경 561 | 562 | all_input_ids = torch.tensor([feature.input_ids for feature in batch], dtype=torch.long).cuda() 563 | all_attention_masks = torch.tensor([feature.attention_mask for feature in batch], dtype=torch.long).cuda() 564 | all_token_type_ids = torch.tensor([feature.token_type_ids for feature in batch], dtype=torch.long).cuda() 565 | all_sent_masks = torch.tensor([feature.sent_mask for feature in batch], dtype=torch.long).cuda() 566 | 567 | with torch.no_grad(): 568 | # 평가에 필요한 입력 데이터 저장 569 | inputs = { 570 | "input_ids": all_input_ids, 571 | "attention_mask": all_attention_masks, 572 | "token_type_ids": all_token_type_ids, 573 | "sent_masks": all_sent_masks, 574 | } 575 | # outputs = (start_logits, end_logits) 576 | # start_logits: [batch_size, max_length] 577 | # end_logits: [batch_size, max_length] 578 | outputs = model(**inputs) 579 | 580 | 581 | # 1,000 582 | # 입력 데이터 별 고유 id 저장 (feature와 dataset에 종속) 583 | example_indices = batch[-1] 584 | batch_size = all_input_ids.size(0) 585 | for i in range(batch_size): 586 | # feature 고유 id로 접근하여 원본 q_id 저장 587 | # 각 feature는 유일한 q_id를 갖고 있지 않음 588 | # ==> context가 긴 경우, context를 분할하여 여러 개의 데이터로 변환하기 때문! 589 | eval_feature = batch[i] 590 | 591 | # 입력 질문에 대한 N개의 결과 저장하기위해 q_id 저장 592 | unique_id = int(eval_feature.unique_id) 593 | 594 | # outputs = [start_logits, end_logits] 595 | output = [to_list(output[i]) for output in outputs] 596 | 597 | # start_logits: [batch_size, max_length] 598 | # end_logits: [batch_size, max_length] 599 | 600 | start_logits, end_logits, evidence = output 601 | 602 | # q_id에 대한 예측 정답 시작/끝 위치 확률 저장 603 | result = SquadResult(unique_id, start_logits, end_logits, evidence) 604 | 605 | # feature에 종속되는 최종 출력 값을 리스트에 저장 606 | all_results.append(result) 607 | # refine_input_ids = [] 608 | # refine_attention_masks = [] 609 | # eval_features = [] 610 | # evidence = outputs[2] 611 | # evidence_path = to_list(evidence)[0] 612 | # cur_example = examples[batch_idx] 613 | # cur_feature = features[batch_idx] 614 | # 615 | # eval_feature = [] 616 | # 617 | # hop_1_ex_id = evidence_path[0] // args.max_sent_num 618 | # hop_1_sent_num = (evidence_path[0] % args.max_sent_num) - 1 619 | # hop_1_evidence_sentence = cur_example[cur_feature[hop_1_ex_id].example_id].doc_sent_tokens[hop_1_sent_num] 620 | # 621 | # hop_2_ex_id = evidence_path[1] // args.max_sent_num 622 | # hop_2_sent_num = (evidence_path[1] % args.max_sent_num) - 1 623 | # 624 | # hop_2_evidence_sentence = cur_example[cur_feature[hop_2_ex_id].example_id].doc_sent_tokens[hop_2_sent_num] 625 | # 626 | # hop_3_ex_id = evidence_path[2] // args.max_sent_num 627 | # hop_3_sent_num = (evidence_path[2] % args.max_sent_num) - 1 628 | # hop_3_evidence_sentence = cur_example[cur_feature[hop_3_ex_id].example_id].doc_sent_tokens[hop_3_sent_num] 629 | # 630 | # query = [tokenizer.cls_token_id] + cur_feature[0].truncated_query 631 | # 632 | # refine_context = ['[SEP]'] + hop_1_evidence_sentence + hop_2_evidence_sentence + hop_3_evidence_sentence 633 | # eval_feature.append(query + refine_context) 634 | # eval_features.append(eval_feature) 635 | # context_token_ids = tokenizer.convert_tokens_to_ids(refine_context) 636 | # refine_input_id = query + context_token_ids 637 | # refine_input_id = refine_input_id[:args.max_seq_length - 1] + [tokenizer.sep_token_id] + [ 638 | # tokenizer.pad_token_id] * ( 639 | # args.max_seq_length - len( 640 | # refine_input_id) - 1) 641 | # refine_attention_mask = [1] * args.max_seq_length if 0 not in refine_input_id else [1] * ( 642 | # refine_input_id.index(0)) + [0] * (args.max_seq_length - refine_input_id.index(0)) 643 | # 644 | # refine_input_ids.append(refine_input_id) 645 | # refine_attention_masks.append(refine_attention_mask) 646 | # refine_input_ids = torch.tensor(refine_input_ids, dtype=torch.long).cuda() 647 | # refine_attention_masks = torch.tensor(refine_attention_masks, dtype=torch.long).cuda() 648 | # 649 | # start_logit, end_logit = model(input_ids=refine_input_ids, attention_mask=refine_attention_masks) 650 | # batch_size = start_logit.size(0) 651 | # 652 | # results = [] 653 | # for i in range(batch_size): 654 | # output = [to_list(output[i]) for output in [start_logit, end_logit]] 655 | # start, end = output 656 | # # q_id에 대한 예측 정답 시작/끝 위치 확률 저장 657 | # result = SquadResult(0, start, end, evidence=evidence_path) 658 | # results.append(result) 659 | # preds, scores = restore_prediction_2(results, eval_features, args.n_best_size, tokenizer) 660 | # for idx, pred in enumerate(preds): 661 | # f1 = f1_score(pred, cur_example[0].answer_text) 662 | # tmp_scores.append(f1) 663 | # print(len(tmp_scores)) 664 | # print(sum(tmp_scores)/len(tmp_scores)) 665 | # 평가 시간 측정을 위한 time 변수 666 | evalTime = timeit.default_timer() - start_time 667 | logger.info(" Evaluation done in total %f secs (%f sec per example)", evalTime, evalTime / len(features)) 668 | 669 | # 최종 예측 값을 저장하기 위한 파일 생성 670 | output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(global_step)) 671 | output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(global_step)) 672 | 673 | # Yes/No Question을 다룰 경우, 각 정답이 유효할 확률 저장을 위한 파일 생성 674 | if args.version_2_with_negative: 675 | output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(global_step)) 676 | else: 677 | output_null_log_odds_file = None 678 | 679 | # q_id에 대한 N개의 출력 값의 확률로 부터 가장 확률이 높은 최종 예측 값 저장 680 | predictions = compute_predictions_logits( 681 | examples, 682 | features, 683 | all_results, 684 | args.n_best_size, 685 | args.max_sent_num, 686 | args.do_lower_case, 687 | output_prediction_file, 688 | output_nbest_file, 689 | output_null_log_odds_file, 690 | args.verbose_logging, 691 | args.version_2_with_negative, 692 | args.null_score_diff_threshold, 693 | tokenizer, 694 | ) 695 | 696 | output_dir = os.path.join(args.output_dir, 'eval') 697 | if not os.path.exists(output_dir): 698 | os.makedirs(output_dir) 699 | 700 | # KorQuAD 평가 스크립트 기반 성능 저장을 위한 파일 생성 701 | output_eval_file = os.path.join(output_dir, "eval_result_{}_{}.txt".format( 702 | list(filter(None, args.model_name_or_path.split("/"))).pop(), 703 | global_step)) 704 | 705 | logger.info("***** Official Eval results *****") 706 | with open(output_eval_file, "w", encoding='utf-8') as f: 707 | # KorQuAD 평가 스크립트 기반의 성능 측정 708 | official_eval_results = eval_during_train(args, global_step) 709 | # official_eval_results = eval(output_prediction_file, os.path.join(args.data_dir, args.predict_file)) 710 | for key in sorted(official_eval_results.keys()): 711 | logger.info(" %s = %s", key, str(official_eval_results[key])) 712 | f.write(" {} = {}\n".format(key, str(official_eval_results[key]))) 713 | 714 | -------------------------------------------------------------------------------- /src/model/model_rnn.py: -------------------------------------------------------------------------------- 1 | from torch.nn import CrossEntropyLoss, KLDivLoss 2 | import torch.nn as nn 3 | import torch 4 | from torch.nn import functional as F 5 | from transformers.modeling_electra import ElectraPreTrainedModel, ElectraModel 6 | from torch.nn import TransformerEncoderLayer 7 | from src.model.attention import MultiheadAttention 8 | from random import random, randint, randrange 9 | from torch.nn import MSELoss 10 | import math 11 | 12 | class AttentionDecoder(nn.Module): 13 | def __init__(self, hidden_size): 14 | super(AttentionDecoder, self).__init__() 15 | self.dense1 = nn.Linear(in_features=hidden_size, out_features=hidden_size) 16 | self.dense2 = nn.Linear(in_features=hidden_size, out_features=hidden_size) 17 | self.dense3 = nn.Linear(in_features=hidden_size * 2, out_features=hidden_size) 18 | 19 | self.decoder = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, batch_first=True, num_layers=1) 20 | 21 | self.div_term = math.sqrt(hidden_size) 22 | 23 | def forward(self, last_hidden, decoder_inputs, encoder_outputs, attention_mask, is_training=True): 24 | ''' 25 | :param last_hidden: (1, batch, hidden) 26 | :param decoder_inputs: (batch, 1, hidden) 27 | :param encoder_outputs: (batch, seq_len, hidden) 28 | :return: 29 | ''' 30 | batch_size = decoder_inputs.size(0) 31 | indexes = [e for e in range(batch_size)] 32 | key_encoder_outputs = self.dense1(encoder_outputs) 33 | value_encoder_outputs = self.dense2(encoder_outputs) 34 | # key : (batch, seq, hidden) 35 | # value : (batch, seq, hidden) 36 | 37 | output, hidden = self.decoder(decoder_inputs, hx=last_hidden) 38 | # output : (batch, 1, hidden) 39 | # hidden : (1, batch, hidden) 40 | 41 | t_encoder_outputs = key_encoder_outputs.transpose(1, 2) 42 | # t_encoder_outputs : (batch, hidden, seq) 43 | 44 | attn_outputs = output.bmm(t_encoder_outputs) + attention_mask 45 | # attn_outputs : (batch, 1, seq_len) 46 | 47 | # attn_alignment = F.softmax(attn_outputs, -1) 48 | if is_training: 49 | # attn_alignment = F.gumbel_softmax(attn_outputs, tau=1, hard=False, dim=-1) 50 | attn_alignment = F.softmax(attn_outputs, -1) 51 | else: 52 | attn_alignment = F.softmax(attn_outputs, -1) 53 | # attn_alignment : (batch, 1, seq_len) 54 | 55 | evidence_sentence = attn_alignment.argmax(-1).squeeze(1) 56 | #if is_training: 57 | attention_mask[indexes, 0, evidence_sentence] = -1e10 58 | context = attn_alignment.bmm(value_encoder_outputs) 59 | # context : (batch, 1, hidden) 60 | 61 | hidden_states = torch.cat([context, output], -1) 62 | 63 | result = self.dense3(hidden_states) 64 | return result, hidden, evidence_sentence, attn_outputs, attention_mask 65 | 66 | class SampledAttentionDecoder1204(nn.Module): 67 | def __init__(self, hidden_size): 68 | super(SampledAttentionDecoder1204, self).__init__() 69 | self.dense1 = nn.Linear(in_features=hidden_size, out_features=hidden_size) 70 | self.dense2 = nn.Linear(in_features=hidden_size, out_features=hidden_size) 71 | self.dense3 = nn.Linear(in_features=hidden_size * 2, out_features=hidden_size) 72 | 73 | self.decoder = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, batch_first=True, num_layers=1) 74 | 75 | self.div_term = math.sqrt(hidden_size) 76 | 77 | def forward(self, last_hidden, decoder_inputs, encoder_outputs, attention_mask, is_training=True, is_sample=False): 78 | ''' 79 | :param last_hidden: (1, batch, hidden) 80 | :param decoder_inputs: (batch, 1, hidden) 81 | :param encoder_outputs: (batch, seq_len, hidden) 82 | :return: 83 | ''' 84 | batch_size = decoder_inputs.size(0) 85 | indexes = [e for e in range(batch_size)] 86 | key_encoder_outputs = self.dense1(encoder_outputs) 87 | value_encoder_outputs = self.dense2(encoder_outputs) 88 | # key : (batch, seq, hidden) 89 | # value : (batch, seq, hidden) 90 | 91 | output, hidden = self.decoder(decoder_inputs, hx=last_hidden) 92 | # output : (batch, 1, hidden) 93 | # hidden : (1, batch, hidden) 94 | 95 | t_encoder_outputs = key_encoder_outputs.transpose(1, 2) 96 | # t_encoder_outputs : (batch, hidden, seq) 97 | 98 | attn_outputs = output.bmm(t_encoder_outputs) + attention_mask 99 | # attn_outputs : (batch, 1, seq_len) 100 | 101 | # attn_alignment = F.softmax(attn_outputs, -1) 102 | 103 | if is_sample: 104 | attn_alignment = F.gumbel_softmax(attn_outputs, tau=1, hard=False, dim=-1) 105 | # a = torch.sum(attn_alignment1) 106 | else: 107 | attn_alignment = F.softmax(attn_outputs, -1) 108 | # b = torch.sum(attn_alignment) 109 | 110 | # attn_alignment : (batch, 1, seq_len) 111 | 112 | evidence_sentence = attn_alignment.argmax(-1).squeeze(1) 113 | #if is_training: 114 | attention_mask[indexes, 0, evidence_sentence] = -1e10 115 | context = attn_alignment.bmm(value_encoder_outputs) 116 | # context : (batch, 1, hidden) 117 | 118 | hidden_states = torch.cat([context, output], -1) 119 | 120 | result = self.dense3(hidden_states) 121 | return result, hidden, evidence_sentence, attn_outputs, attention_mask 122 | class SampledAttentionDecoder(nn.Module): 123 | def __init__(self, hidden_size): 124 | super(SampledAttentionDecoder, self).__init__() 125 | self.dense1 = nn.Linear(in_features=hidden_size, out_features=hidden_size) 126 | self.dense2 = nn.Linear(in_features=hidden_size, out_features=hidden_size) 127 | self.dense3 = nn.Linear(in_features=hidden_size * 2, out_features=hidden_size) 128 | self.decoder = nn.GRU(input_size=hidden_size, hidden_size=hidden_size, batch_first=True, num_layers=1) 129 | 130 | self.div_term = math.sqrt(hidden_size) 131 | def forward(self, last_hidden, decoder_inputs, encoder_outputs, attention_mask): 132 | ''' 133 | :param last_hidden: (1, batch, hidden) 134 | :param decoder_inputs: (batch, 1, hidden) 135 | :param encoder_outputs: (batch, seq_len, hidden) 136 | :return: 137 | ''' 138 | batch_size = decoder_inputs.size(0) 139 | 140 | key_encoder_outputs = self.dense1(encoder_outputs) 141 | value_encoder_outputs = self.dense2(encoder_outputs) 142 | # key : (batch, seq, hidden) 143 | # value : (batch, seq, hidden) 144 | 145 | output, hidden = self.decoder(decoder_inputs, hx=last_hidden) 146 | # output : (batch, 1, hidden) 147 | # hidden : (1, batch, hidden) 148 | 149 | t_encoder_outputs = key_encoder_outputs.transpose(1, 2) 150 | # t_encoder_outputs : (batch, hidden, seq) 151 | 152 | attn_outputs = output.bmm(t_encoder_outputs) + attention_mask 153 | # attn_outputs : (batch, 1, seq_len) 154 | 155 | attn_alignment = F.gumbel_softmax(attn_outputs, tau=1, hard=False, dim=-1) 156 | # attn_alignment = F.softmax(attn_outputs, dim=-1) 157 | # attn_alignment : (batch, 1, seq_len) 158 | 159 | evidence_sentence = torch.argmax(attn_alignment, -1).squeeze(1) 160 | 161 | 162 | # attention_mask[indexes, 0, evidence_sentence] = -1e10 163 | 164 | for idx in range(len(attention_mask)): 165 | attention_mask[idx, 0, evidence_sentence[idx]] = -1e10 166 | aa = attention_mask.tolist() 167 | context = attn_alignment.bmm(value_encoder_outputs) 168 | # context : (batch, 1, hidden) 169 | 170 | hidden_states = torch.cat([context, output], -1) 171 | 172 | result = self.dense3(hidden_states) 173 | return result, hidden, evidence_sentence, attn_outputs.squeeze(1), attention_mask 174 | 175 | 176 | class ElectraForQuestionAnswering_sent_evidence_trm_sampling_1016(ElectraPreTrainedModel): 177 | def __init__(self, config): 178 | super(ElectraForQuestionAnswering_sent_evidence_trm_sampling_1016, self).__init__(config) 179 | # 분류 해야할 라벨 개수 (start/end) 180 | self.num_labels = config.num_labels 181 | self.hidden_size = config.hidden_size 182 | 183 | # ELECTRA 모델 선언 184 | self.max_seq_length = config.max_position_embeddings 185 | self.electra = ElectraModel(config) 186 | self.max_sent_num = config.max_sent_num 187 | self.num_samples = config.num_samples 188 | self.start_dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size) 189 | self.end_dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size) 190 | 191 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 192 | 193 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 194 | self.gru = AttentionDecoder(self.hidden_size) 195 | 196 | # ELECTRA weight 초기화 197 | self.init_weights() 198 | 199 | def extract_answerable_sent_idx(self, answer_positions, sent_masks): 200 | batch_size = sent_masks.size(0) 201 | answerable_sent_id = torch.gather(sent_masks, 1, answer_positions.unsqueeze(-1)).squeeze(-1) 202 | expanded_answerable_sent_id = F.one_hot(answerable_sent_id, num_classes=self.max_sent_num).reshape( 203 | batch_size * self.max_sent_num) 204 | expanded_answerable_sent_id[0::self.max_sent_num] = 0 205 | answerable_sent_ids= torch.where(expanded_answerable_sent_id ==1)[0] 206 | return answerable_sent_ids 207 | 208 | def _path_generate(self, hop_score, p_mask, n_mask, answerable_sent_num=None): 209 | if answerable_sent_num: 210 | n_negative_hop_score = hop_score + n_mask 211 | n_positive_hop_score = hop_score + p_mask 212 | _, n_hop_negative_path_idx = torch.sort(n_negative_hop_score, descending=True) 213 | n_hop_negative_path_idx = n_hop_negative_path_idx[:3] 214 | _, n_hop_positive_path_idx = torch.sort(n_positive_hop_score, descending=True) 215 | n_hop_positive_path_idx = n_hop_positive_path_idx[:answerable_sent_num] 216 | path_idx = torch.cat([n_hop_negative_path_idx, n_hop_positive_path_idx]) 217 | path_label = torch.cat([torch.ones([3]).float(), torch.ones([answerable_sent_num])]).cuda() 218 | path_logits = torch.gather(hop_score, 0, path_idx) 219 | return path_logits, path_idx, path_label 220 | 221 | 222 | else: 223 | n_hop_score = hop_score + n_mask 224 | _, n_hop_path_idx = torch.sort(n_hop_score, descending=True) 225 | n_hop_path_idx = n_hop_path_idx[:1] 226 | path_logits = torch.gather(hop_score, 0, n_hop_path_idx) 227 | return path_logits, n_hop_path_idx, None 228 | 229 | def _cross_entropy(self, logits): 230 | loss1 = -(F.log_softmax(logits, dim=-1)) 231 | # print(to_list(loss1)) 232 | loss = loss1 233 | return loss 234 | 235 | def forward( 236 | self, 237 | input_ids=None, 238 | 239 | ################## 240 | attention_mask=None, 241 | token_type_ids=None, 242 | sent_masks=None, 243 | ############## 244 | 245 | start_positions=None, 246 | end_positions=None, 247 | ): 248 | 249 | # ELECTRA output 저장 250 | # outputs : [1, batch_size, seq_length, hidden_size] 251 | # electra 선언 부분에 특정 옵션을 부여할 경우 출력은 다음과 같음 252 | # outputs : (last-layer hidden state, all hidden states, all attentions) 253 | # last-layer hidden state : [batch, seq_length, hidden_size] 254 | # all hidden states : [13, batch, seq_length, hidden_size] 255 | # 12가 아닌 13인 이유?? ==> 토큰의 임베딩도 output에 포함되어 return 256 | # all attentions : [12, batch, num_heads, seq_length, seq_length] 257 | batch_size = input_ids.size(0) 258 | 259 | outputs = self.electra( 260 | input_ids, 261 | attention_mask=attention_mask, 262 | token_type_ids=token_type_ids, 263 | ) 264 | is_training = False 265 | if start_positions is not None: 266 | is_training=True 267 | 268 | sequence_output = outputs[0] 269 | # sequence_output : [batch_size, seq_length, hidden_size] 270 | 271 | cls_output = sequence_output[:, 0, :] 272 | # cls_output : [batch, hidden] 273 | 274 | sentence_masks = F.one_hot(sent_masks, num_classes=self.max_sent_num).transpose(1, 2).float() 275 | sentence_masks[:, 0, :] = sentence_masks[:, 0, :] * attention_mask 276 | # sentence_masks : [batch, seq_length] ==> [batch, seq_len, sent_num] ==> [batch, sent_num, seq_len] 277 | # sentence_masks : [10, 512] ==> [10, 512, 40] ==> [10, 40, 512] 278 | # [sentence_masks] = [[0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, ...]] 279 | 280 | div_term = torch.sum(sentence_masks, dim=-1, keepdim=True) 281 | # div_term : [batch, sent_num, 1] 282 | 283 | div_term = div_term.masked_fill(div_term == 0, 1e-10) 284 | 285 | attention_masks = div_term.masked_fill(div_term != 1e-10, 0).masked_fill(div_term == 1e-10, 1).view(1, 286 | batch_size * self.max_sent_num).bool() 287 | sentence_representation = sentence_masks.bmm(sequence_output) 288 | sentence_representation = sentence_representation / div_term 289 | # sentence_representation : [batch, sent_num, hidden] 290 | 291 | sentence_representation = self.dropout( 292 | sentence_representation.reshape(1, batch_size * self.max_sent_num, self.hidden_size)) 293 | 294 | attention_mask = attention_masks.float() 295 | # [0, 0, 0, 1, 1, [1, 0, 0, 0, 0 296 | # 0, 0, 0, 0, 1] 1, 0, 0, 0, 0] 297 | 298 | 299 | last_hidden = None 300 | decoder_inputs = torch.sum(sentence_representation[:, 0::self.max_sent_num, :], 1, keepdim=True) / batch_size 301 | encoder_outputs = sentence_representation 302 | attention_mask[:, 0::self.max_sent_num] = 1 303 | attention_mask = attention_mask.masked_fill(attention_mask == 1, -1e10).masked_fill(attention_mask == 0, 0) 304 | attention_mask = attention_mask.unsqueeze(0) 305 | evidence_sentences = [] 306 | for evidence_step in range(3): 307 | decoder_inputs, last_hidden, evidence_sentence,attn_outputs, attention_mask = self.gru(last_hidden, decoder_inputs, encoder_outputs, attention_mask, is_training) 308 | evidence_sentences.append(evidence_sentence) 309 | evidence_vector = decoder_inputs.squeeze().unsqueeze(-1) 310 | evidence_sentences = torch.stack(evidence_sentences, 0) 311 | 312 | evidence_sentences = evidence_sentences.unsqueeze(0).expand(batch_size, -1, -1) 313 | start_representation = self.start_dense(sequence_output) 314 | end_representation = self.end_dense(sequence_output) 315 | 316 | start_logits = start_representation.matmul(evidence_vector) 317 | end_logits = end_representation.matmul(evidence_vector) 318 | 319 | start_logits = start_logits.squeeze(-1) 320 | end_logits = end_logits.squeeze(-1) 321 | 322 | # outputs = (start_logits, end_logits) 323 | outputs = (start_logits, end_logits, evidence_sentences.squeeze(-1)) + outputs[1:] 324 | 325 | # 학습 시 326 | if start_positions is not None and end_positions is not None: 327 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 328 | # logg_fct 선언 329 | loss_fct = CrossEntropyLoss(reduction='none') 330 | 331 | # start/end에 대해 loss 계산 332 | start_loss = loss_fct(start_logits, start_positions) 333 | end_loss = loss_fct(end_logits, end_positions) 334 | 335 | # 최종 loss 계산 336 | span_loss = (start_loss + end_loss) / 2 337 | 338 | # outputs : (total_loss, start_logits, end_logits) 339 | outputs = (span_loss, ) + outputs 340 | 341 | return outputs # (loss), start_logits, end_logits 342 | 343 | class ElectraForQuestionAnswering_sent_evidence_trm_sampling_1028(ElectraPreTrainedModel): 344 | def __init__(self, config): 345 | super(ElectraForQuestionAnswering_sent_evidence_trm_sampling_1028, self).__init__(config) 346 | # 분류 해야할 라벨 개수 (start/end) 347 | self.num_labels = config.num_labels 348 | self.hidden_size = config.hidden_size 349 | 350 | # ELECTRA 모델 선언 351 | self.max_seq_length = config.max_position_embeddings 352 | self.electra = ElectraModel(config) 353 | self.max_sent_num = config.max_sent_num 354 | self.num_samples = config.num_samples 355 | self.start_dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size) 356 | self.end_dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size) 357 | 358 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 359 | 360 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 361 | self.gru = AttentionDecoder(self.hidden_size) 362 | 363 | # ELECTRA weight 초기화 364 | self.init_weights() 365 | 366 | def extract_answerable_sent_idx(self, answer_positions, sent_masks): 367 | batch_size = sent_masks.size(0) 368 | answerable_sent_id = torch.gather(sent_masks, 1, answer_positions.unsqueeze(-1)).squeeze(-1) 369 | expanded_answerable_sent_id = F.one_hot(answerable_sent_id, num_classes=self.max_sent_num).reshape( 370 | batch_size * self.max_sent_num) 371 | expanded_answerable_sent_id[0::self.max_sent_num] = 0 372 | answerable_sent_ids= torch.where(expanded_answerable_sent_id ==1)[0] 373 | return answerable_sent_ids 374 | 375 | def _path_generate(self, hop_score, p_mask, n_mask, answerable_sent_num=None): 376 | if answerable_sent_num: 377 | n_negative_hop_score = hop_score + n_mask 378 | n_positive_hop_score = hop_score + p_mask 379 | _, n_hop_negative_path_idx = torch.sort(n_negative_hop_score, descending=True) 380 | n_hop_negative_path_idx = n_hop_negative_path_idx[:3] 381 | _, n_hop_positive_path_idx = torch.sort(n_positive_hop_score, descending=True) 382 | n_hop_positive_path_idx = n_hop_positive_path_idx[:answerable_sent_num] 383 | path_idx = torch.cat([n_hop_negative_path_idx, n_hop_positive_path_idx]) 384 | path_label = torch.cat([torch.ones([3]).float(), torch.ones([answerable_sent_num])]).cuda() 385 | path_logits = torch.gather(hop_score, 0, path_idx) 386 | return path_logits, path_idx, path_label 387 | 388 | 389 | else: 390 | n_hop_score = hop_score + n_mask 391 | _, n_hop_path_idx = torch.sort(n_hop_score, descending=True) 392 | n_hop_path_idx = n_hop_path_idx[:1] 393 | path_logits = torch.gather(hop_score, 0, n_hop_path_idx) 394 | return path_logits, n_hop_path_idx, None 395 | 396 | def _cross_entropy(self, logits): 397 | loss1 = -(F.log_softmax(logits, dim=-1)) 398 | # print(to_list(loss1)) 399 | loss = loss1 400 | return loss 401 | 402 | def forward( 403 | self, 404 | input_ids=None, 405 | 406 | ################## 407 | attention_mask=None, 408 | token_type_ids=None, 409 | sent_masks=None, 410 | ############## 411 | 412 | start_positions=None, 413 | end_positions=None, 414 | ): 415 | 416 | # ELECTRA output 저장 417 | # outputs : [1, batch_size, seq_length, hidden_size] 418 | # electra 선언 부분에 특정 옵션을 부여할 경우 출력은 다음과 같음 419 | # outputs : (last-layer hidden state, all hidden states, all attentions) 420 | # last-layer hidden state : [batch, seq_length, hidden_size] 421 | # all hidden states : [13, batch, seq_length, hidden_size] 422 | # 12가 아닌 13인 이유?? ==> 토큰의 임베딩도 output에 포함되어 return 423 | # all attentions : [12, batch, num_heads, seq_length, seq_length] 424 | batch_size = input_ids.size(0) 425 | sequence_length = input_ids.size(1) 426 | outputs = self.electra( 427 | input_ids, 428 | attention_mask=attention_mask, 429 | token_type_ids=token_type_ids, 430 | ) 431 | is_training = False 432 | if start_positions is not None: 433 | is_training=True 434 | 435 | sequence_output = outputs[0] 436 | # sequence_output : [batch_size, seq_length, hidden_size] 437 | 438 | cls_output = sequence_output[:, 0, :] 439 | # cls_output : [batch, hidden] 440 | all_cls_output = torch.sum(cls_output, dim=0, keepdim=True) / batch_size 441 | sentence_masks = F.one_hot(sent_masks, num_classes=self.max_sent_num).transpose(1, 2).float() 442 | sentence_masks[:, 0, :] = sentence_masks[:, 0, :] * attention_mask 443 | # sentence_masks : [batch, seq_length] ==> [batch, seq_len, sent_num] ==> [batch, sent_num, seq_len] 444 | # sentence_masks : [10, 512] ==> [10, 512, 40] ==> [10, 40, 512] 445 | # [sentence_masks] = [[0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, ...]] 446 | 447 | div_term = torch.sum(sentence_masks, dim=-1, keepdim=True) 448 | # div_term : [batch, sent_num, 1] 449 | 450 | div_term = div_term.masked_fill(div_term == 0, 1e-10) 451 | 452 | attention_masks = div_term.masked_fill(div_term != 1e-10, 0).masked_fill(div_term == 1e-10, 1).view(1, 453 | batch_size * self.max_sent_num).bool() 454 | sentence_representation = sentence_masks.bmm(sequence_output) 455 | sentence_representation = sentence_representation / div_term 456 | # sentence_representation : [batch, sent_num, hidden] 457 | 458 | sentence_representation = self.dropout( 459 | sentence_representation.reshape(1, batch_size * self.max_sent_num, self.hidden_size)) 460 | 461 | attention_masks = attention_masks.float() 462 | # [0, 0, 0, 1, 1, [1, 0, 0, 0, 0 463 | # 0, 0, 0, 0, 1] 1, 0, 0, 0, 0] 464 | 465 | 466 | last_hidden = None 467 | # decoder_inputs = all_cls_output.unsqueeze(1) 468 | decoder_inputs = torch.sum(sentence_representation[:, 0::self.max_sent_num, :], 1, keepdim=True) / batch_size 469 | encoder_outputs = sentence_representation 470 | attention_masks[:, 0::self.max_sent_num] = 1 471 | mm = 1-attention_masks 472 | mm = mm.unsqueeze(1).expand(-1, 3, -1) 473 | attention_masks = attention_masks.masked_fill(attention_masks == 1, -1e10).masked_fill(attention_masks == 0, 0).unsqueeze(0) 474 | # if is_training: 475 | # decoder_inputs = decoder_inputs.expand(self.num_samples, -1, -1) 476 | # encoder_outputs = encoder_outputs.expand(self.num_samples, -1, -1) 477 | # attention_masks = attention_masks.expand(self.num_samples, -1, -1) 478 | # attention_masks = attention_masks.clone().detach() 479 | evidence_sentences = [] 480 | attention_scores = [] 481 | for evidence_step in range(3): 482 | decoder_inputs, last_hidden, evidence_sentence, attention_score, attention_masks = self.gru(last_hidden, decoder_inputs, encoder_outputs, attention_masks) 483 | evidence_sentences.append(evidence_sentence) 484 | attention_scores.append(attention_score) 485 | evidence_vector = decoder_inputs.squeeze(1).transpose(0, 1) 486 | evidence_sentences = torch.stack(evidence_sentences, 0) 487 | attention_scores = torch.stack(attention_scores, 0) 488 | 489 | 490 | evidence_sentences = evidence_sentences.transpose(0, 1) 491 | attention_scores = attention_scores.transpose(0, 1) 492 | if is_training: 493 | evidence = evidence_sentences.unsqueeze(0) 494 | else: 495 | evidence = evidence_sentences.unsqueeze(0).expand(batch_size, -1, -1) 496 | 497 | start_representation = self.start_dense(sequence_output) 498 | end_representation = self.end_dense(sequence_output) 499 | 500 | start_logits = start_representation.matmul(evidence_vector) 501 | end_logits = end_representation.matmul(evidence_vector) 502 | 503 | start_logits = start_logits.squeeze(-1) 504 | end_logits = end_logits.squeeze(-1) 505 | 506 | 507 | # outputs = (start_logits, end_logits) 508 | outputs = (start_logits, end_logits, evidence.squeeze(1)) + outputs[1:] 509 | 510 | # 학습 시 511 | if start_positions is not None and end_positions is not None: 512 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 513 | # logg_fct 선언 514 | 515 | loss_fct = CrossEntropyLoss() 516 | 517 | # start/end에 대해 loss 계산 518 | start_loss = loss_fct(start_logits, start_positions) 519 | end_loss = loss_fct(end_logits, end_positions) 520 | 521 | # 최종 loss 계산 522 | span_loss = (start_loss + end_loss) / 2 523 | 524 | # outputs : (total_loss, start_logits, end_logits) 525 | outputs = (span_loss, attention_scores, mm) + outputs 526 | 527 | return outputs # (loss), start_logits, end_logits 528 | 529 | class ElectraForQuestionAnswering_sent_evidence_final(ElectraPreTrainedModel): 530 | def __init__(self, config): 531 | super(ElectraForQuestionAnswering_sent_evidence_final, self).__init__(config) 532 | # 분류 해야할 라벨 개수 (start/end) 533 | self.num_labels = config.num_labels 534 | self.hidden_size = config.hidden_size 535 | 536 | # ELECTRA 모델 선언 537 | self.max_seq_length = config.max_position_embeddings 538 | self.electra = ElectraModel(config) 539 | self.max_sent_num = config.max_sent_num 540 | self.num_samples = config.num_samples 541 | self.start_dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size) 542 | self.end_dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size) 543 | 544 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels+1) 545 | 546 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 547 | self.gru = AttentionDecoder(self.hidden_size) 548 | 549 | # ELECTRA weight 초기화 550 | self.init_weights() 551 | 552 | def extract_answerable_sent_idx(self, answer_positions, sent_masks): 553 | batch_size = sent_masks.size(0) 554 | answerable_sent_id = torch.gather(sent_masks, 1, answer_positions.unsqueeze(-1)).squeeze(-1) 555 | expanded_answerable_sent_id = F.one_hot(answerable_sent_id, num_classes=self.max_sent_num).reshape( 556 | batch_size * self.max_sent_num) 557 | expanded_answerable_sent_id[0::self.max_sent_num] = 0 558 | answerable_sent_ids= torch.where(expanded_answerable_sent_id ==1)[0] 559 | return answerable_sent_ids 560 | 561 | def _path_generate(self, hop_score, p_mask, n_mask, answerable_sent_num=None): 562 | if answerable_sent_num: 563 | n_negative_hop_score = hop_score + n_mask 564 | n_positive_hop_score = hop_score + p_mask 565 | _, n_hop_negative_path_idx = torch.sort(n_negative_hop_score, descending=True) 566 | n_hop_negative_path_idx = n_hop_negative_path_idx[:3] 567 | _, n_hop_positive_path_idx = torch.sort(n_positive_hop_score, descending=True) 568 | n_hop_positive_path_idx = n_hop_positive_path_idx[:answerable_sent_num] 569 | path_idx = torch.cat([n_hop_negative_path_idx, n_hop_positive_path_idx]) 570 | path_label = torch.cat([torch.ones([3]).float(), torch.ones([answerable_sent_num])]).cuda() 571 | path_logits = torch.gather(hop_score, 0, path_idx) 572 | return path_logits, path_idx, path_label 573 | 574 | 575 | else: 576 | n_hop_score = hop_score + n_mask 577 | _, n_hop_path_idx = torch.sort(n_hop_score, descending=True) 578 | n_hop_path_idx = n_hop_path_idx[:1] 579 | path_logits = torch.gather(hop_score, 0, n_hop_path_idx) 580 | return path_logits, n_hop_path_idx, None 581 | 582 | def _cross_entropy(self, logits): 583 | loss1 = -(F.log_softmax(logits, dim=-1)) 584 | # print(to_list(loss1)) 585 | loss = loss1 586 | return loss 587 | 588 | def forward( 589 | self, 590 | input_ids=None, 591 | 592 | ################## 593 | attention_mask=None, 594 | token_type_ids=None, 595 | sent_masks=None, 596 | ############## 597 | question_type=None, 598 | start_positions=None, 599 | end_positions=None, 600 | ): 601 | 602 | # ELECTRA output 저장 603 | # outputs : [1, batch_size, seq_length, hidden_size] 604 | # electra 선언 부분에 특정 옵션을 부여할 경우 출력은 다음과 같음 605 | # outputs : (last-layer hidden state, all hidden states, all attentions) 606 | # last-layer hidden state : [batch, seq_length, hidden_size] 607 | # all hidden states : [13, batch, seq_length, hidden_size] 608 | # 12가 아닌 13인 이유?? ==> 토큰의 임베딩도 output에 포함되어 return 609 | # all attentions : [12, batch, num_heads, seq_length, seq_length] 610 | batch_size = input_ids.size(0) 611 | sequence_length = input_ids.size(1) 612 | outputs = self.electra( 613 | input_ids, 614 | attention_mask=attention_mask, 615 | token_type_ids=token_type_ids, 616 | ) 617 | is_training = False 618 | if start_positions is not None: 619 | is_training=True 620 | 621 | sequence_output = outputs[0] 622 | # sequence_output : [batch_size, seq_length, hidden_size] 623 | 624 | cls_output = sequence_output[:, 0, :] 625 | # cls_output : [batch, hidden] 626 | all_cls_output = torch.sum(cls_output, dim=0, keepdim=True) / batch_size 627 | sentence_masks = F.one_hot(sent_masks, num_classes=self.max_sent_num).transpose(1, 2).float() 628 | sentence_masks[:, 0, :] = sentence_masks[:, 0, :] * attention_mask 629 | # sentence_masks : [batch, seq_length] ==> [batch, seq_len, sent_num] ==> [batch, sent_num, seq_len] 630 | # sentence_masks : [10, 512] ==> [10, 512, 40] ==> [10, 40, 512] 631 | # [sentence_masks] = [[0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, ...]] 632 | 633 | div_term = torch.sum(sentence_masks, dim=-1, keepdim=True) 634 | # div_term : [batch, sent_num, 1] 635 | 636 | div_term = div_term.masked_fill(div_term == 0, 1e-10) 637 | 638 | attention_masks = div_term.masked_fill(div_term != 1e-10, 0).masked_fill(div_term == 1e-10, 1).view(1, 639 | batch_size * self.max_sent_num).bool() 640 | sentence_representation = sentence_masks.bmm(sequence_output) 641 | sentence_representation = sentence_representation / div_term 642 | # sentence_representation : [batch, sent_num, hidden] 643 | 644 | sentence_representation = self.dropout( 645 | sentence_representation.reshape(1, batch_size * self.max_sent_num, self.hidden_size)) 646 | 647 | attention_masks = attention_masks.float() 648 | # [0, 0, 0, 1, 1, [1, 0, 0, 0, 0 649 | # 0, 0, 0, 0, 1] 1, 0, 0, 0, 0] 650 | 651 | 652 | last_hidden = None 653 | # decoder_inputs = all_cls_output.unsqueeze(1) 654 | decoder_inputs = torch.sum(sentence_representation[:, 0::self.max_sent_num, :], 1, keepdim=True) / batch_size 655 | encoder_outputs = sentence_representation 656 | attention_masks[:, 0::self.max_sent_num] = 1 657 | mm = 1-attention_masks 658 | mm = mm.unsqueeze(1).expand(-1, 3, -1) 659 | attention_masks = attention_masks.masked_fill(attention_masks == 1, -1e10).masked_fill(attention_masks == 0, 0).unsqueeze(0) 660 | # if is_training: 661 | # decoder_inputs = decoder_inputs.expand(self.num_samples, -1, -1) 662 | # encoder_outputs = encoder_outputs.expand(self.num_samples, -1, -1) 663 | # attention_masks = attention_masks.expand(self.num_samples, -1, -1) 664 | # attention_masks = attention_masks.clone().detach() 665 | evidence_sentences = [] 666 | attention_scores = [] 667 | for evidence_step in range(3): 668 | decoder_inputs, last_hidden, evidence_sentence, attention_score, attention_masks = self.gru(last_hidden, decoder_inputs, encoder_outputs, attention_masks) 669 | evidence_sentences.append(evidence_sentence) 670 | attention_scores.append(attention_score) 671 | evidence_vector = decoder_inputs.squeeze(1).transpose(0, 1) 672 | evidence_sentences = torch.stack(evidence_sentences, 0) 673 | attention_scores = torch.stack(attention_scores, 0) 674 | 675 | 676 | evidence_sentences = evidence_sentences.transpose(0, 1) 677 | attention_scores = attention_scores.transpose(0, 1) 678 | if is_training: 679 | evidence = evidence_sentences.unsqueeze(0) 680 | else: 681 | evidence = evidence_sentences.unsqueeze(0).expand(batch_size, -1, -1) 682 | 683 | start_representation = self.start_dense(sequence_output) 684 | end_representation = self.end_dense(sequence_output) 685 | 686 | start_logits = start_representation.matmul(evidence_vector) 687 | end_logits = end_representation.matmul(evidence_vector) 688 | 689 | start_logits = start_logits.squeeze(-1) 690 | end_logits = end_logits.squeeze(-1) 691 | 692 | qt_logits = self.qa_outputs(decoder_inputs.squeeze(1)) 693 | # outputs = (start_logits, end_logits) 694 | if not is_training: 695 | qt_logits = torch.argmax(qt_logits.expand(batch_size, -1), -1) 696 | outputs = (start_logits, end_logits, qt_logits, evidence.squeeze(1)) + outputs[1:] 697 | 698 | # 학습 시 699 | if start_positions is not None and end_positions is not None: 700 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 701 | # logg_fct 선언 702 | 703 | loss_fct = CrossEntropyLoss() 704 | 705 | # start/end에 대해 loss 계산 706 | start_loss = loss_fct(start_logits, start_positions) 707 | end_loss = loss_fct(end_logits, end_positions) 708 | qt_loss = loss_fct(qt_logits, question_type) 709 | # 최종 loss 계산 710 | span_loss = (start_loss + end_loss) / 2 + qt_loss 711 | 712 | # outputs : (total_loss, start_logits, end_logits) 713 | outputs = (span_loss, attention_scores, mm) + outputs 714 | 715 | return outputs # (loss), start_logits, end_logits 716 | 717 | class ElectraForQuestionAnswering_1204(ElectraPreTrainedModel): 718 | def __init__(self, config): 719 | super(ElectraForQuestionAnswering_1204, self).__init__(config) 720 | # 분류 해야할 라벨 개수 (start/end) 721 | self.num_labels = config.num_labels 722 | self.hidden_size = config.hidden_size 723 | 724 | # ELECTRA 모델 선언 725 | self.max_seq_length = config.max_position_embeddings 726 | self.electra = ElectraModel(config) 727 | self.max_sent_num = config.max_sent_num 728 | self.max_dec_len = config.max_dec_len 729 | self.num_samples = config.num_samples 730 | self.start_dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size) 731 | self.end_dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size) 732 | 733 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 734 | 735 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 736 | self.gru = SampledAttentionDecoder1204(self.hidden_size) 737 | 738 | # ELECTRA weight 초기화 739 | self.init_weights() 740 | 741 | def extract_answerable_sent_idx(self, answer_positions, sent_masks): 742 | batch_size = sent_masks.size(0) 743 | answerable_sent_id = torch.gather(sent_masks, 1, answer_positions.unsqueeze(-1)).squeeze(-1) 744 | expanded_answerable_sent_id = F.one_hot(answerable_sent_id, num_classes=self.max_sent_num).reshape( 745 | batch_size * self.max_sent_num) 746 | expanded_answerable_sent_id[0::self.max_sent_num] = 0 747 | answerable_sent_ids= torch.where(expanded_answerable_sent_id ==1)[0] 748 | return answerable_sent_ids 749 | 750 | def _path_generate(self, hop_score, p_mask, n_mask, answerable_sent_num=None): 751 | if answerable_sent_num: 752 | n_negative_hop_score = hop_score + n_mask 753 | n_positive_hop_score = hop_score + p_mask 754 | _, n_hop_negative_path_idx = torch.sort(n_negative_hop_score, descending=True) 755 | n_hop_negative_path_idx = n_hop_negative_path_idx[:3] 756 | _, n_hop_positive_path_idx = torch.sort(n_positive_hop_score, descending=True) 757 | n_hop_positive_path_idx = n_hop_positive_path_idx[:answerable_sent_num] 758 | path_idx = torch.cat([n_hop_negative_path_idx, n_hop_positive_path_idx]) 759 | path_label = torch.cat([torch.ones([3]).float(), torch.ones([answerable_sent_num])]).cuda() 760 | path_logits = torch.gather(hop_score, 0, path_idx) 761 | return path_logits, path_idx, path_label 762 | 763 | 764 | else: 765 | n_hop_score = hop_score + n_mask 766 | _, n_hop_path_idx = torch.sort(n_hop_score, descending=True) 767 | n_hop_path_idx = n_hop_path_idx[:1] 768 | path_logits = torch.gather(hop_score, 0, n_hop_path_idx) 769 | return path_logits, n_hop_path_idx, None 770 | 771 | def _cross_entropy(self, logits): 772 | loss1 = -(F.log_softmax(logits, dim=-1)) 773 | # print(to_list(loss1)) 774 | loss = loss1 775 | return loss 776 | 777 | def forward( 778 | self, 779 | input_ids=None, 780 | 781 | ################## 782 | attention_mask=None, 783 | token_type_ids=None, 784 | sent_masks=None, 785 | ############## 786 | 787 | start_positions=None, 788 | end_positions=None, 789 | ): 790 | 791 | # ELECTRA output 저장 792 | # outputs : [1, batch_size, seq_length, hidden_size] 793 | # electra 선언 부분에 특정 옵션을 부여할 경우 출력은 다음과 같음 794 | # outputs : (last-layer hidden state, all hidden states, all attentions) 795 | # last-layer hidden state : [batch, seq_length, hidden_size] 796 | # all hidden states : [13, batch, seq_length, hidden_size] 797 | # 12가 아닌 13인 이유?? ==> 토큰의 임베딩도 output에 포함되어 return 798 | # all attentions : [12, batch, num_heads, seq_length, seq_length] 799 | batch_size = input_ids.size(0) 800 | sequence_length = input_ids.size(1) 801 | outputs = self.electra( 802 | input_ids, 803 | attention_mask=attention_mask, 804 | token_type_ids=token_type_ids, 805 | ) 806 | is_training = False 807 | if start_positions is not None: 808 | is_training=True 809 | 810 | sequence_output = outputs[0] 811 | # sequence_output : [batch_size, seq_length, hidden_size] 812 | 813 | cls_output = sequence_output[:, 0, :] 814 | # cls_output : [batch, hidden] 815 | all_cls_output = torch.sum(cls_output, dim=0, keepdim=True) / batch_size 816 | sentence_masks = F.one_hot(sent_masks, num_classes=self.max_sent_num).transpose(1, 2).float() 817 | sentence_masks[:, 0, :] = sentence_masks[:, 0, :] * attention_mask 818 | # sentence_masks : [batch, seq_length] ==> [batch, seq_len, sent_num] ==> [batch, sent_num, seq_len] 819 | # sentence_masks : [10, 512] ==> [10, 512, 40] ==> [10, 40, 512] 820 | # [sentence_masks] = [[0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, ...]] 821 | 822 | div_term = torch.sum(sentence_masks, dim=-1, keepdim=True) 823 | # div_term : [batch, sent_num, 1] 824 | 825 | div_term = div_term.masked_fill(div_term == 0, 1e-10) 826 | 827 | attention_masks = div_term.masked_fill(div_term != 1e-10, 0).masked_fill(div_term == 1e-10, 1).view(1, 828 | batch_size * self.max_sent_num).bool() 829 | sentence_representation = sentence_masks.bmm(sequence_output) 830 | sentence_representation = sentence_representation / div_term 831 | # sentence_representation : [batch, sent_num, hidden] 832 | 833 | sentence_representation = self.dropout( 834 | sentence_representation.reshape(1, batch_size * self.max_sent_num, self.hidden_size)) 835 | 836 | attention_masks = attention_masks.float() 837 | # [0, 0, 0, 1, 1, [1, 0, 0, 0, 0 838 | # 0, 0, 0, 0, 1] 1, 0, 0, 0, 0] 839 | 840 | 841 | last_hidden = None 842 | # decoder_inputs = all_cls_output.unsqueeze(1) 843 | decoder_inputs = torch.sum(sentence_representation[:, 0::self.max_sent_num, :], 1, keepdim=True) / batch_size 844 | encoder_outputs = sentence_representation 845 | attention_masks[:, 0::self.max_sent_num] = 1 846 | mm = 1-attention_masks 847 | mm = mm.unsqueeze(1).expand(-1, self.max_dec_len, -1) 848 | attention_masks = attention_masks.masked_fill(attention_masks == 1, -1e10).masked_fill(attention_masks == 0, 0).unsqueeze(0) 849 | if is_training: 850 | decoder_inputs = decoder_inputs.expand(self.num_samples, -1, -1) 851 | encoder_outputs = encoder_outputs.expand(self.num_samples, -1, -1) 852 | attention_masks = attention_masks.expand(self.num_samples, -1, -1) 853 | attention_masks = attention_masks.clone().detach() 854 | evidence_sentences = [] 855 | attention_scores = [] 856 | for evidence_step in range(self.max_dec_len): 857 | decoder_inputs, last_hidden, evidence_sentence, attention_score, attention_masks = self.gru(last_hidden, decoder_inputs, encoder_outputs, attention_masks) 858 | evidence_sentences.append(evidence_sentence) 859 | attention_scores.append(attention_score) 860 | evidence_vector = decoder_inputs.squeeze(1).transpose(0, 1) 861 | evidence_sentences = torch.stack(evidence_sentences, 0) 862 | attention_scores = torch.stack(attention_scores, 0) 863 | 864 | 865 | evidence_sentences = evidence_sentences.transpose(0, 1) 866 | attention_scores = attention_scores.transpose(0, 1) 867 | if not is_training: 868 | evidence = evidence_sentences.unsqueeze(0).expand(batch_size, -1, -1) 869 | else: 870 | evidence = evidence_sentences 871 | start_representation = self.start_dense(sequence_output) 872 | end_representation = self.end_dense(sequence_output) 873 | 874 | start_logits = start_representation.matmul(evidence_vector) 875 | end_logits = end_representation.matmul(evidence_vector) 876 | if not is_training: 877 | start_logits = start_logits.squeeze(-1) 878 | end_logits = end_logits.squeeze(-1) 879 | 880 | 881 | # outputs = (start_logits, end_logits) 882 | outputs = (start_logits, end_logits, evidence.squeeze(1)) + outputs[1:] 883 | 884 | # 학습 시 885 | if start_positions is not None and end_positions is not None: 886 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 887 | # logg_fct 선언 888 | 889 | loss_fct = CrossEntropyLoss() 890 | start_logits = start_logits.permute(2, 0, 1) 891 | end_logits = end_logits.permute(2, 0, 1) 892 | start_positions = start_positions.unsqueeze(0).expand(self.num_samples, -1) 893 | end_positions = end_positions.unsqueeze(0).expand(self.num_samples, -1) 894 | # start/end에 대해 loss 계산 895 | start_loss = loss_fct(start_logits.reshape(batch_size*self.num_samples, sequence_length), start_positions.reshape(batch_size*self.num_samples)) 896 | end_loss = loss_fct(end_logits.reshape(batch_size*self.num_samples, sequence_length), end_positions.reshape(batch_size*self.num_samples)) 897 | 898 | # 최종 loss 계산 899 | span_loss = (start_loss + end_loss) / 2 900 | 901 | # outputs : (total_loss, start_logits, end_logits) 902 | outputs = (span_loss, attention_scores, mm) + outputs 903 | 904 | return outputs # (loss), start_logits, end_logits 905 | class ElectraForQuestionAnswering_1208(ElectraPreTrainedModel): 906 | def __init__(self, config): 907 | super(ElectraForQuestionAnswering_1208, self).__init__(config) 908 | # 분류 해야할 라벨 개수 (start/end) 909 | self.num_labels = config.num_labels 910 | self.hidden_size = config.hidden_size 911 | 912 | # ELECTRA 모델 선언 913 | self.max_seq_length = config.max_position_embeddings 914 | self.electra = ElectraModel(config) 915 | self.max_sent_num = config.max_sent_num 916 | self.max_dec_len = config.max_dec_len 917 | self.num_samples = config.num_samples 918 | self.start_dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size) 919 | self.end_dense = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size) 920 | 921 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 922 | 923 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 924 | self.gru = SampledAttentionDecoder1204(self.hidden_size) 925 | 926 | # ELECTRA weight 초기화 927 | self.init_weights() 928 | 929 | def forward( 930 | self, 931 | input_ids=None, 932 | 933 | ################## 934 | attention_mask=None, 935 | token_type_ids=None, 936 | sent_masks=None, 937 | ############## 938 | 939 | start_positions=None, 940 | end_positions=None, 941 | ): 942 | 943 | # ELECTRA output 저장 944 | # outputs : [1, batch_size, seq_length, hidden_size] 945 | # electra 선언 부분에 특정 옵션을 부여할 경우 출력은 다음과 같음 946 | # outputs : (last-layer hidden state, all hidden states, all attentions) 947 | # last-layer hidden state : [batch, seq_length, hidden_size] 948 | # all hidden states : [13, batch, seq_length, hidden_size] 949 | # 12가 아닌 13인 이유?? ==> 토큰의 임베딩도 output에 포함되어 return 950 | # all attentions : [12, batch, num_heads, seq_length, seq_length] 951 | batch_size = input_ids.size(0) 952 | sequence_length = input_ids.size(1) 953 | outputs = self.electra( 954 | input_ids, 955 | attention_mask=attention_mask, 956 | token_type_ids=token_type_ids, 957 | ) 958 | is_training = False 959 | if start_positions is not None: 960 | is_training=True 961 | 962 | sequence_output = outputs[0] 963 | # sequence_output : [batch_size, seq_length, hidden_size] 964 | 965 | cls_output = sequence_output[:, 0, :] 966 | # cls_output : [batch, hidden] 967 | all_cls_output = torch.sum(cls_output, dim=0, keepdim=True) / batch_size 968 | sentence_masks = F.one_hot(sent_masks, num_classes=self.max_sent_num).transpose(1, 2).float() 969 | sentence_masks[:, 0, :] = sentence_masks[:, 0, :] * attention_mask 970 | # sentence_masks : [batch, seq_length] ==> [batch, seq_len, sent_num] ==> [batch, sent_num, seq_len] 971 | # sentence_masks : [10, 512] ==> [10, 512, 40] ==> [10, 40, 512] 972 | # [sentence_masks] = [[0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, ...]] 973 | 974 | div_term = torch.sum(sentence_masks, dim=-1, keepdim=True) 975 | # div_term : [batch, sent_num, 1] 976 | 977 | div_term = div_term.masked_fill(div_term == 0, 1e-10) 978 | 979 | attention_masks = div_term.masked_fill(div_term != 1e-10, 0).masked_fill(div_term == 1e-10, 1).view(1, 980 | batch_size * self.max_sent_num).bool() 981 | sentence_representation = sentence_masks.bmm(sequence_output) 982 | sentence_representation = sentence_representation / div_term 983 | # sentence_representation : [batch, sent_num, hidden] 984 | 985 | sentence_representation = self.dropout( 986 | sentence_representation.reshape(1, batch_size * self.max_sent_num, self.hidden_size)) 987 | 988 | attention_masks = attention_masks.float() 989 | # [0, 0, 0, 1, 1, [1, 0, 0, 0, 0 990 | # 0, 0, 0, 0, 1] 1, 0, 0, 0, 0] 991 | 992 | 993 | last_hidden = None 994 | sampled_last_hidden = None 995 | # decoder_inputs = all_cls_output.unsqueeze(1) 996 | decoder_inputs = torch.sum(sentence_representation[:, 0::self.max_sent_num, :], 1, keepdim=True) / batch_size 997 | sampled_decoder_inputs = torch.sum(sentence_representation[:, 0::self.max_sent_num, :], 1, keepdim=True) / batch_size 998 | encoder_outputs = sentence_representation 999 | attention_masks[:, 0::self.max_sent_num] = 1 1000 | mm = 1-attention_masks 1001 | mm = mm.unsqueeze(1).expand(-1, self.max_dec_len, -1) 1002 | attention_masks = attention_masks.masked_fill(attention_masks == 1, -1e10).masked_fill(attention_masks == 0, 0).unsqueeze(0) 1003 | sampled_attention_masks = attention_masks.clone().detach() 1004 | evidence_sentences = [] 1005 | attention_scores = [] 1006 | sampled_evidence_sentences = [] 1007 | sampled_attention_scores = [] 1008 | for evidence_step in range(self.max_dec_len): 1009 | decoder_inputs, last_hidden, evidence_sentence, attention_score, attention_masks = self.gru(last_hidden, 1010 | decoder_inputs, 1011 | encoder_outputs, 1012 | attention_masks) 1013 | evidence_sentences.append(evidence_sentence) 1014 | attention_scores.append(attention_score) 1015 | evidence_vector = decoder_inputs.squeeze(1).transpose(0, 1) 1016 | evidence_sentences = torch.stack(evidence_sentences, 0) 1017 | attention_scores = torch.stack(attention_scores, 0) 1018 | # if is_training: 1019 | # sampled_decoder_inputs = sampled_decoder_inputs.expand(self.num_samples, -1, -1) 1020 | # encoder_outputs = encoder_outputs.expand(self.num_samples, -1, -1) 1021 | # sampled_attention_masks = sampled_attention_masks.expand(self.num_samples, -1, -1) 1022 | # sampled_attention_masks = sampled_attention_masks.clone().detach() 1023 | # 1024 | # for evidence_step in range(self.max_dec_len): 1025 | # sampled_decoder_inputs, sampled_last_hidden, evidence_sentence, attention_score, sampled_attention_masks = self.gru( 1026 | # sampled_last_hidden, sampled_decoder_inputs, encoder_outputs, sampled_attention_masks, is_sample=True) 1027 | # sampled_evidence_sentences.append(evidence_sentence) 1028 | # sampled_attention_scores.append(attention_score) 1029 | # sampled_evidence_vector = sampled_decoder_inputs.squeeze(1).transpose(0, 1) 1030 | # sampled_evidence_sentences = torch.stack(sampled_evidence_sentences, 0) 1031 | # sampled_attention_scores = torch.stack(sampled_attention_scores, 0) 1032 | # 1033 | # evidence_vector = torch.cat([evidence_vector, sampled_evidence_vector], -1) 1034 | # evidence_sentences = torch.cat([evidence_sentences, sampled_evidence_sentences], -1) 1035 | # attention_scores = torch.cat([attention_scores, sampled_attention_scores], 1) 1036 | 1037 | evidence_sentences = evidence_sentences.transpose(0, 1) 1038 | attention_scores = attention_scores.transpose(0, 1) 1039 | if not is_training: 1040 | evidence = evidence_sentences.unsqueeze(0).expand(batch_size, -1, -1) 1041 | else: 1042 | evidence = evidence_sentences 1043 | start_representation = self.start_dense(sequence_output) 1044 | end_representation = self.end_dense(sequence_output) 1045 | 1046 | start_logits = start_representation.matmul(evidence_vector) 1047 | end_logits = end_representation.matmul(evidence_vector) 1048 | if not is_training: 1049 | start_logits = start_logits.squeeze(-1) 1050 | end_logits = end_logits.squeeze(-1) 1051 | 1052 | 1053 | # outputs = (start_logits, end_logits) 1054 | outputs = (start_logits, end_logits, evidence.squeeze(1)) + outputs[1:] 1055 | 1056 | # 학습 시 1057 | if start_positions is not None and end_positions is not None: 1058 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1059 | # logg_fct 선언 1060 | 1061 | loss_fct = CrossEntropyLoss() 1062 | start_logits = start_logits.permute(2, 0, 1) 1063 | end_logits = end_logits.permute(2, 0, 1) 1064 | start_positions = start_positions.unsqueeze(0).expand((1+self.num_samples), -1) 1065 | end_positions = end_positions.unsqueeze(0).expand((1+self.num_samples), -1) 1066 | # start/end에 대해 loss 계산 1067 | start_loss = loss_fct(start_logits.reshape(batch_size*(1+self.num_samples), sequence_length), start_positions.reshape(batch_size*(1+self.num_samples))) 1068 | end_loss = loss_fct(end_logits.reshape(batch_size*(1+self.num_samples), sequence_length), end_positions.reshape(batch_size*(1+self.num_samples))) 1069 | 1070 | # 최종 loss 계산 1071 | span_loss = (start_loss + end_loss) / 2 1072 | 1073 | # outputs : (total_loss, start_logits, end_logits) 1074 | outputs = (span_loss, attention_scores, mm) + outputs 1075 | 1076 | return outputs # (loss), start_logits, end_logits 1077 | --------------------------------------------------------------------------------