├── README.md ├── data_refine.py ├── emji_tokenizer ├── emji_tokenizer-merges.txt ├── emji_tokenizer-vocab.json └── model.json ├── run_qg.py └── src ├── functions ├── __pycache__ │ ├── rouge.cpython-37.pyc │ └── utils.cpython-37.pyc ├── mrc_processor.py ├── processor.py ├── rouge.py └── utils.py └── model ├── __pycache__ ├── custom_bart.cpython-37.pyc ├── main_functions.cpython-37.pyc └── model.cpython-37.pyc ├── custom_bart.py ├── main_functions.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # Question_Generation 2 | 3 | # Dependencies 4 | * python 3.7 5 | * PyTorch 1.6.0 6 | * Transformers 4.3.3 7 | * AttrDict 8 | 9 | # Model Architecture 10 | 11 | # Data 12 | * KLUE Machine Reading Comprehension [Click](https://klue-benchmark.com/tasks/72/data/download) 13 | 14 | # Train & Test 15 | * python3.7 run_qg --train_file [train file] --test_file [test_file] --from_init_weight --do_train 16 | * python3.7 run_qg --test_file [test_file] --do_evaluate 17 | -------------------------------------------------------------------------------- /data_refine.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | # f_list = ['train', 'test'] 3 | # 4 | # for f_name in f_list: 5 | # with open('./data/{}.txt'.format(f_name),'r',encoding='utf8') as infile, open('./data/baseline_{}with_ans_symbol_and_title.txt'.format(f_name),'w',encoding='utf8') as outfile: 6 | # data_dict = {} 7 | # for line in infile: 8 | # parsed_context, answer, question, orig_context = line.split('\t') 9 | # title, context = parsed_context.split(" [title] ") 10 | # title = ''.join(title.split()).replace("_", " ").strip() 11 | # # context = ''.join(context.split()).replace("_", " ").strip() 12 | # context = ''.join(context.split()).replace("_", " ").replace(title, "").strip() 13 | # answer = ''.join(answer.split()).replace("_", " ").strip() 14 | # # answer_start = context.index("[answer]") 15 | # context = context.replace("[answer]", "") 16 | # # question = ''.join(question.split()).replace("_", " ").strip() 17 | # question = ''.join(question.split()).replace("_", " ").replace(title, "").strip() 18 | # # print(title) 19 | # # print(context) 20 | # # print(answer) 21 | # # print(question) 22 | # outfile.write("\t".join([title, answer, context, question]) + '\n') 23 | # # 24 | # # from kobart import get_pytorch_kobart_model, get_kobart_tokenizer 25 | # # get_kobart_tokenizer(".") 26 | # # get_pytorch_kobart_model(cachedir=".") 27 | 28 | ########################################################################################################################### 29 | import json 30 | import nltk 31 | 32 | def sent_tokenizer(context): 33 | char_to_sent_id = [] 34 | refine_context = context.split("\n") 35 | result_context = [] 36 | for _, r_context in enumerate(refine_context): 37 | if not r_context: 38 | rr_context = [""] 39 | else: 40 | rr_context = nltk.sent_tokenize(r_context) 41 | sent_id = len(result_context) 42 | for offset, sub_seq in enumerate(rr_context): 43 | char_to_sent_id += [sent_id+offset]*(len(sub_seq)+1) 44 | result_context+=rr_context 45 | refine_context = " ".join(result_context) 46 | return refine_context, result_context, char_to_sent_id 47 | # a = len(context.replace("\n", " ")) 48 | # b = len(' '.join(result_context)) 49 | # for idx, r_context in enumerate(refine_context): 50 | # if len(context[:len(" ".join(refine_context[:idx+1]))]) != len(" ".join(refine_context[:idx+1])): 51 | # print("#####################################") 52 | # print(idx) 53 | # print(context[:len(" ".join(refine_context[:idx+1]))]) 54 | # print(" ".join(refine_context[:idx+1])) 55 | # 56 | # if len(context.replace("\n", " ")) != len(' '.join(result_context)): 57 | # for e in range(len(context)): 58 | # if context.replace("\n", " ")[e] != ' '.join(result_context)[e]: 59 | # print(context.replace("\n", " ")[e-10:e+10]) 60 | # print(' '.join(result_context)[e-10:e+10]) 61 | # print(e) 62 | # print("?") 63 | def process(f_name, outfile): 64 | print('\n\n\n', f_name) 65 | 66 | num = 0 67 | nnum = 0 68 | a = 0 69 | b = 0 70 | # ofile = open('./ai_data/refine_{}.json'.format(f_name), 'w', encoding='utf8') 71 | with open('./ai_data/{}.json'.format(f_name), 'r', encoding='utf8') as infile: 72 | data_dict = json.load(infile) 73 | result_dict = {"data":[]} 74 | for document in tqdm(data_dict["data"]): 75 | title = document["title"] 76 | document_dict = {"title":title, "paragraphs":[]} 77 | for paragraph in document["paragraphs"]: 78 | 79 | context = " ".join([e for e in paragraph['context'].replace("\n", " ").split() if e]) 80 | refine_context, split_context, char_to_sent_id = sent_tokenizer(context) 81 | paragraph_dict = {"context":refine_context, "split_context":split_context, "qas":[]} 82 | for qas in paragraph["qas"]: 83 | id = qas['id'] 84 | question = qas['question'] 85 | level = qas['level'] 86 | 87 | answer = qas['answers'][0] 88 | answer_text = answer['text'] 89 | answer_start = answer['answer_start'] 90 | keyword_text = answer['keyword'] 91 | keyword_start = answer['keyword_start'] 92 | try: 93 | nnum+=1 94 | answerable_context = split_context[char_to_sent_id[answer_start]] 95 | # evidence_context = split_context[char_to_sent_id[keyword_start]] 96 | evidence_context = ' '.join(split_context[char_to_sent_id[max(0, answer_start)]-5: char_to_sent_id[answer_start]]) 97 | except: 98 | num+=1 99 | continue 100 | if answer_text not in answerable_context: 101 | continue 102 | # if keyword_text not in evidence_context: 103 | # continue 104 | if answer_text not in refine_context: 105 | continue 106 | #outfile.write("\t".join([title, answer, context, question]) + '\n') 107 | outfile.write("\t".join([id, level, title, answer_text.replace("\n", " "), evidence_context.replace("\n", " "), answerable_context.replace(answer_text, "").replace("\n", " "), question.replace("\n", " ")]) + '\n') 108 | # qas_dict = {'question':question, "id":id, "level":level, "answers":[{"text":answer_text, "answer_start":refine_context.index(answer_text), "keyword_text":keyword_text, "keyword_start":keyword_start}]} 109 | # 110 | # paragraph_dict["qas"].append(qas_dict) 111 | # document_dict["paragraphs"].append(paragraph_dict) 112 | # result_dict["data"].append(document_dict) 113 | # json.dump(result_dict, ofile, indent='\t', ensure_ascii=False) 114 | # 115 | # for data in data_dict: 116 | # q_id = data['question_id'].replace("\n", "") 117 | # question = data["question"].replace("\n", "") 118 | # if not data['evidence_sent']: 119 | # continue 120 | # evidence_sent = data['evidence_sent'][0].replace("\n", "") 121 | # 122 | # answer = answer_dict[q_id][0].replace("\n", "") 123 | # title = answer_dict[q_id][1] 124 | # context = [e.replace("\n", "") for e in data['splited_sent'] if answer in e.replace("\n", "")] 125 | # 126 | # if not context: 127 | # continue 128 | # context = " ".join(context).replace(answer, "") 129 | # num +=1 130 | # outfile.write("\t".join([title, answer, evidence_sent, context, question]) + '\n') 131 | 132 | # def answer_dict_load(f_name): 133 | # result_dict = {} 134 | # with open('./ai_data/{}.json'.format(f_name),'r',encoding='utf8') as infile: 135 | # data_dict = json.load(infile)["data"] 136 | # for document in data_dict: 137 | # title = document['title'] 138 | # for qas in document['qas']: 139 | # q_id = qas['id'] 140 | # answer = qas['answer']['answer_text'] 141 | # result_dict[q_id] = [answer, title] 142 | # 143 | # return result_dict 144 | f_list = ["all_outdomain_aug"] 145 | # 146 | for f_name in f_list: 147 | outfile = open('./processed_ai_data/{}.txt'.format(f_name), 'w', encoding='utf8') 148 | process(f_name, outfile) 149 | # 150 | # from kobart import get_pytorch_kobart_model, get_kobart_tokenizer 151 | # get_kobart_tokenizer(".") 152 | # get_pytorch_kobart_model(cachedir=".") -------------------------------------------------------------------------------- /run_qg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from src.model.model import BartForConditionalGeneration 3 | from transformers import PreTrainedTokenizerFast 4 | from transformers.models.bart.configuration_bart import BartConfig 5 | from src.model.main_functions import train, evaluate, make_file 6 | from attrdict import AttrDict 7 | import os 8 | 9 | from transformers import generation_utils 10 | def create_model(args): 11 | config = BartConfig.from_pretrained(args.model_path) 12 | tokenizer = PreTrainedTokenizerFast.from_pretrained(args.model_path) 13 | 14 | init_weight = args.model_path if args.from_init_weight else os.path.join(args.output_dir, "checkpoint-{}".format(args.checkpoint)) 15 | print("Init Weight From {}".format(init_weight)) 16 | model = BartForConditionalGeneration.from_pretrained(init_weight, config=config) 17 | 18 | model.to(args.device) 19 | 20 | return model, tokenizer 21 | 22 | def main(cil_args): 23 | args = AttrDict(vars(cil_args)) 24 | args.device = 'cuda' 25 | if not os.path.exists(args.output_dir): 26 | os.makedirs(args.output_dir) 27 | print(args) 28 | model, tokenizer = create_model(args) 29 | 30 | if args.do_train: 31 | train(args, model, tokenizer) 32 | elif args.do_evaluate: 33 | evaluate(args, model, tokenizer) 34 | elif args.do_predict: 35 | make_file(args, model, tokenizer) 36 | 37 | 38 | if __name__ == '__main__': 39 | cli_parser = argparse.ArgumentParser() 40 | 41 | # Path 42 | cli_parser.add_argument('--model_path', type=str, default='hyunwoongko/kobart', help='kobart model path') 43 | cli_parser.add_argument('--train_file', type=str, default='processed_ai_data/indomain_train.txt', help='train file') 44 | cli_parser.add_argument('--test_file', type=str, default='processed_ai_data/indomain_test.txt', help='test file') 45 | cli_parser.add_argument('--predict_file', type=str, default='processed_ai_data/all_outdomain_aug.txt', help='test file') 46 | cli_parser.add_argument('--tokenizer_path', type=str, default='emji_tokenizer', help='tokenizer') 47 | cli_parser.add_argument('--output_dir', type=str, default='./indomain', help='tokenizer') 48 | 49 | 50 | # Training Parameter 51 | cli_parser.add_argument("--weight_decay", type=float, default=0.0) 52 | 53 | cli_parser.add_argument('--num_workers', type=int, default=5, help='num of worker for dataloader (# of CPU Cores for Pre-processing)') 54 | cli_parser.add_argument('--learning_rate', type=float, default=5e-5, help='The initial learning rate') 55 | cli_parser.add_argument("--adam_epsilon", type=int, default=1e-10) 56 | cli_parser.add_argument("--gradient_accumulation_steps", type=int, default=1) 57 | cli_parser.add_argument("--max_grad_norm", type=int, default=1.0) 58 | 59 | cli_parser.add_argument("--seed", type=int, default=42) 60 | cli_parser.add_argument("--save_steps", type=int, default=2000) 61 | cli_parser.add_argument("--train_epochs", type=int, default=40) 62 | cli_parser.add_argument("--checkpoint", type=int, default=20000) 63 | cli_parser.add_argument('--batch_size', type=int, default=16, help='batch size for training') 64 | cli_parser.add_argument('--max_seq_len', type=int, default=128, help='max seq len') 65 | cli_parser.add_argument('--warmup_ratio', type=float, default=0.1, help='warmup ratio') 66 | 67 | cli_parser.add_argument('--from_init_weight', type=bool, default=False, help='init weight var') 68 | cli_parser.add_argument('--do_train', type=bool, default=False, help='Train Mode Bool Variable') 69 | cli_parser.add_argument('--do_evaluate', type=bool, default=False, help='Evaluate Mode Bool Variable') 70 | cli_parser.add_argument('--do_predict', type=bool, default=True, help='Predict Mode Bool Variable') 71 | cli_args = cli_parser.parse_args() 72 | main(cli_args) -------------------------------------------------------------------------------- /src/functions/__pycache__/rouge.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/Question_Generation/5ccd2388aa679d3ccff908efedf45b1196fcad90/src/functions/__pycache__/rouge.cpython-37.pyc -------------------------------------------------------------------------------- /src/functions/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/Question_Generation/5ccd2388aa679d3ccff908efedf45b1196fcad90/src/functions/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/functions/mrc_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from functools import partial 5 | from multiprocessing import Pool, cpu_count 6 | 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from transformers.file_utils import is_tf_available, is_torch_available 11 | from transformers.data.processors.utils import DataProcessor 12 | 13 | 14 | if is_torch_available(): 15 | import torch 16 | from torch.utils.data import TensorDataset 17 | 18 | if is_tf_available(): 19 | import tensorflow as tf 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer, orig_answer_text): 25 | """Returns tokenized answer spans that better match the annotated answer.""" 26 | tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text)) 27 | 28 | for new_start in range(input_start, input_end + 1): 29 | for new_end in range(input_end, new_start - 1, -1): 30 | text_span = " ".join(doc_tokens[new_start : (new_end + 1)]) 31 | if text_span == tok_answer_text: 32 | return (new_start, new_end) 33 | 34 | return (input_start, input_end) 35 | 36 | def _is_whitespace(c): 37 | if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F: 38 | return True 39 | return False 40 | 41 | 42 | def squad_convert_example_to_features( 43 | example, max_seq_length, doc_stride, max_query_length, padding_strategy, is_training 44 | ): 45 | features = [] 46 | if is_training and not example.is_impossible: 47 | # Get start and end position 48 | start_position = example.start_position 49 | end_position = example.end_position 50 | 51 | # If the answer cannot be found in the text, then skip this example. 52 | actual_text = " ".join(example.doc_tokens[start_position : (end_position + 1)]) 53 | cleaned_answer_text = " ".join(whitespace_tokenize(example.answer_text)) 54 | if actual_text.find(cleaned_answer_text) == -1: 55 | logger.warning("Could not find answer: '%s' vs. '%s'", actual_text, cleaned_answer_text) 56 | return [] 57 | 58 | tok_to_orig_index = [] 59 | orig_to_tok_index = [] 60 | all_doc_tokens = [] 61 | for (i, token) in enumerate(example.doc_tokens): 62 | orig_to_tok_index.append(len(all_doc_tokens)) 63 | if tokenizer.__class__.__name__ in [ 64 | "RobertaTokenizer", 65 | "LongformerTokenizer", 66 | "BartTokenizer", 67 | "RobertaTokenizerFast", 68 | "LongformerTokenizerFast", 69 | "BartTokenizerFast", 70 | ]: 71 | sub_tokens = tokenizer.tokenize(token, add_prefix_space=True) 72 | else: 73 | sub_tokens = tokenizer.tokenize(token) 74 | for sub_token in sub_tokens: 75 | tok_to_orig_index.append(i) 76 | all_doc_tokens.append(sub_token) 77 | 78 | if is_training and not example.is_impossible: 79 | tok_start_position = orig_to_tok_index[example.start_position] 80 | if example.end_position < len(example.doc_tokens) - 1: 81 | tok_end_position = orig_to_tok_index[example.end_position + 1] - 1 82 | else: 83 | tok_end_position = len(all_doc_tokens) - 1 84 | 85 | (tok_start_position, tok_end_position) = _improve_answer_span( 86 | all_doc_tokens, tok_start_position, tok_end_position, tokenizer, example.answer_text 87 | ) 88 | 89 | spans = [] 90 | 91 | truncated_query = tokenizer.encode( 92 | example.question_text, add_special_tokens=False, truncation=True, max_length=max_query_length 93 | ) 94 | 95 | # Tokenizers who insert 2 SEP tokens in-between & need to have special handling 96 | # in the way they compute mask of added tokens. 97 | tokenizer_type = type(tokenizer).__name__.replace("Tokenizer", "").lower() 98 | sequence_added_tokens = ( 99 | tokenizer.model_max_length - tokenizer.max_len_single_sentence + 1 100 | if tokenizer_type in MULTI_SEP_TOKENS_TOKENIZERS_SET 101 | else tokenizer.model_max_length - tokenizer.max_len_single_sentence 102 | ) 103 | sequence_pair_added_tokens = tokenizer.model_max_length - tokenizer.max_len_sentences_pair 104 | 105 | span_doc_tokens = all_doc_tokens 106 | while len(spans) * doc_stride < len(all_doc_tokens): 107 | 108 | # Define the side we want to truncate / pad and the text/pair sorting 109 | if tokenizer.padding_side == "right": 110 | texts = truncated_query 111 | pairs = span_doc_tokens 112 | truncation = TruncationStrategy.ONLY_SECOND.value 113 | else: 114 | texts = span_doc_tokens 115 | pairs = truncated_query 116 | truncation = TruncationStrategy.ONLY_FIRST.value 117 | 118 | encoded_dict = tokenizer.encode_plus( # TODO(thom) update this logic 119 | texts, 120 | pairs, 121 | truncation=truncation, 122 | padding=padding_strategy, 123 | max_length=max_seq_length, 124 | return_overflowing_tokens=True, 125 | stride=max_seq_length - doc_stride - len(truncated_query) - sequence_pair_added_tokens, 126 | return_token_type_ids=True, 127 | ) 128 | 129 | paragraph_len = min( 130 | len(all_doc_tokens) - len(spans) * doc_stride, 131 | max_seq_length - len(truncated_query) - sequence_pair_added_tokens, 132 | ) 133 | 134 | if tokenizer.pad_token_id in encoded_dict["input_ids"]: 135 | if tokenizer.padding_side == "right": 136 | non_padded_ids = encoded_dict["input_ids"][: encoded_dict["input_ids"].index(tokenizer.pad_token_id)] 137 | else: 138 | last_padding_id_position = ( 139 | len(encoded_dict["input_ids"]) - 1 - encoded_dict["input_ids"][::-1].index(tokenizer.pad_token_id) 140 | ) 141 | non_padded_ids = encoded_dict["input_ids"][last_padding_id_position + 1 :] 142 | 143 | else: 144 | non_padded_ids = encoded_dict["input_ids"] 145 | 146 | tokens = tokenizer.convert_ids_to_tokens(non_padded_ids) 147 | 148 | token_to_orig_map = {} 149 | for i in range(paragraph_len): 150 | index = len(truncated_query) + sequence_added_tokens + i if tokenizer.padding_side == "right" else i 151 | token_to_orig_map[index] = tok_to_orig_index[len(spans) * doc_stride + i] 152 | 153 | encoded_dict["paragraph_len"] = paragraph_len 154 | encoded_dict["tokens"] = tokens 155 | encoded_dict["token_to_orig_map"] = token_to_orig_map 156 | encoded_dict["truncated_query_with_special_tokens_length"] = len(truncated_query) + sequence_added_tokens 157 | encoded_dict["token_is_max_context"] = {} 158 | encoded_dict["start"] = len(spans) * doc_stride 159 | encoded_dict["length"] = paragraph_len 160 | 161 | spans.append(encoded_dict) 162 | 163 | if "overflowing_tokens" not in encoded_dict or ( 164 | "overflowing_tokens" in encoded_dict and len(encoded_dict["overflowing_tokens"]) == 0 165 | ): 166 | break 167 | span_doc_tokens = encoded_dict["overflowing_tokens"] 168 | 169 | for doc_span_index in range(len(spans)): 170 | for j in range(spans[doc_span_index]["paragraph_len"]): 171 | is_max_context = _new_check_is_max_context(spans, doc_span_index, doc_span_index * doc_stride + j) 172 | index = ( 173 | j 174 | if tokenizer.padding_side == "left" 175 | else spans[doc_span_index]["truncated_query_with_special_tokens_length"] + j 176 | ) 177 | spans[doc_span_index]["token_is_max_context"][index] = is_max_context 178 | 179 | for span in spans: 180 | # Identify the position of the CLS token 181 | cls_index = span["input_ids"].index(tokenizer.cls_token_id) 182 | 183 | # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) 184 | # Original TF implem also keep the classification token (set to 0) 185 | p_mask = np.ones_like(span["token_type_ids"]) 186 | if tokenizer.padding_side == "right": 187 | p_mask[len(truncated_query) + sequence_added_tokens :] = 0 188 | else: 189 | p_mask[-len(span["tokens"]) : -(len(truncated_query) + sequence_added_tokens)] = 0 190 | 191 | pad_token_indices = np.where(span["input_ids"] == tokenizer.pad_token_id) 192 | special_token_indices = np.asarray( 193 | tokenizer.get_special_tokens_mask(span["input_ids"], already_has_special_tokens=True) 194 | ).nonzero() 195 | 196 | p_mask[pad_token_indices] = 1 197 | p_mask[special_token_indices] = 1 198 | 199 | # Set the cls index to 0: the CLS index can be used for impossible answers 200 | p_mask[cls_index] = 0 201 | 202 | span_is_impossible = example.is_impossible 203 | start_position = 0 204 | end_position = 0 205 | if is_training and not span_is_impossible: 206 | # For training, if our document chunk does not contain an annotation 207 | # we throw it out, since there is nothing to predict. 208 | doc_start = span["start"] 209 | doc_end = span["start"] + span["length"] - 1 210 | out_of_span = False 211 | 212 | if not (tok_start_position >= doc_start and tok_end_position <= doc_end): 213 | out_of_span = True 214 | 215 | if out_of_span: 216 | start_position = cls_index 217 | end_position = cls_index 218 | span_is_impossible = True 219 | else: 220 | if tokenizer.padding_side == "left": 221 | doc_offset = 0 222 | else: 223 | doc_offset = len(truncated_query) + sequence_added_tokens 224 | 225 | start_position = tok_start_position - doc_start + doc_offset 226 | end_position = tok_end_position - doc_start + doc_offset 227 | 228 | features.append( 229 | SquadFeatures( 230 | span["input_ids"], 231 | span["attention_mask"], 232 | span["token_type_ids"], 233 | cls_index, 234 | p_mask.tolist(), 235 | example_index=0, # Can not set unique_id and example_index here. They will be set after multiple processing. 236 | unique_id=0, 237 | paragraph_len=span["paragraph_len"], 238 | token_is_max_context=span["token_is_max_context"], 239 | tokens=span["tokens"], 240 | token_to_orig_map=span["token_to_orig_map"], 241 | start_position=start_position, 242 | end_position=end_position, 243 | is_impossible=span_is_impossible, 244 | qas_id=example.qas_id, 245 | ) 246 | ) 247 | return features 248 | 249 | def squad_convert_example_to_features_init(tokenizer_for_convert: PreTrainedTokenizerBase): 250 | global tokenizer 251 | tokenizer = tokenizer_for_convert 252 | 253 | 254 | def squad_convert_examples_to_features( 255 | examples, 256 | tokenizer, 257 | max_seq_length, 258 | doc_stride, 259 | max_query_length, 260 | is_training, 261 | padding_strategy="max_length", 262 | return_dataset=False, 263 | threads=1, 264 | tqdm_enabled=True, 265 | ): 266 | """ 267 | Converts a list of examples into a list of features that can be directly given as input to a model. It is 268 | model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs. 269 | 270 | Args: 271 | examples: list of :class:`~transformers.data.processors.squad.SquadExample` 272 | tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer` 273 | max_seq_length: The maximum sequence length of the inputs. 274 | doc_stride: The stride used when the context is too large and is split across several features. 275 | max_query_length: The maximum length of the query. 276 | is_training: whether to create features for model evaluation or model training. 277 | padding_strategy: Default to "max_length". Which padding strategy to use 278 | return_dataset: Default False. Either 'pt' or 'tf'. 279 | if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset 280 | threads: multiple processing threads. 281 | 282 | 283 | Returns: 284 | list of :class:`~transformers.data.processors.squad.SquadFeatures` 285 | 286 | Example:: 287 | 288 | processor = SquadV2Processor() 289 | examples = processor.get_dev_examples(data_dir) 290 | 291 | features = squad_convert_examples_to_features( 292 | examples=examples, 293 | tokenizer=tokenizer, 294 | max_seq_length=args.max_seq_length, 295 | doc_stride=args.doc_stride, 296 | max_query_length=args.max_query_length, 297 | is_training=not evaluate, 298 | ) 299 | """ 300 | # Defining helper methods 301 | features = [] 302 | 303 | threads = min(threads, cpu_count()) 304 | with Pool(threads, initializer=squad_convert_example_to_features_init, initargs=(tokenizer,)) as p: 305 | annotate_ = partial( 306 | squad_convert_example_to_features, 307 | max_seq_length=max_seq_length, 308 | doc_stride=doc_stride, 309 | max_query_length=max_query_length, 310 | padding_strategy=padding_strategy, 311 | is_training=is_training, 312 | ) 313 | features = list( 314 | tqdm( 315 | p.imap(annotate_, examples, chunksize=32), 316 | total=len(examples), 317 | desc="convert squad examples to features", 318 | disable=not tqdm_enabled, 319 | ) 320 | ) 321 | 322 | new_features = [] 323 | unique_id = 1000000000 324 | example_index = 0 325 | for example_features in tqdm( 326 | features, total=len(features), desc="add example index and unique id", disable=not tqdm_enabled 327 | ): 328 | if not example_features: 329 | continue 330 | for example_feature in example_features: 331 | example_feature.example_index = example_index 332 | example_feature.unique_id = unique_id 333 | new_features.append(example_feature) 334 | unique_id += 1 335 | example_index += 1 336 | features = new_features 337 | del new_features 338 | 339 | if not is_torch_available(): 340 | raise RuntimeError("PyTorch must be installed to return a PyTorch dataset.") 341 | 342 | # Convert to Tensors and build dataset 343 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 344 | all_attention_masks = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 345 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 346 | all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) 347 | all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) 348 | all_is_impossible = torch.tensor([f.is_impossible for f in features], dtype=torch.float) 349 | 350 | if not is_training: 351 | all_feature_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 352 | dataset = TensorDataset( 353 | all_input_ids, all_attention_masks, all_token_type_ids, all_feature_index, all_cls_index, all_p_mask 354 | ) 355 | else: 356 | all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) 357 | all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) 358 | dataset = TensorDataset( 359 | all_input_ids, 360 | all_attention_masks, 361 | all_token_type_ids, 362 | all_start_positions, 363 | all_end_positions, 364 | all_cls_index, 365 | all_p_mask, 366 | all_is_impossible, 367 | ) 368 | 369 | return features, dataset 370 | 371 | def get_examples(data_dir, filename): 372 | 373 | with open(os.path.join(data_dir, filename), "r", encoding="utf-8") as reader: 374 | input_data = json.load(reader)["data"] 375 | return create_examples(input_data, "train") 376 | def create_examples(input_data, set_type): 377 | is_training = set_type == "train" 378 | examples = [] 379 | for entry in tqdm(input_data): 380 | title = entry["title"] 381 | for paragraph in entry["paragraphs"]: 382 | context_text = paragraph["context"] 383 | for qa in paragraph["qas"]: 384 | qas_id = qa["id"] 385 | question_text = qa["question"] 386 | start_position_character = None 387 | answer_text = None 388 | answers = [] 389 | 390 | is_impossible = qa.get("is_impossible", False) 391 | if not is_impossible: 392 | if is_training: 393 | answer = qa["answers"][0] 394 | answer_text = answer["text"] 395 | start_position_character = answer["answer_start"] 396 | else: 397 | answers = qa["answers"] 398 | 399 | example = SquadExample( 400 | qas_id=qas_id, 401 | question_text=question_text, 402 | context_text=context_text, 403 | answer_text=answer_text, 404 | start_position_character=start_position_character, 405 | title=title, 406 | is_impossible=is_impossible, 407 | answers=answers, 408 | ) 409 | examples.append(example) 410 | return examples 411 | 412 | class SquadExample: 413 | def __init__( 414 | self, 415 | qas_id, 416 | question_text, 417 | context_text, 418 | answer_text, 419 | start_position_character, 420 | title, 421 | answers=[], 422 | is_impossible=False, 423 | ): 424 | self.qas_id = qas_id 425 | self.question_text = question_text 426 | self.context_text = context_text 427 | self.answer_text = answer_text 428 | self.title = title 429 | self.is_impossible = is_impossible 430 | self.answers = answers 431 | 432 | self.start_position, self.end_position = 0, 0 433 | 434 | doc_tokens = [] 435 | char_to_word_offset = [] 436 | prev_is_whitespace = True 437 | 438 | # Split on whitespace so that different tokens may be attributed to their original position. 439 | for c in self.context_text: 440 | if _is_whitespace(c): 441 | prev_is_whitespace = True 442 | else: 443 | if prev_is_whitespace: 444 | doc_tokens.append(c) 445 | else: 446 | doc_tokens[-1] += c 447 | prev_is_whitespace = False 448 | char_to_word_offset.append(len(doc_tokens) - 1) 449 | 450 | self.doc_tokens = doc_tokens 451 | self.char_to_word_offset = char_to_word_offset 452 | 453 | # Start and end positions only has a value during evaluation. 454 | if start_position_character is not None and not is_impossible: 455 | self.start_position = char_to_word_offset[start_position_character] 456 | self.end_position = char_to_word_offset[ 457 | min(start_position_character + len(answer_text) - 1, len(char_to_word_offset) - 1) 458 | ] 459 | 460 | 461 | class SquadFeatures: 462 | """ 463 | Single squad example features to be fed to a model. Those features are model-specific and can be crafted from 464 | :class:`~transformers.data.processors.squad.SquadExample` using the 465 | :method:`~transformers.data.processors.squad.squad_convert_examples_to_features` method. 466 | 467 | Args: 468 | input_ids: Indices of input sequence tokens in the vocabulary. 469 | attention_mask: Mask to avoid performing attention on padding token indices. 470 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 471 | cls_index: the index of the CLS token. 472 | p_mask: Mask identifying tokens that can be answers vs. tokens that cannot. 473 | Mask with 1 for tokens than cannot be in the answer and 0 for token that can be in an answer 474 | example_index: the index of the example 475 | unique_id: The unique Feature identifier 476 | paragraph_len: The length of the context 477 | token_is_max_context: List of booleans identifying which tokens have their maximum context in this feature object. 478 | If a token does not have their maximum context in this feature object, it means that another feature object 479 | has more information related to that token and should be prioritized over this feature for that token. 480 | tokens: list of tokens corresponding to the input ids 481 | token_to_orig_map: mapping between the tokens and the original text, needed in order to identify the answer. 482 | start_position: start of the answer token index 483 | end_position: end of the answer token index 484 | encoding: optionally store the BatchEncoding with the fast-tokenizer alignement methods. 485 | """ 486 | 487 | def __init__( 488 | self, 489 | input_ids, 490 | attention_mask, 491 | token_type_ids, 492 | cls_index, 493 | p_mask, 494 | example_index, 495 | unique_id, 496 | paragraph_len, 497 | token_is_max_context, 498 | tokens, 499 | token_to_orig_map, 500 | start_position, 501 | end_position, 502 | is_impossible, 503 | qas_id: str = None, 504 | encoding: BatchEncoding = None, 505 | ): 506 | self.input_ids = input_ids 507 | self.attention_mask = attention_mask 508 | self.token_type_ids = token_type_ids 509 | self.cls_index = cls_index 510 | self.p_mask = p_mask 511 | 512 | self.example_index = example_index 513 | self.unique_id = unique_id 514 | self.paragraph_len = paragraph_len 515 | self.token_is_max_context = token_is_max_context 516 | self.tokens = tokens 517 | self.token_to_orig_map = token_to_orig_map 518 | 519 | self.start_position = start_position 520 | self.end_position = end_position 521 | self.is_impossible = is_impossible 522 | self.qas_id = qas_id 523 | 524 | self.encoding = encoding 525 | -------------------------------------------------------------------------------- /src/functions/processor.py: -------------------------------------------------------------------------------- 1 | from transformers import (BartForConditionalGeneration, 2 | PreTrainedTokenizerFast) 3 | import torch 4 | from tqdm import tqdm 5 | 6 | class ChatDataset(): 7 | def __init__(self, filepath, tok_vocab, enc_seq_len=512, dec_seq_len=64): 8 | self.filepath = filepath 9 | self.data = open(self.filepath, 'r', encoding='utf8').readlines() 10 | # self.data = pd.read_csv(filepath) 11 | self.bos_token = '' 12 | self.eos_token = '' 13 | self.enc_seq_len = enc_seq_len 14 | self.dec_seq_len = dec_seq_len 15 | self.tokenizer = PreTrainedTokenizerFast( 16 | tokenizer_file=tok_vocab, 17 | bos_token=self.bos_token, eos_token=self.eos_token, unk_token='', pad_token='', mask_token='') 18 | 19 | def make_input_id_mask(self, tokens, max_seq_len): 20 | input_id = self.tokenizer.convert_tokens_to_ids(tokens) 21 | attention_mask = [1] * len(input_id) 22 | if len(input_id) < max_seq_len: 23 | while len(input_id) < max_seq_len: 24 | input_id += [self.tokenizer.pad_token_id] 25 | attention_mask += [0] 26 | else: 27 | # logging.warning(f'exceed max_seq_len for given article : {index}') 28 | input_id = input_id[:max_seq_len - 1] + [ 29 | self.tokenizer.eos_token_id] 30 | attention_mask = attention_mask[:max_seq_len] 31 | return input_id, attention_mask 32 | 33 | def load_dataset(self): 34 | input_ids = [] 35 | attention_masks = [] 36 | decoder_input_ids = [] 37 | decoder_attention_masks = [] 38 | decoder_labels = [] 39 | for index in tqdm(range(len(self.data))): 40 | title, answer, context, question = self.data[index].strip().split('\t') 41 | title_tokens = [self.bos_token] + \ 42 | self.tokenizer.tokenize(title) + [self.eos_token] 43 | answer_tokens = [self.bos_token] + \ 44 | self.tokenizer.tokenize(answer) + [self.eos_token] 45 | context_tokens = [self.bos_token] + \ 46 | self.tokenizer.tokenize(context) + [self.eos_token] 47 | question_tokens = [self.bos_token] + \ 48 | self.tokenizer.tokenize(question) + [self.eos_token] 49 | encoder_input_id, encoder_attention_mask = self.make_input_id_mask( 50 | title_tokens + answer_tokens + context_tokens, self.enc_seq_len) 51 | decoder_input_id, decoder_attention_mask = self.make_input_id_mask( 52 | question_tokens, self.dec_seq_len) 53 | labels = self.tokenizer.convert_tokens_to_ids( 54 | question_tokens[1:(self.dec_seq_len + 1)]) 55 | if len(labels) < self.dec_seq_len: 56 | while len(labels) < self.dec_seq_len: 57 | # for cross entropy loss masking 58 | labels += [-100] 59 | input_ids.append(encoder_input_id) 60 | attention_masks.append(encoder_attention_mask) 61 | decoder_input_ids.append(decoder_input_id) 62 | decoder_attention_masks.append(decoder_attention_mask) 63 | decoder_labels.append(labels) 64 | input_ids = torch.tensor(input_ids, dtype=torch.long) 65 | attention_masks = torch.tensor(attention_masks, dtype=torch.long) 66 | decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long) 67 | decoder_attention_masks = torch.tensor(decoder_attention_masks, dtype=torch.long) 68 | decoder_labels = torch.tensor(decoder_labels, dtype=torch.long) 69 | 70 | return input_ids, attention_masks, decoder_input_ids, decoder_attention_masks, decoder_labels -------------------------------------------------------------------------------- /src/functions/rouge.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from itertools import chain 3 | 4 | 5 | def get_unigram_count(tokens): 6 | count_dict = dict() 7 | for t in tokens: 8 | if t in count_dict: 9 | count_dict[t] += 1 10 | else: 11 | count_dict[t] = 1 12 | 13 | return count_dict 14 | 15 | 16 | class Rouge(object): 17 | beta = 1 18 | 19 | @staticmethod 20 | def my_lcs_grid(x, y): 21 | n = len(x) 22 | m = len(y) 23 | 24 | table = [[0 for i in range(m + 1)] for j in range(n + 1)] 25 | 26 | for j in range(m + 1): 27 | for i in range(n + 1): 28 | if i == 0 or j == 0: 29 | cell = (0, 'e') 30 | elif x[i - 1] == y[j - 1]: 31 | cell = (table[i - 1][j - 1][0] + 1, '\\') 32 | else: 33 | over = table[i - 1][j][0] 34 | left = table[i][j - 1][0] 35 | 36 | if left < over: 37 | cell = (over, '^') 38 | else: 39 | cell = (left, '<') 40 | 41 | table[i][j] = cell 42 | 43 | return table 44 | 45 | @staticmethod 46 | def my_lcs(x, y, mask_x): 47 | table = Rouge.my_lcs_grid(x, y) 48 | i = len(x) 49 | j = len(y) 50 | 51 | while i > 0 and j > 0: 52 | move = table[i][j][1] 53 | if move == '\\': 54 | mask_x[i - 1] = 1 55 | i -= 1 56 | j -= 1 57 | elif move == '^': 58 | i -= 1 59 | elif move == '<': 60 | j -= 1 61 | 62 | return mask_x 63 | 64 | @staticmethod 65 | def rouge_l(cand_sents, ref_sents): 66 | lcs_scores = 0.0 67 | cand_unigrams = get_unigram_count(chain(*cand_sents)) 68 | ref_unigrams = get_unigram_count(chain(*ref_sents)) 69 | for cand_sent in cand_sents: 70 | cand_token_mask = [0 for t in cand_sent] 71 | cand_len = len(cand_sent) 72 | for ref_sent in ref_sents: 73 | # aligns = [] 74 | # Rouge.lcs(ref_sent, cand_sent, aligns) 75 | Rouge.my_lcs(cand_sent, ref_sent, cand_token_mask) 76 | 77 | # for i in aligns: 78 | # ref_token_mask[i] = 1 79 | # lcs = [] 80 | cur_lcs_score = 0.0 81 | for i in range(cand_len): 82 | if cand_token_mask[i]: 83 | token = cand_sent[i] 84 | if cand_unigrams[token] > 0 and ref_unigrams[token] > 0: 85 | cand_unigrams[token] -= 1 86 | ref_unigrams[token] -= 1 87 | cur_lcs_score += 1 88 | 89 | # lcs.append(token) 90 | 91 | # print ' '.join(lcs) 92 | 93 | lcs_scores += cur_lcs_score 94 | 95 | # print "lcs_scores: %d" % lcs_scores 96 | ref_words_count = sum(len(s) for s in ref_sents) 97 | # print "ref_words_count: %d" % ref_words_count 98 | cand_words_count = sum(len(s) for s in cand_sents) 99 | # print "cand_words_count: %d" % cand_words_count 100 | 101 | precision = lcs_scores / cand_words_count 102 | recall = lcs_scores / ref_words_count 103 | f_score = (1 + Rouge.beta ** 2) * precision * recall / (recall + 104 | Rouge.beta ** 2 * precision + 1e-7) + 1e-6 # prevent underflow 105 | return precision, recall, f_score 106 | 107 | # @staticmethod 108 | # def rouge_2(cand_sents, ref_sents): 109 | # cand_bigram_counts = get_bigram_counts(cand_sents) 110 | # ref_bigram_counts = get_bigram_counts(ref_sents) 111 | 112 | 113 | if __name__ == '__main__': 114 | r = Rouge() 115 | # A simple eample of how rouge can be calculated 116 | print(r.rouge_l([[1, 7, 6, 7, 5], [0, 2, 8, 3, 5]], 117 | [[1, 2, 3, 4, 5], [3, 9, 5]])) 118 | 119 | # A more practical example of how it can be used for summary evaluation 120 | system_generated_summary = " The Kyrgyz President pushed through the law requiring the use of ink during the upcoming Parliamentary and Presidential elections In an effort to live up to its reputation in the 1990s as an island of democracy. The use of ink is one part of a general effort to show commitment towards more open elections. improper use of this type of ink can cause additional problems as the elections in Afghanistan showed. The use of ink and readers by itself is not a panacea for election ills." 121 | manual_summmary = " The use of invisible ink and ultraviolet readers in the elections of the Kyrgyz Republic which is a small, mountainous state of the former Soviet republic, causing both worries and guarded optimism among different sectors of the population. Though the actual technology behind the ink is not complicated, the presence of ultraviolet light (of the kind used to verify money) causes the ink to glow with a neon yellow light. But, this use of the new technology has caused a lot of problems. " 122 | 123 | print(r.rouge_l([system_generated_summary], [manual_summmary])) -------------------------------------------------------------------------------- /src/functions/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import (BartForConditionalGeneration, 2 | PreTrainedTokenizerFast) 3 | import torch 4 | from tqdm import tqdm 5 | import nltk 6 | from torch.utils.data import TensorDataset 7 | import random 8 | import numpy as np 9 | from nltk.translate.bleu_score import sentence_bleu 10 | from src.functions.rouge import Rouge 11 | 12 | def set_seed(args): 13 | random.seed(args.seed) 14 | np.random.seed(args.seed) 15 | torch.manual_seed(args.seed) 16 | if torch.cuda.is_available(): 17 | torch.cuda.manual_seed_all(args.seed) 18 | 19 | def measure(preds, refs): 20 | b1, b2, b3, b4, r_l = 0, 0, 0, 0, 0 21 | r = Rouge() 22 | count = 0 23 | for index in range(len(preds)): 24 | pred = preds[index] 25 | ref = refs[index] 26 | b1 += sentence_bleu([ref], pred, weights=(1, 0, 0, 0)) 27 | b2 += sentence_bleu([ref], pred, weights=(0.5, 0.5, 0, 0)) 28 | b3 += sentence_bleu([ref], pred, weights=(0.33, 0.33, 0.33, 0)) 29 | b4 += sentence_bleu([ref], pred, weights=(0.25, 0.25, 0.25, 0.25)) 30 | r_l += r.rouge_l(["".join(pred).replace("▁", " ")], ["".join(ref).replace("▁", " ")])[2] 31 | # print(decode) 32 | count += 1 33 | print("BLEU-1 = ", b1 / count) 34 | print("BLEU-2 = ", b2 / count) 35 | print("BLEU-3 = ", b3 / count) 36 | print("BLEU-4 = ", b4 / count) 37 | print("ROUGE-L = ", r_l / count) 38 | 39 | class ChatDataset(): 40 | def __init__(self, filepath, tokenizer, enc_seq_len=300, dec_seq_len=30): 41 | self.filepath = filepath 42 | self.data = open(self.filepath, 'r', encoding='utf8').readlines() 43 | # self.data = pd.read_csv(filepath) 44 | self.bos_token = '' 45 | self.eos_token = '' 46 | self.enc_seq_len = enc_seq_len 47 | self.dec_seq_len = dec_seq_len 48 | self.tokenizer = tokenizer 49 | def make_input_id_mask(self, tokens, max_seq_len, passage_ids=None): 50 | input_id = self.tokenizer.convert_tokens_to_ids(tokens) 51 | attention_mask = [1] * len(input_id) 52 | if len(input_id) < max_seq_len: 53 | while len(input_id) < max_seq_len: 54 | input_id += [self.tokenizer.pad_token_id] 55 | attention_mask += [0] 56 | if passage_ids is not None: 57 | passage_ids += [0] 58 | else: 59 | # logging.warning(f'exceed max_seq_len for given article : {index}') 60 | input_id = input_id[:max_seq_len - 1] + [ 61 | self.tokenizer.eos_token_id] 62 | attention_mask = attention_mask[:max_seq_len] 63 | if passage_ids is not None: 64 | passage_ids = passage_ids[:max_seq_len] 65 | if passage_ids is not None: 66 | return input_id, attention_mask, passage_ids 67 | return input_id, attention_mask 68 | def gold_passage_mask(self, context): 69 | split_context = nltk.sent_tokenize(context) 70 | passage_label = [1 if '' not in e else 2 for e in split_context] 71 | tokenized_split_context = [self.tokenizer.tokenize(e) for e in split_context] 72 | tokenized_context = [] 73 | passage_ids = [] 74 | for e in range(len(passage_label)): 75 | tokenized_context+=tokenized_split_context[e] 76 | passage_ids+=[passage_label[e]]*len(tokenized_split_context[e]) 77 | 78 | return tokenized_context, passage_ids 79 | def gold_passage_mask_v2(self, context): 80 | split_context = nltk.sent_tokenize(context) 81 | passage_label = [0 if '' not in e else 1 for e in split_context] 82 | tokenized_split_context = [self.tokenizer.tokenize(e) for e in split_context] 83 | tokenized_context = [] 84 | passage_ids = [] 85 | for e in range(len(passage_label)): 86 | if passage_label[e] == 0: 87 | continue 88 | tokenized_context+=tokenized_split_context[e] 89 | passage_ids+=[passage_label[e]]*len(tokenized_split_context[e]) 90 | 91 | return tokenized_context, passage_ids 92 | def make_dataset(self, title, context): 93 | input_ids = [] 94 | attention_masks = [] 95 | 96 | 97 | answer = input("Enter The Answer To The Question : ") 98 | if '-1' in [title, context, answer]: 99 | exit(1) 100 | title_tokens = [self.bos_token] + \ 101 | self.tokenizer.tokenize(title) + [self.eos_token] 102 | answer_tokens = [self.bos_token] + \ 103 | self.tokenizer.tokenize(answer) + [self.eos_token] 104 | context_tokens = [self.bos_token] + \ 105 | self.tokenizer.tokenize(context) + [self.eos_token] 106 | encoder_input_id, encoder_attention_mask = self.make_input_id_mask( 107 | title_tokens + answer_tokens + context_tokens, self.enc_seq_len) 108 | 109 | input_ids.append(encoder_input_id) 110 | attention_masks.append(encoder_attention_mask) 111 | 112 | input_ids = torch.tensor(input_ids, dtype=torch.long) 113 | attention_masks = torch.tensor(attention_masks, dtype=torch.long) 114 | 115 | return TensorDataset(input_ids, attention_masks) 116 | def load_dataset_with_passage_ids(self): 117 | input_ids = [] 118 | attention_masks = [] 119 | gold_passage_ids = [] 120 | decoder_input_ids = [] 121 | decoder_attention_masks = [] 122 | decoder_labels = [] 123 | ids, levels = [], [] 124 | for index in tqdm(range(len(self.data))): 125 | # title, answer, context, question = self.data[index].strip().split('\t') 126 | datas = self.data[index].strip().split('\t') 127 | if len(datas) != 7: 128 | print(index, "!!!") 129 | continue 130 | id, level, title, answer, evidence_sent, context, question = datas 131 | if level == '하' or level == '중': 132 | continue 133 | ids.append(id) 134 | levels.append(level) 135 | tokenized_context, passage_ids = self.gold_passage_mask_v2(context) 136 | title_tokens = [self.bos_token] + \ 137 | self.tokenizer.tokenize(title) + [self.eos_token] 138 | answer_tokens = [self.bos_token] + \ 139 | self.tokenizer.tokenize(answer) + [self.eos_token] 140 | evidence_tokens = [self.bos_token] + \ 141 | self.tokenizer.tokenize(evidence_sent) + [self.eos_token] 142 | context_tokens = [self.bos_token] + \ 143 | tokenized_context + [self.eos_token] 144 | passage_ids = [0] + passage_ids + [0] 145 | question_tokens = [self.bos_token] + \ 146 | self.tokenizer.tokenize(question) + [self.eos_token] 147 | encoder_input_id, encoder_attention_mask, encoder_gold_passage_id = self.make_input_id_mask( 148 | title_tokens + answer_tokens + evidence_tokens + context_tokens, self.enc_seq_len, [0]*len(title_tokens + answer_tokens + evidence_tokens)+passage_ids) 149 | # encoder_input_id, encoder_attention_mask, encoder_gold_passage_id = self.make_input_id_mask( 150 | # title_tokens + answer_tokens + context_tokens, self.enc_seq_len, 151 | # [0] * len(title_tokens + answer_tokens) + passage_ids) 152 | decoder_input_id, decoder_attention_mask = self.make_input_id_mask( 153 | question_tokens, self.dec_seq_len) 154 | labels = self.tokenizer.convert_tokens_to_ids( 155 | question_tokens[1:(self.dec_seq_len + 1)]) 156 | if len(labels) < self.dec_seq_len: 157 | while len(labels) < self.dec_seq_len: 158 | # for cross entropy loss masking 159 | labels += [-100] 160 | # if len(input_ids) > 100: 161 | # break 162 | # print(context_tokens) 163 | # print(question_tokens) 164 | # print(decoder_labels) 165 | 166 | input_ids.append(encoder_input_id) 167 | attention_masks.append(encoder_attention_mask) 168 | gold_passage_ids.append(encoder_gold_passage_id) 169 | decoder_input_ids.append(decoder_input_id) 170 | decoder_attention_masks.append(decoder_attention_mask) 171 | decoder_labels.append(labels) 172 | 173 | input_ids = torch.tensor(input_ids, dtype=torch.long) 174 | attention_masks = torch.tensor(attention_masks, dtype=torch.long) 175 | gold_passage_ids = torch.tensor(gold_passage_ids, dtype=torch.long) 176 | decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long) 177 | decoder_attention_masks = torch.tensor(decoder_attention_masks, dtype=torch.long) 178 | decoder_labels = torch.tensor(decoder_labels, dtype=torch.long) 179 | 180 | return TensorDataset(input_ids, attention_masks, gold_passage_ids, decoder_input_ids, decoder_attention_masks, decoder_labels), ids, levels 181 | def load_dataset(self): 182 | input_ids = [] 183 | attention_masks = [] 184 | decoder_input_ids = [] 185 | decoder_attention_masks = [] 186 | decoder_labels = [] 187 | for index in tqdm(range(len(self.data))): 188 | title, answer, context, question = self.data[index].strip().split('\t') 189 | 190 | title_tokens = [self.bos_token] + \ 191 | self.tokenizer.tokenize(title) + [self.eos_token] 192 | answer_tokens = [self.bos_token] + \ 193 | self.tokenizer.tokenize(answer) + [self.eos_token] 194 | context_tokens = [self.bos_token] + \ 195 | self.tokenizer.tokenize(context) + [self.eos_token] 196 | question_tokens = [self.bos_token] + \ 197 | self.tokenizer.tokenize(question) + [self.eos_token] 198 | encoder_input_id, encoder_attention_mask = self.make_input_id_mask( 199 | title_tokens + answer_tokens + context_tokens, self.enc_seq_len) 200 | decoder_input_id, decoder_attention_mask = self.make_input_id_mask( 201 | question_tokens, self.dec_seq_len) 202 | labels = self.tokenizer.convert_tokens_to_ids( 203 | question_tokens[1:(self.dec_seq_len + 1)]) 204 | if len(labels) < self.dec_seq_len: 205 | while len(labels) < self.dec_seq_len: 206 | # for cross entropy loss masking 207 | labels += [-100] 208 | # if len(input_ids) > 100: 209 | # break 210 | # print(context_tokens) 211 | # print(question_tokens) 212 | # print(decoder_labels) 213 | 214 | input_ids.append(encoder_input_id) 215 | attention_masks.append(encoder_attention_mask) 216 | decoder_input_ids.append(decoder_input_id) 217 | decoder_attention_masks.append(decoder_attention_mask) 218 | decoder_labels.append(labels) 219 | 220 | input_ids = torch.tensor(input_ids, dtype=torch.long) 221 | attention_masks = torch.tensor(attention_masks, dtype=torch.long) 222 | decoder_input_ids = torch.tensor(decoder_input_ids, dtype=torch.long) 223 | decoder_attention_masks = torch.tensor(decoder_attention_masks, dtype=torch.long) 224 | decoder_labels = torch.tensor(decoder_labels, dtype=torch.long) 225 | 226 | return TensorDataset(input_ids, attention_masks, decoder_input_ids, decoder_attention_masks, decoder_labels) -------------------------------------------------------------------------------- /src/model/__pycache__/custom_bart.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/Question_Generation/5ccd2388aa679d3ccff908efedf45b1196fcad90/src/model/__pycache__/custom_bart.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/main_functions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/Question_Generation/5ccd2388aa679d3ccff908efedf45b1196fcad90/src/model/__pycache__/main_functions.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KUNLP/Question_Generation/5ccd2388aa679d3ccff908efedf45b1196fcad90/src/model/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /src/model/custom_bart.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch BART model. """ 16 | import copy 17 | import math 18 | import random 19 | import warnings 20 | from typing import Optional, Tuple 21 | 22 | import torch 23 | import torch.utils.checkpoint 24 | from torch import nn 25 | from torch.nn import CrossEntropyLoss, MSELoss 26 | 27 | from transformers.activations import ACT2FN 28 | from transformers.file_utils import ( 29 | add_code_sample_docstrings, 30 | add_end_docstrings, 31 | add_start_docstrings, 32 | add_start_docstrings_to_model_forward, 33 | replace_return_docstrings, 34 | ) 35 | from transformers.modeling_outputs import ( 36 | BaseModelOutput, 37 | BaseModelOutputWithPastAndCrossAttentions, 38 | CausalLMOutputWithCrossAttentions, 39 | Seq2SeqLMOutput, 40 | Seq2SeqModelOutput, 41 | Seq2SeqQuestionAnsweringModelOutput, 42 | Seq2SeqSequenceClassifierOutput, 43 | ) 44 | from transformers.modeling_utils import PreTrainedModel 45 | from transformers.utils import logging 46 | from transformers.models.bart.configuration_bart import BartConfig 47 | 48 | 49 | logger = logging.get_logger(__name__) 50 | 51 | _CHECKPOINT_FOR_DOC = "facebook/bart-large" 52 | _CONFIG_FOR_DOC = "BartConfig" 53 | _TOKENIZER_FOR_DOC = "BartTokenizer" 54 | 55 | 56 | BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ 57 | "facebook/bart-large", 58 | # See all BART models at https://huggingface.co/models?filter=bart 59 | ] 60 | 61 | 62 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 63 | """ 64 | Shift input ids one token to the right. 65 | """ 66 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 67 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 68 | shifted_input_ids[:, 0] = decoder_start_token_id 69 | 70 | assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." 71 | # replace possible -100 values in labels by `pad_token_id` 72 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 73 | 74 | return shifted_input_ids 75 | 76 | 77 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): 78 | """ 79 | Make causal mask used for bi-directional self-attention. 80 | """ 81 | bsz, tgt_len = input_ids_shape 82 | mask = torch.full((tgt_len, tgt_len), float("-inf")) 83 | mask_cond = torch.arange(mask.size(-1)) 84 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 85 | mask = mask.to(dtype) 86 | 87 | if past_key_values_length > 0: 88 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) 89 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 90 | 91 | 92 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 93 | """ 94 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 95 | """ 96 | bsz, src_len = mask.size() 97 | tgt_len = tgt_len if tgt_len is not None else src_len 98 | 99 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 100 | 101 | inverted_mask = 1.0 - expanded_mask 102 | 103 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) 104 | 105 | 106 | class BartLearnedPositionalEmbedding(nn.Embedding): 107 | """ 108 | This module learns positional embeddings up to a fixed maximum size. 109 | """ 110 | 111 | def __init__(self, num_embeddings: int, embedding_dim: int): 112 | # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 113 | # and adjust num_embeddings appropriately. Other models don't have this hack 114 | self.offset = 2 115 | super().__init__(num_embeddings + self.offset, embedding_dim) 116 | 117 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): 118 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 119 | bsz, seq_len = input_ids_shape[:2] 120 | positions = torch.arange( 121 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device 122 | ) 123 | return super().forward(positions + self.offset) 124 | 125 | 126 | class BartAttention(nn.Module): 127 | """Multi-headed attention from 'Attention Is All You Need' paper""" 128 | 129 | def __init__( 130 | self, 131 | embed_dim: int, 132 | num_heads: int, 133 | dropout: float = 0.0, 134 | is_decoder: bool = False, 135 | bias: bool = True, 136 | ): 137 | super().__init__() 138 | self.embed_dim = embed_dim 139 | self.num_heads = num_heads 140 | self.dropout = dropout 141 | self.head_dim = embed_dim // num_heads 142 | 143 | if (self.head_dim * num_heads) != self.embed_dim: 144 | raise ValueError( 145 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 146 | f" and `num_heads`: {num_heads})." 147 | ) 148 | self.scaling = self.head_dim ** -0.5 149 | self.is_decoder = is_decoder 150 | 151 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 152 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 153 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 154 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 155 | 156 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 157 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 158 | 159 | def forward( 160 | self, 161 | hidden_states: torch.Tensor, 162 | key_value_states: Optional[torch.Tensor] = None, 163 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 164 | attention_mask: Optional[torch.Tensor] = None, 165 | layer_head_mask: Optional[torch.Tensor] = None, 166 | output_attentions: bool = False, 167 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 168 | """Input shape: Batch x Time x Channel""" 169 | 170 | # if key_value_states are provided this layer is used as a cross-attention layer 171 | # for the decoder 172 | is_cross_attention = key_value_states is not None 173 | bsz, tgt_len, embed_dim = hidden_states.size() 174 | 175 | # get query proj 176 | query_states = self.q_proj(hidden_states) * self.scaling 177 | # get key, value proj 178 | if is_cross_attention and past_key_value is not None: 179 | # reuse k,v, cross_attentions 180 | key_states = past_key_value[0] 181 | value_states = past_key_value[1] 182 | elif is_cross_attention: 183 | # cross_attentions 184 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 185 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 186 | elif past_key_value is not None: 187 | # reuse k, v, self_attention 188 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 189 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 190 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 191 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 192 | else: 193 | # self_attention 194 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 195 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 196 | 197 | if self.is_decoder: 198 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 199 | # Further calls to cross_attention layer can then reuse all cross-attention 200 | # key/value_states (first "if" case) 201 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 202 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 203 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 204 | # if encoder bi-directional self-attention `past_key_value` is always `None` 205 | past_key_value = (key_states, value_states) 206 | 207 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 208 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 209 | key_states = key_states.view(*proj_shape) 210 | value_states = value_states.view(*proj_shape) 211 | 212 | src_len = key_states.size(1) 213 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 214 | 215 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 216 | raise ValueError( 217 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 218 | ) 219 | 220 | if attention_mask is not None: 221 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 222 | raise ValueError( 223 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 224 | ) 225 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 226 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 227 | 228 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 229 | 230 | if layer_head_mask is not None: 231 | if layer_head_mask.size() != (self.num_heads,): 232 | raise ValueError( 233 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" 234 | ) 235 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 236 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 237 | 238 | if output_attentions: 239 | # this operation is a bit awkward, but it's required to 240 | # make sure that attn_weights keeps its gradient. 241 | # In order to do so, attn_weights have to be reshaped 242 | # twice and have to be reused in the following 243 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 244 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 245 | else: 246 | attn_weights_reshaped = None 247 | 248 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 249 | 250 | attn_output = torch.bmm(attn_probs, value_states) 251 | 252 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 253 | raise ValueError( 254 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 255 | ) 256 | 257 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 258 | attn_output = attn_output.transpose(1, 2) 259 | attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) 260 | 261 | attn_output = self.out_proj(attn_output) 262 | 263 | return attn_output, attn_weights_reshaped, past_key_value 264 | 265 | 266 | class BartEncoderLayer(nn.Module): 267 | def __init__(self, config: BartConfig): 268 | super().__init__() 269 | self.embed_dim = config.d_model 270 | self.self_attn = BartAttention( 271 | embed_dim=self.embed_dim, 272 | num_heads=config.encoder_attention_heads, 273 | dropout=config.attention_dropout, 274 | ) 275 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 276 | self.dropout = config.dropout 277 | self.activation_fn = ACT2FN[config.activation_function] 278 | self.activation_dropout = config.activation_dropout 279 | self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) 280 | self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) 281 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 282 | 283 | def forward( 284 | self, 285 | hidden_states: torch.Tensor, 286 | attention_mask: torch.Tensor, 287 | layer_head_mask: torch.Tensor, 288 | output_attentions: bool = False, 289 | ): 290 | """ 291 | Args: 292 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` 293 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size 294 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 295 | layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size 296 | `(encoder_attention_heads,)`. 297 | output_attentions (:obj:`bool`, `optional`): 298 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 299 | returned tensors for more detail. 300 | """ 301 | residual = hidden_states 302 | hidden_states, attn_weights, _ = self.self_attn( 303 | hidden_states=hidden_states, 304 | attention_mask=attention_mask, 305 | layer_head_mask=layer_head_mask, 306 | output_attentions=output_attentions, 307 | ) 308 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 309 | hidden_states = residual + hidden_states 310 | hidden_states = self.self_attn_layer_norm(hidden_states) 311 | 312 | residual = hidden_states 313 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 314 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 315 | hidden_states = self.fc2(hidden_states) 316 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 317 | hidden_states = residual + hidden_states 318 | hidden_states = self.final_layer_norm(hidden_states) 319 | 320 | if hidden_states.dtype == torch.float16 and ( 321 | torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() 322 | ): 323 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 324 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 325 | 326 | outputs = (hidden_states,) 327 | 328 | if output_attentions: 329 | outputs += (attn_weights,) 330 | 331 | return outputs 332 | 333 | 334 | class BartDecoderLayer(nn.Module): 335 | def __init__(self, config: BartConfig): 336 | super().__init__() 337 | self.embed_dim = config.d_model 338 | 339 | self.self_attn = BartAttention( 340 | embed_dim=self.embed_dim, 341 | num_heads=config.decoder_attention_heads, 342 | dropout=config.attention_dropout, 343 | is_decoder=True, 344 | ) 345 | self.dropout = config.dropout 346 | self.activation_fn = ACT2FN[config.activation_function] 347 | self.activation_dropout = config.activation_dropout 348 | 349 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 350 | self.encoder_attn = BartAttention( 351 | self.embed_dim, 352 | config.decoder_attention_heads, 353 | dropout=config.attention_dropout, 354 | is_decoder=True, 355 | ) 356 | self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) 357 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 358 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 359 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 360 | 361 | def forward( 362 | self, 363 | hidden_states: torch.Tensor, 364 | attention_mask: Optional[torch.Tensor] = None, 365 | encoder_hidden_states: Optional[torch.Tensor] = None, 366 | encoder_attention_mask: Optional[torch.Tensor] = None, 367 | layer_head_mask: Optional[torch.Tensor] = None, 368 | cross_attn_layer_head_mask: Optional[torch.Tensor] = None, 369 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 370 | output_attentions: Optional[bool] = False, 371 | use_cache: Optional[bool] = True, 372 | ): 373 | """ 374 | Args: 375 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 376 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size 377 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 378 | encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(batch, seq_len, embed_dim)` 379 | encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size 380 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 381 | layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size 382 | `(encoder_attention_heads,)`. 383 | cross_attn_layer_head_mask (:obj:`torch.FloatTensor`): mask for cross-attention heads in a given layer of 384 | size `(decoder_attention_heads,)`. 385 | past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states 386 | output_attentions (:obj:`bool`, `optional`): 387 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 388 | returned tensors for more detail. 389 | """ 390 | residual = hidden_states 391 | 392 | # Self Attention 393 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 394 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 395 | # add present self-attn cache to positions 1,2 of present_key_value tuple 396 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 397 | hidden_states=hidden_states, 398 | past_key_value=self_attn_past_key_value, 399 | attention_mask=attention_mask, 400 | layer_head_mask=layer_head_mask, 401 | output_attentions=output_attentions, 402 | ) 403 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 404 | hidden_states = residual + hidden_states 405 | hidden_states = self.self_attn_layer_norm(hidden_states) 406 | 407 | # Cross-Attention Block 408 | cross_attn_present_key_value = None 409 | cross_attn_weights = None 410 | if encoder_hidden_states is not None: 411 | residual = hidden_states 412 | 413 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 414 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 415 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 416 | hidden_states=hidden_states, 417 | key_value_states=encoder_hidden_states, 418 | attention_mask=encoder_attention_mask, 419 | layer_head_mask=cross_attn_layer_head_mask, 420 | past_key_value=cross_attn_past_key_value, 421 | output_attentions=output_attentions, 422 | ) 423 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 424 | hidden_states = residual + hidden_states 425 | hidden_states = self.encoder_attn_layer_norm(hidden_states) 426 | 427 | # add cross-attn to positions 3,4 of present_key_value tuple 428 | present_key_value = present_key_value + cross_attn_present_key_value 429 | 430 | # Fully Connected 431 | residual = hidden_states 432 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 433 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 434 | hidden_states = self.fc2(hidden_states) 435 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 436 | hidden_states = residual + hidden_states 437 | hidden_states = self.final_layer_norm(hidden_states) 438 | 439 | outputs = (hidden_states,) 440 | 441 | if output_attentions: 442 | outputs += (self_attn_weights, cross_attn_weights) 443 | 444 | if use_cache: 445 | outputs += (present_key_value,) 446 | 447 | return outputs 448 | 449 | 450 | class BartClassificationHead(nn.Module): 451 | """Head for sentence-level classification tasks.""" 452 | 453 | def __init__( 454 | self, 455 | input_dim: int, 456 | inner_dim: int, 457 | num_classes: int, 458 | pooler_dropout: float, 459 | ): 460 | super().__init__() 461 | self.dense = nn.Linear(input_dim, inner_dim) 462 | self.dropout = nn.Dropout(p=pooler_dropout) 463 | self.out_proj = nn.Linear(inner_dim, num_classes) 464 | 465 | def forward(self, hidden_states: torch.Tensor): 466 | hidden_states = self.dropout(hidden_states) 467 | hidden_states = self.dense(hidden_states) 468 | hidden_states = torch.tanh(hidden_states) 469 | hidden_states = self.dropout(hidden_states) 470 | hidden_states = self.out_proj(hidden_states) 471 | return hidden_states 472 | 473 | 474 | class BartPretrainedModel(PreTrainedModel): 475 | config_class = BartConfig 476 | base_model_prefix = "model" 477 | supports_gradient_checkpointing = True 478 | _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] 479 | 480 | def _init_weights(self, module): 481 | std = self.config.init_std 482 | if isinstance(module, nn.Linear): 483 | module.weight.data.normal_(mean=0.0, std=std) 484 | if module.bias is not None: 485 | module.bias.data.zero_() 486 | elif isinstance(module, nn.Embedding): 487 | module.weight.data.normal_(mean=0.0, std=std) 488 | if module.padding_idx is not None: 489 | module.weight.data[module.padding_idx].zero_() 490 | 491 | def _set_gradient_checkpointing(self, module, value=False): 492 | if isinstance(module, (BartDecoder, BartEncoder)): 493 | module.gradient_checkpointing = value 494 | 495 | @property 496 | def dummy_inputs(self): 497 | pad_token = self.config.pad_token_id 498 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 499 | dummy_inputs = { 500 | "attention_mask": input_ids.ne(pad_token), 501 | "input_ids": input_ids, 502 | } 503 | return dummy_inputs 504 | 505 | 506 | class PretrainedBartModel(BartPretrainedModel): 507 | def __init_subclass__(self): 508 | warnings.warn( 509 | "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", 510 | FutureWarning, 511 | ) 512 | 513 | 514 | BART_START_DOCSTRING = r""" 515 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 516 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 517 | pruning heads etc.) 518 | 519 | This model is also a PyTorch `torch.nn.Module `__ 520 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 521 | general usage and behavior. 522 | 523 | Parameters: 524 | config (:class:`~transformers.BartConfig`): 525 | Model configuration class with all the parameters of the model. Initializing with a config file does not 526 | load the weights associated with the model, only the configuration. Check out the 527 | :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 528 | """ 529 | 530 | BART_GENERATION_EXAMPLE = r""" 531 | Summarization example:: 532 | 533 | >>> from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig 534 | 535 | >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn') 536 | >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn') 537 | 538 | >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." 539 | >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') 540 | 541 | >>> # Generate Summary 542 | >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) 543 | >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) 544 | 545 | Mask filling example:: 546 | 547 | >>> from transformers import BartTokenizer, BartForConditionalGeneration 548 | >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 549 | >>> TXT = "My friends are but they eat too many carbs." 550 | 551 | >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') 552 | >>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] 553 | >>> logits = model(input_ids).logits 554 | 555 | >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() 556 | >>> probs = logits[0, masked_index].softmax(dim=0) 557 | >>> values, predictions = probs.topk(5) 558 | 559 | >>> tokenizer.decode(predictions).split() 560 | """ 561 | 562 | BART_INPUTS_DOCSTRING = r""" 563 | Args: 564 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 565 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 566 | it. 567 | 568 | Indices can be obtained using :class:`~transformers.BartTokenizer`. See 569 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 570 | details. 571 | 572 | `What are input IDs? <../glossary.html#input-ids>`__ 573 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 574 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 575 | 576 | - 1 for tokens that are **not masked**, 577 | - 0 for tokens that are **masked**. 578 | 579 | `What are attention masks? <../glossary.html#attention-mask>`__ 580 | decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): 581 | Indices of decoder input sequence tokens in the vocabulary. 582 | 583 | Indices can be obtained using :class:`~transformers.BartTokenizer`. See 584 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 585 | details. 586 | 587 | `What are decoder input IDs? <../glossary.html#decoder-input-ids>`__ 588 | 589 | Bart uses the :obj:`eos_token_id` as the starting token for :obj:`decoder_input_ids` generation. If 590 | :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see 591 | :obj:`past_key_values`). 592 | 593 | For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no 594 | :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to 595 | the right for denoising pre-training following the paper. 596 | decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): 597 | Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will 598 | also be used by default. 599 | 600 | If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and 601 | modify to your needs. See diagram 1 in `the paper `__ for more 602 | information on the default strategy. 603 | head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): 604 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: 605 | 606 | - 1 indicates the head is **not masked**, 607 | - 0 indicates the head is **masked**. 608 | 609 | decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): 610 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: 611 | 612 | - 1 indicates the head is **not masked**, 613 | - 0 indicates the head is **masked**. 614 | 615 | cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): 616 | Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in ``[0, 617 | 1]``: 618 | 619 | - 1 indicates the head is **not masked**, 620 | - 0 indicates the head is **masked**. 621 | 622 | encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): 623 | Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: 624 | :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, 625 | `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the 626 | cross-attention of the decoder. 627 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): 628 | Tuple of :obj:`tuple(torch.FloatTensor)` of length :obj:`config.n_layers`, with each tuple having 2 tensors 629 | of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 630 | shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 631 | 632 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 633 | blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. 634 | 635 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 636 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 637 | instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. 638 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 639 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 640 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 641 | vectors than the model's internal embedding lookup matrix. 642 | decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): 643 | Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded 644 | representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` 645 | have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert 646 | :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 647 | 648 | If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` 649 | takes the value of :obj:`inputs_embeds`. 650 | use_cache (:obj:`bool`, `optional`): 651 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 652 | decoding (see :obj:`past_key_values`). 653 | output_attentions (:obj:`bool`, `optional`): 654 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 655 | tensors for more detail. 656 | output_hidden_states (:obj:`bool`, `optional`): 657 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 658 | more detail. 659 | return_dict (:obj:`bool`, `optional`): 660 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 661 | """ 662 | 663 | 664 | class BartEncoder(BartPretrainedModel): 665 | """ 666 | Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a 667 | :class:`BartEncoderLayer`. 668 | 669 | Args: 670 | config: BartConfig 671 | embed_tokens (nn.Embedding): output embedding 672 | """ 673 | 674 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): 675 | super().__init__(config) 676 | 677 | self.dropout = config.dropout 678 | self.layerdrop = config.encoder_layerdrop 679 | 680 | embed_dim = config.d_model 681 | self.padding_idx = config.pad_token_id 682 | self.max_source_positions = config.max_position_embeddings 683 | self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 684 | 685 | if embed_tokens is not None: 686 | self.embed_tokens = embed_tokens 687 | else: 688 | self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) 689 | 690 | self.embed_positions = BartLearnedPositionalEmbedding( 691 | config.max_position_embeddings, 692 | embed_dim, 693 | ) 694 | self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) 695 | self.layernorm_embedding = nn.LayerNorm(embed_dim) 696 | 697 | self.init_weights() 698 | self.gradient_checkpointing = False 699 | 700 | def get_input_embeddings(self): 701 | return self.embed_tokens 702 | 703 | def set_input_embeddings(self, value): 704 | self.embed_tokens = value 705 | 706 | def forward( 707 | self, 708 | input_ids=None, 709 | attention_mask=None, 710 | head_mask=None, 711 | inputs_embeds=None, 712 | output_attentions=None, 713 | output_hidden_states=None, 714 | return_dict=None, 715 | ): 716 | r""" 717 | Args: 718 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 719 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 720 | provide it. 721 | 722 | Indices can be obtained using :class:`~transformers.BartTokenizer`. See 723 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` 724 | for details. 725 | 726 | `What are input IDs? <../glossary.html#input-ids>`__ 727 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 728 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 729 | 730 | - 1 for tokens that are **not masked**, 731 | - 0 for tokens that are **masked**. 732 | 733 | `What are attention masks? <../glossary.html#attention-mask>`__ 734 | head_mask (:obj:`torch.Tensor` of shape :obj:`(encoder_layers, encoder_attention_heads)`, `optional`): 735 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: 736 | 737 | - 1 indicates the head is **not masked**, 738 | - 0 indicates the head is **masked**. 739 | 740 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 741 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded 742 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices 743 | into associated vectors than the model's internal embedding lookup matrix. 744 | output_attentions (:obj:`bool`, `optional`): 745 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 746 | returned tensors for more detail. 747 | output_hidden_states (:obj:`bool`, `optional`): 748 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors 749 | for more detail. 750 | return_dict (:obj:`bool`, `optional`): 751 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 752 | """ 753 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 754 | output_hidden_states = ( 755 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 756 | ) 757 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 758 | 759 | # retrieve input_ids and inputs_embeds 760 | if input_ids is not None and inputs_embeds is not None: 761 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 762 | elif input_ids is not None: 763 | input_shape = input_ids.size() 764 | input_ids = input_ids.view(-1, input_shape[-1]) 765 | elif inputs_embeds is not None: 766 | input_shape = inputs_embeds.size()[:-1] 767 | else: 768 | raise ValueError("You have to specify either input_ids or inputs_embeds") 769 | 770 | if inputs_embeds is None: 771 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 772 | 773 | embed_pos = self.embed_positions(input_shape) 774 | 775 | hidden_states = inputs_embeds + embed_pos 776 | hidden_states = self.layernorm_embedding(hidden_states) 777 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 778 | 779 | # expand attention_mask 780 | if attention_mask is not None: 781 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 782 | attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) 783 | 784 | encoder_states = () if output_hidden_states else None 785 | all_attentions = () if output_attentions else None 786 | 787 | # check if head_mask has a correct number of layers specified if desired 788 | if head_mask is not None: 789 | assert head_mask.size()[0] == ( 790 | len(self.layers) 791 | ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 792 | for idx, encoder_layer in enumerate(self.layers): 793 | if output_hidden_states: 794 | encoder_states = encoder_states + (hidden_states,) 795 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 796 | dropout_probability = random.uniform(0, 1) 797 | if self.training and (dropout_probability < self.layerdrop): # skip the layer 798 | layer_outputs = (None, None) 799 | else: 800 | if self.gradient_checkpointing and self.training: 801 | 802 | def create_custom_forward(module): 803 | def custom_forward(*inputs): 804 | return module(*inputs, output_attentions) 805 | 806 | return custom_forward 807 | 808 | layer_outputs = torch.utils.checkpoint.checkpoint( 809 | create_custom_forward(encoder_layer), 810 | hidden_states, 811 | attention_mask, 812 | (head_mask[idx] if head_mask is not None else None), 813 | ) 814 | else: 815 | layer_outputs = encoder_layer( 816 | hidden_states, 817 | attention_mask, 818 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 819 | output_attentions=output_attentions, 820 | ) 821 | 822 | hidden_states = layer_outputs[0] 823 | 824 | if output_attentions: 825 | all_attentions = all_attentions + (layer_outputs[1],) 826 | 827 | if output_hidden_states: 828 | encoder_states = encoder_states + (hidden_states,) 829 | 830 | if not return_dict: 831 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) 832 | return BaseModelOutput( 833 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions 834 | ) 835 | 836 | 837 | class BartDecoder(BartPretrainedModel): 838 | """ 839 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`BartDecoderLayer` 840 | 841 | Args: 842 | config: BartConfig 843 | embed_tokens (nn.Embedding): output embedding 844 | """ 845 | 846 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): 847 | super().__init__(config) 848 | self.dropout = config.dropout 849 | self.layerdrop = config.decoder_layerdrop 850 | self.padding_idx = config.pad_token_id 851 | self.max_target_positions = config.max_position_embeddings 852 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 853 | 854 | if embed_tokens is not None: 855 | self.embed_tokens = embed_tokens 856 | else: 857 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) 858 | 859 | self.embed_positions = BartLearnedPositionalEmbedding( 860 | config.max_position_embeddings, 861 | config.d_model, 862 | ) 863 | self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) 864 | self.layernorm_embedding = nn.LayerNorm(config.d_model) 865 | 866 | self.init_weights() 867 | self.gradient_checkpointing = False 868 | 869 | def get_input_embeddings(self): 870 | return self.embed_tokens 871 | 872 | def set_input_embeddings(self, value): 873 | self.embed_tokens = value 874 | 875 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 876 | # create causal mask 877 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 878 | combined_attention_mask = None 879 | if input_shape[-1] > 1: 880 | combined_attention_mask = _make_causal_mask( 881 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 882 | ).to(self.device) 883 | 884 | if attention_mask is not None: 885 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 886 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 887 | combined_attention_mask = ( 888 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 889 | ) 890 | 891 | return combined_attention_mask 892 | 893 | def forward( 894 | self, 895 | input_ids=None, 896 | attention_mask=None, 897 | encoder_hidden_states=None, 898 | encoder_attention_mask=None, 899 | head_mask=None, 900 | cross_attn_head_mask=None, 901 | past_key_values=None, 902 | inputs_embeds=None, 903 | use_cache=None, 904 | output_attentions=None, 905 | output_hidden_states=None, 906 | return_dict=None, 907 | ): 908 | r""" 909 | Args: 910 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 911 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 912 | provide it. 913 | 914 | Indices can be obtained using :class:`~transformers.BartTokenizer`. See 915 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` 916 | for details. 917 | 918 | `What are input IDs? <../glossary.html#input-ids>`__ 919 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 920 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 921 | 922 | - 1 for tokens that are **not masked**, 923 | - 0 for tokens that are **masked**. 924 | 925 | `What are attention masks? <../glossary.html#attention-mask>`__ 926 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): 927 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 928 | of the decoder. 929 | encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): 930 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 931 | selected in ``[0, 1]``: 932 | 933 | - 1 for tokens that are **not masked**, 934 | - 0 for tokens that are **masked**. 935 | 936 | `What are attention masks? <../glossary.html#attention-mask>`__ 937 | head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): 938 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: 939 | 940 | - 1 indicates the head is **not masked**, 941 | - 0 indicates the head is **masked**. 942 | 943 | cross_attn_head_mask (:obj:`torch.Tensor` of shape :obj:`(decoder_layers, decoder_attention_heads)`, `optional`): 944 | Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing 945 | cross-attention on hidden heads. Mask values selected in ``[0, 1]``: 946 | 947 | - 1 indicates the head is **not masked**, 948 | - 0 indicates the head is **masked**. 949 | 950 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): 951 | Tuple of :obj:`tuple(torch.FloatTensor)` of length :obj:`config.n_layers`, with each tuple having 2 952 | tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional 953 | tensors of shape :obj:`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 954 | 955 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 956 | cross-attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential 957 | decoding. 958 | 959 | If :obj:`past_key_values` are used, the user can optionally input only the last 960 | :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of 961 | shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, 962 | sequence_length)`. 963 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 964 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded 965 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices 966 | into associated vectors than the model's internal embedding lookup matrix. 967 | output_attentions (:obj:`bool`, `optional`): 968 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 969 | returned tensors for more detail. 970 | output_hidden_states (:obj:`bool`, `optional`): 971 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors 972 | for more detail. 973 | return_dict (:obj:`bool`, `optional`): 974 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 975 | """ 976 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 977 | output_hidden_states = ( 978 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 979 | ) 980 | use_cache = use_cache if use_cache is not None else self.config.use_cache 981 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 982 | 983 | # retrieve input_ids and inputs_embeds 984 | if input_ids is not None and inputs_embeds is not None: 985 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 986 | elif input_ids is not None: 987 | input_shape = input_ids.size() 988 | input_ids = input_ids.view(-1, input_shape[-1]) 989 | elif inputs_embeds is not None: 990 | input_shape = inputs_embeds.size()[:-1] 991 | else: 992 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 993 | 994 | # past_key_values_length 995 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 996 | 997 | if inputs_embeds is None: 998 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 999 | 1000 | attention_mask = self._prepare_decoder_attention_mask( 1001 | attention_mask, input_shape, inputs_embeds, past_key_values_length 1002 | ) 1003 | 1004 | # expand encoder attention mask 1005 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 1006 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1007 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 1008 | 1009 | # embed positions 1010 | positions = self.embed_positions(input_shape, past_key_values_length) 1011 | 1012 | hidden_states = inputs_embeds + positions 1013 | hidden_states = self.layernorm_embedding(hidden_states) 1014 | 1015 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1016 | 1017 | # decoder layers 1018 | all_hidden_states = () if output_hidden_states else None 1019 | all_self_attns = () if output_attentions else None 1020 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 1021 | next_decoder_cache = () if use_cache else None 1022 | 1023 | # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired 1024 | for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): 1025 | if attn_mask is not None: 1026 | assert attn_mask.size()[0] == ( 1027 | len(self.layers) 1028 | ), f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 1029 | for idx, decoder_layer in enumerate(self.layers): 1030 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 1031 | if output_hidden_states: 1032 | all_hidden_states += (hidden_states,) 1033 | dropout_probability = random.uniform(0, 1) 1034 | if self.training and (dropout_probability < self.layerdrop): 1035 | continue 1036 | 1037 | past_key_value = past_key_values[idx] if past_key_values is not None else None 1038 | 1039 | if self.gradient_checkpointing and self.training: 1040 | 1041 | if use_cache: 1042 | logger.warning( 1043 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1044 | ) 1045 | use_cache = False 1046 | 1047 | def create_custom_forward(module): 1048 | def custom_forward(*inputs): 1049 | # None for past_key_value 1050 | return module(*inputs, output_attentions, use_cache) 1051 | 1052 | return custom_forward 1053 | 1054 | layer_outputs = torch.utils.checkpoint.checkpoint( 1055 | create_custom_forward(decoder_layer), 1056 | hidden_states, 1057 | attention_mask, 1058 | encoder_hidden_states, 1059 | encoder_attention_mask, 1060 | head_mask[idx] if head_mask is not None else None, 1061 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, 1062 | None, 1063 | ) 1064 | else: 1065 | 1066 | layer_outputs = decoder_layer( 1067 | hidden_states, 1068 | attention_mask=attention_mask, 1069 | encoder_hidden_states=encoder_hidden_states, 1070 | encoder_attention_mask=encoder_attention_mask, 1071 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 1072 | cross_attn_layer_head_mask=( 1073 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None 1074 | ), 1075 | past_key_value=past_key_value, 1076 | output_attentions=output_attentions, 1077 | use_cache=use_cache, 1078 | ) 1079 | hidden_states = layer_outputs[0] 1080 | 1081 | if use_cache: 1082 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 1083 | 1084 | if output_attentions: 1085 | all_self_attns += (layer_outputs[1],) 1086 | 1087 | if encoder_hidden_states is not None: 1088 | all_cross_attentions += (layer_outputs[2],) 1089 | 1090 | # add hidden states from the last decoder layer 1091 | if output_hidden_states: 1092 | all_hidden_states += (hidden_states,) 1093 | 1094 | next_cache = next_decoder_cache if use_cache else None 1095 | if not return_dict: 1096 | return tuple( 1097 | v 1098 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 1099 | if v is not None 1100 | ) 1101 | return BaseModelOutputWithPastAndCrossAttentions( 1102 | last_hidden_state=hidden_states, 1103 | past_key_values=next_cache, 1104 | hidden_states=all_hidden_states, 1105 | attentions=all_self_attns, 1106 | cross_attentions=all_cross_attentions, 1107 | ) 1108 | 1109 | 1110 | @add_start_docstrings( 1111 | "The bare BART Model outputting raw hidden-states without any specific head on top.", 1112 | BART_START_DOCSTRING, 1113 | ) 1114 | class BartModel(BartPretrainedModel): 1115 | def __init__(self, config: BartConfig): 1116 | super().__init__(config) 1117 | 1118 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size 1119 | self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) 1120 | 1121 | self.encoder = BartEncoder(config, self.shared) 1122 | self.decoder = BartDecoder(config, self.shared) 1123 | 1124 | self.init_weights() 1125 | 1126 | def get_input_embeddings(self): 1127 | return self.shared 1128 | 1129 | def set_input_embeddings(self, value): 1130 | self.shared = value 1131 | self.encoder.embed_tokens = self.shared 1132 | self.decoder.embed_tokens = self.shared 1133 | 1134 | def get_encoder(self): 1135 | return self.encoder 1136 | 1137 | def get_decoder(self): 1138 | return self.decoder 1139 | 1140 | # @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1141 | # @add_code_sample_docstrings( 1142 | # processor_class=_TOKENIZER_FOR_DOC, 1143 | # checkpoint=_CHECKPOINT_FOR_DOC, 1144 | # output_type=Seq2SeqModelOutput, 1145 | # config_class=_CONFIG_FOR_DOC, 1146 | # ) 1147 | def forward( 1148 | self, 1149 | input_ids=None, 1150 | attention_mask=None, 1151 | decoder_input_ids=None, 1152 | decoder_attention_mask=None, 1153 | head_mask=None, 1154 | decoder_head_mask=None, 1155 | cross_attn_head_mask=None, 1156 | encoder_outputs=None, 1157 | past_key_values=None, 1158 | inputs_embeds=None, 1159 | decoder_inputs_embeds=None, 1160 | use_cache=None, 1161 | output_attentions=None, 1162 | output_hidden_states=None, 1163 | return_dict=None, 1164 | ): 1165 | 1166 | # different to other models, Bart automatically creates decoder_input_ids from 1167 | # input_ids if no decoder_input_ids are provided 1168 | if decoder_input_ids is None and decoder_inputs_embeds is None: 1169 | decoder_input_ids = shift_tokens_right( 1170 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id 1171 | ) 1172 | 1173 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1174 | output_hidden_states = ( 1175 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1176 | ) 1177 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1178 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1179 | 1180 | if encoder_outputs is None: 1181 | encoder_outputs = self.encoder( 1182 | input_ids=input_ids, 1183 | attention_mask=attention_mask, 1184 | head_mask=head_mask, 1185 | inputs_embeds=inputs_embeds, 1186 | output_attentions=output_attentions, 1187 | output_hidden_states=output_hidden_states, 1188 | return_dict=return_dict, 1189 | ) 1190 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 1191 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 1192 | encoder_outputs = BaseModelOutput( 1193 | last_hidden_state=encoder_outputs[0], 1194 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 1195 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 1196 | ) 1197 | 1198 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 1199 | decoder_outputs = self.decoder( 1200 | input_ids=decoder_input_ids, 1201 | attention_mask=decoder_attention_mask, 1202 | encoder_hidden_states=encoder_outputs[0], 1203 | encoder_attention_mask=attention_mask, 1204 | head_mask=decoder_head_mask, 1205 | cross_attn_head_mask=cross_attn_head_mask, 1206 | past_key_values=past_key_values, 1207 | inputs_embeds=decoder_inputs_embeds, 1208 | use_cache=use_cache, 1209 | output_attentions=output_attentions, 1210 | output_hidden_states=output_hidden_states, 1211 | return_dict=return_dict, 1212 | ) 1213 | 1214 | if not return_dict: 1215 | return decoder_outputs + encoder_outputs 1216 | 1217 | return Seq2SeqModelOutput( 1218 | last_hidden_state=decoder_outputs.last_hidden_state, 1219 | past_key_values=decoder_outputs.past_key_values, 1220 | decoder_hidden_states=decoder_outputs.hidden_states, 1221 | decoder_attentions=decoder_outputs.attentions, 1222 | cross_attentions=decoder_outputs.cross_attentions, 1223 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 1224 | encoder_hidden_states=encoder_outputs.hidden_states, 1225 | encoder_attentions=encoder_outputs.attentions, 1226 | ) 1227 | 1228 | 1229 | 1230 | @add_start_docstrings( 1231 | "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING 1232 | ) 1233 | class BartForConditionalGeneration(BartPretrainedModel): 1234 | base_model_prefix = "model" 1235 | _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"] 1236 | 1237 | def __init__(self, config: BartConfig): 1238 | super().__init__(config) 1239 | self.model = BartModel(config) 1240 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 1241 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 1242 | 1243 | self.init_weights() 1244 | 1245 | def get_encoder(self): 1246 | return self.model.get_encoder() 1247 | 1248 | def get_decoder(self): 1249 | return self.model.get_decoder() 1250 | 1251 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: 1252 | new_embeddings = super().resize_token_embeddings(new_num_tokens) 1253 | self._resize_final_logits_bias(new_num_tokens) 1254 | return new_embeddings 1255 | 1256 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None: 1257 | old_num_tokens = self.final_logits_bias.shape[-1] 1258 | if new_num_tokens <= old_num_tokens: 1259 | new_bias = self.final_logits_bias[:, :new_num_tokens] 1260 | else: 1261 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 1262 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 1263 | self.register_buffer("final_logits_bias", new_bias) 1264 | 1265 | def get_output_embeddings(self): 1266 | return self.lm_head 1267 | 1268 | def set_output_embeddings(self, new_embeddings): 1269 | self.lm_head = new_embeddings 1270 | 1271 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1272 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 1273 | @add_end_docstrings(BART_GENERATION_EXAMPLE) 1274 | def forward( 1275 | self, 1276 | input_ids=None, 1277 | attention_mask=None, 1278 | decoder_input_ids=None, 1279 | decoder_attention_mask=None, 1280 | head_mask=None, 1281 | decoder_head_mask=None, 1282 | cross_attn_head_mask=None, 1283 | encoder_outputs=None, 1284 | past_key_values=None, 1285 | inputs_embeds=None, 1286 | decoder_inputs_embeds=None, 1287 | labels=None, 1288 | use_cache=None, 1289 | output_attentions=None, 1290 | output_hidden_states=None, 1291 | return_dict=None, 1292 | ): 1293 | r""" 1294 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1295 | Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., 1296 | config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored 1297 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. 1298 | 1299 | Returns: 1300 | """ 1301 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1302 | 1303 | if labels is not None: 1304 | if decoder_input_ids is None and decoder_inputs_embeds is None: 1305 | decoder_input_ids = shift_tokens_right( 1306 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 1307 | ) 1308 | 1309 | outputs = self.model( 1310 | input_ids, 1311 | attention_mask=attention_mask, 1312 | decoder_input_ids=decoder_input_ids, 1313 | encoder_outputs=encoder_outputs, 1314 | decoder_attention_mask=decoder_attention_mask, 1315 | head_mask=head_mask, 1316 | decoder_head_mask=decoder_head_mask, 1317 | cross_attn_head_mask=cross_attn_head_mask, 1318 | past_key_values=past_key_values, 1319 | inputs_embeds=inputs_embeds, 1320 | decoder_inputs_embeds=decoder_inputs_embeds, 1321 | use_cache=use_cache, 1322 | output_attentions=output_attentions, 1323 | output_hidden_states=output_hidden_states, 1324 | return_dict=return_dict, 1325 | ) 1326 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 1327 | 1328 | masked_lm_loss = None 1329 | if labels is not None: 1330 | loss_fct = CrossEntropyLoss() 1331 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 1332 | 1333 | if not return_dict: 1334 | output = (lm_logits,) + outputs[1:] 1335 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1336 | 1337 | return Seq2SeqLMOutput( 1338 | loss=masked_lm_loss, 1339 | logits=lm_logits, 1340 | past_key_values=outputs.past_key_values, 1341 | decoder_hidden_states=outputs.decoder_hidden_states, 1342 | decoder_attentions=outputs.decoder_attentions, 1343 | cross_attentions=outputs.cross_attentions, 1344 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1345 | encoder_hidden_states=outputs.encoder_hidden_states, 1346 | encoder_attentions=outputs.encoder_attentions, 1347 | ) 1348 | 1349 | 1350 | def prepare_inputs_for_generation( 1351 | self, 1352 | decoder_input_ids, 1353 | past=None, 1354 | attention_mask=None, 1355 | head_mask=None, 1356 | decoder_head_mask=None, 1357 | cross_attn_head_mask=None, 1358 | use_cache=None, 1359 | encoder_outputs=None, 1360 | **kwargs 1361 | ): 1362 | # cut decoder_input_ids if past is used 1363 | if past is not None: 1364 | decoder_input_ids = decoder_input_ids[:, -1:] 1365 | 1366 | return { 1367 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 1368 | "encoder_outputs": encoder_outputs, 1369 | "past_key_values": past, 1370 | "decoder_input_ids": decoder_input_ids, 1371 | "attention_mask": attention_mask, 1372 | "head_mask": head_mask, 1373 | "decoder_head_mask": decoder_head_mask, 1374 | "cross_attn_head_mask": cross_attn_head_mask, 1375 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 1376 | } 1377 | 1378 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 1379 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 1380 | 1381 | @staticmethod 1382 | def _reorder_cache(past, beam_idx): 1383 | reordered_past = () 1384 | for layer_past in past: 1385 | # cached cross_attention states don't have to be reordered -> they are always the same 1386 | reordered_past += ( 1387 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], 1388 | ) 1389 | return reordered_past 1390 | 1391 | 1392 | 1393 | -------------------------------------------------------------------------------- /src/model/main_functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 5 | from src.functions.utils import ChatDataset, set_seed 6 | from transformers import AdamW 7 | from src.functions.utils import measure 8 | from src.functions.rouge import Rouge 9 | import timeit 10 | 11 | def to_list(tensor): 12 | return tensor.detach().cpu().tolist() 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | def train(args, model, tokenizer): 17 | """ Train the model """ 18 | dataset = ChatDataset(args.train_file, tokenizer) 19 | # train_dataset = dataset.load_dataset() 20 | train_dataset, _, _ = dataset.load_dataset_with_passage_ids() 21 | train_sampler = RandomSampler(train_dataset) 22 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size) 23 | no_decay = ["bias", "LayerNorm.weight"] 24 | optimizer_grouped_parameters = [ 25 | { 26 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 27 | "weight_decay": args.weight_decay, 28 | }, 29 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, 30 | ] 31 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 32 | 33 | global_step = 1 34 | steps_trained_in_current_epoch = 0 35 | # Check if continuing training from a checkpoint 36 | 37 | tr_loss, logging_loss = 0.0, 0.0 38 | model.zero_grad() 39 | 40 | # Added here for reproductibility 41 | set_seed(args) 42 | 43 | for epoch in range(args.train_epochs): 44 | for step, batch in enumerate(train_dataloader): 45 | # Skip past any already trained steps if resuming training 46 | model.train() 47 | batch = tuple(t.to(args.device) for t in batch) 48 | 49 | inputs = { 50 | "input_ids": batch[0], 51 | "attention_mask": batch[1], 52 | "passage_ids": batch[2], 53 | "decoder_input_ids": batch[3], 54 | "decoder_attention_mask": batch[4], 55 | "decoder_labels":batch[5], 56 | } 57 | # batch_size = batch[0].size()[0] 58 | # ee = batch[0].tolist() 59 | # d = batch[2].tolist() 60 | # dd = batch[4].tolist() 61 | # for e in range(batch_size): 62 | # enc_input = dataset.tokenizer.batch_decode(ee[e]) 63 | # dec_input = dataset.tokenizer.batch_decode(d[e][:d[e].index(3)]) 64 | # dec_labels = dataset.tokenizer.batch_decode(dd[e][:dd[e].index(-100)]) 65 | # print(33333) 66 | 67 | outputs = model(input_ids=batch[0], 68 | attention_mask=batch[1], 69 | passage_ids = batch[2], 70 | decoder_input_ids=batch[3], 71 | decoder_attention_mask=batch[4], 72 | labels=batch[5], return_dict=True) 73 | loss = outputs["loss"] 74 | if args.gradient_accumulation_steps > 1: 75 | loss = loss / args.gradient_accumulation_steps 76 | 77 | loss.backward() 78 | if (global_step+1) % 50 == 0: 79 | print("{} Processed.. Total Loss : {}".format(global_step+1, loss.item())) 80 | tr_loss += loss.item() 81 | if (step + 1) % args.gradient_accumulation_steps == 0: 82 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 83 | 84 | optimizer.step() 85 | model.zero_grad() 86 | global_step += 1 87 | 88 | # Save model checkpoint 89 | if global_step % args.save_steps == 0: 90 | evaluate(args, model, tokenizer, global_step) 91 | output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) 92 | print("Model Save in {}".format(output_dir)) 93 | if not os.path.exists(output_dir): 94 | os.makedirs(output_dir) 95 | # Take care of distributed/parallel training 96 | model_to_save = model.module if hasattr(model, "module") else model 97 | model_to_save.save_pretrained(output_dir) 98 | 99 | 100 | torch.save(args, os.path.join(output_dir, "training_args.bin")) 101 | logger.info("Saving model checkpoint to %s", output_dir) 102 | 103 | return global_step, tr_loss / global_step 104 | from tqdm import tqdm 105 | def evaluate(args, model, tokenizer, global_step=0): 106 | model.eval() 107 | dataset = ChatDataset(args.test_file, tokenizer) 108 | # test_dataset = dataset.load_dataset() 109 | test_dataset, _, _ = dataset.load_dataset_with_passage_ids() 110 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size) 111 | output_file = open(os.path.join(args.output_dir, "result_{}_only_hard.txt".format(global_step)), 'w', encoding='utf8') 112 | preds = [] 113 | refs = [] 114 | for batch in tqdm(test_dataloader): 115 | # Skip past any already trained steps if resuming training 116 | model.train() 117 | batch = tuple(t.to(args.device) for t in batch) 118 | 119 | 120 | dec_outputs = model.generate(input_ids = batch[0], 121 | attention_mask=batch[1], 122 | passage_ids=batch[2], 123 | max_length=32, 124 | num_beams=5, 125 | eos_token_id=1, 126 | bad_words_ids=[[5]]) 127 | batch_size = batch[0].size()[0] 128 | 129 | dec_outputs = dec_outputs.tolist() 130 | dec_labels = batch[5].tolist() 131 | 132 | for index in range(batch_size): 133 | if 1 in dec_outputs[index]: 134 | dec_outputs[index] = dec_outputs[index][:dec_outputs[index].index(1)] 135 | if -100 in dec_labels[index]: 136 | dec_labels[index] = dec_labels[index][:dec_labels[index].index(-100)] 137 | pred = dataset.tokenizer.convert_ids_to_tokens(dec_outputs[index][1:]) 138 | ref = dataset.tokenizer.convert_ids_to_tokens(dec_labels[index][:-1]) 139 | output_file.write("REFERENCE : {}\nDECODED : {}\n\n".format(''.join(ref), ''.join(pred))) 140 | preds.append(pred) 141 | refs.append(ref) 142 | measure(preds, refs) 143 | import json 144 | def make_file(args, model, tokenizer, global_step=0): 145 | model.eval() 146 | dataset = ChatDataset(args.predict_file, tokenizer) 147 | # test_dataset = dataset.load_dataset() 148 | test_dataset, ids, levels = dataset.load_dataset_with_passage_ids() 149 | test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size) 150 | # output_file = open(os.path.join(args.output_dir, "all_silver_sent_result_{}.json".format(global_step)), 'w', encoding='utf8') 151 | output_file = open(os.path.join(args.output_dir, "all_outdomain_aug.json".format(global_step)), 'w', 152 | encoding='utf8') 153 | result_dict = {} 154 | preds = [] 155 | refs = [] 156 | step = 0 157 | for batch in tqdm(test_dataloader): 158 | # Skip past any already trained steps if resuming training 159 | model.train() 160 | batch = tuple(t.to(args.device) for t in batch) 161 | 162 | dec_outputs = model.generate(input_ids=batch[0], 163 | attention_mask=batch[1], 164 | passage_ids=batch[2], 165 | max_length=32, 166 | num_beams=5, 167 | eos_token_id=1, 168 | bad_words_ids=[[5]]) 169 | batch_size = batch[0].size()[0] 170 | 171 | dec_outputs = dec_outputs.tolist() 172 | dec_labels = batch[5].tolist() 173 | 174 | for index in range(batch_size): 175 | if 1 in dec_outputs[index]: 176 | dec_outputs[index] = dec_outputs[index][:dec_outputs[index].index(1)] 177 | if -100 in dec_labels[index]: 178 | dec_labels[index] = dec_labels[index][:dec_labels[index].index(-100)] 179 | pred = dataset.tokenizer.convert_ids_to_tokens(dec_outputs[index][1:]) 180 | ref = dataset.tokenizer.convert_ids_to_tokens(dec_labels[index][:-1]) 181 | # output_file.write("REFERENCE : {}\nDECODED : {}\n\n".format(''.join(ref), ''.join(pred))) 182 | preds.append(pred) 183 | refs.append(ref) 184 | result_dict[ids[step]] = [''.join(pred).replace("▁", " "), levels[step]] 185 | step += 1 186 | json.dump(result_dict, output_file, indent='\t', ensure_ascii=False) 187 | # def predict(args, model, tokenizer): 188 | # model.eval() 189 | # dataset = ChatDataset(args.test_file, tokenizer) 190 | # print("Enter -1 to Quit") 191 | # title = input("Enter The Document Title : ") 192 | # context = input("Enter The Document : ") 193 | # while(1): 194 | # pred_dataset = dataset.make_dataset(title, context) 195 | # pred_dataloader = DataLoader(pred_dataset, batch_size=args.batch_size) 196 | # 197 | # for step, batch in enumerate(pred_dataloader): 198 | # # Skip past any already trained steps if resuming training 199 | # model.train() 200 | # batch = tuple(t.to(args.device) for t in batch) 201 | # 202 | # inputs = { 203 | # "input_ids": batch[0], 204 | # "attention_mask" : batch[1] 205 | # } 206 | # 207 | # dec_outputs = model.generate(input_ids = batch[0], 208 | # attention_mask=batch[1], 209 | # max_length=32, 210 | # num_beams=5, 211 | # eos_token_id=1, 212 | # bad_words_ids=[[5]]) 213 | # 214 | # 215 | # dec_outputs = dec_outputs.tolist()[0] 216 | # 217 | # if 1 in dec_outputs: 218 | # dec_outputs = dec_outputs[:dec_outputs.index(1)] 219 | # 220 | # pred = dataset.tokenizer.convert_ids_to_tokens(dec_outputs[1:]) 221 | # print("Generated Question : ", ''.join(pred).replace("▁", " "), "\n\n") 222 | 223 | 224 | 225 | -------------------------------------------------------------------------------- /src/model/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch BART model. """ 16 | import copy 17 | import math 18 | import random 19 | import warnings 20 | from typing import Optional, Tuple 21 | from torch.nn import functional as F 22 | import torch 23 | import torch.utils.checkpoint 24 | from torch import nn 25 | from torch.nn import CrossEntropyLoss, MSELoss 26 | 27 | from transformers.activations import ACT2FN 28 | from transformers.file_utils import ( 29 | add_code_sample_docstrings, 30 | add_end_docstrings, 31 | add_start_docstrings, 32 | add_start_docstrings_to_model_forward, 33 | replace_return_docstrings, 34 | ) 35 | from transformers.modeling_outputs import ( 36 | BaseModelOutput, 37 | BaseModelOutputWithPastAndCrossAttentions, 38 | CausalLMOutputWithCrossAttentions, 39 | Seq2SeqLMOutput, 40 | Seq2SeqModelOutput, 41 | Seq2SeqQuestionAnsweringModelOutput, 42 | Seq2SeqSequenceClassifierOutput, 43 | ) 44 | 45 | from transformers.models.roberta.modeling_roberta import RobertaModel 46 | from transformers.modeling_utils import PreTrainedModel 47 | from transformers.utils import logging 48 | from transformers.models.bart.configuration_bart import BartConfig 49 | 50 | 51 | logger = logging.get_logger(__name__) 52 | 53 | _CONFIG_FOR_DOC = "BartConfig" 54 | _TOKENIZER_FOR_DOC = "BartTokenizer" 55 | 56 | 57 | BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ 58 | "facebook/bart-large", 59 | # See all BART models at https://huggingface.co/models?filter=bart 60 | ] 61 | 62 | 63 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 64 | """ 65 | Shift input ids one token to the right. 66 | """ 67 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 68 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 69 | shifted_input_ids[:, 0] = decoder_start_token_id 70 | 71 | assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." 72 | # replace possible -100 values in labels by `pad_token_id` 73 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 74 | 75 | return shifted_input_ids 76 | 77 | 78 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): 79 | """ 80 | Make causal mask used for bi-directional self-attention. 81 | """ 82 | bsz, tgt_len = input_ids_shape 83 | mask = torch.full((tgt_len, tgt_len), float("-inf")) 84 | mask_cond = torch.arange(mask.size(-1)) 85 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 86 | mask = mask.to(dtype) 87 | 88 | if past_key_values_length > 0: 89 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) 90 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 91 | 92 | 93 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 94 | """ 95 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 96 | """ 97 | bsz, src_len = mask.size() 98 | tgt_len = tgt_len if tgt_len is not None else src_len 99 | 100 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 101 | 102 | inverted_mask = 1.0 - expanded_mask 103 | 104 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) 105 | 106 | 107 | class BartLearnedPositionalEmbedding(nn.Embedding): 108 | """ 109 | This module learns positional embeddings up to a fixed maximum size. 110 | """ 111 | 112 | def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int): 113 | assert padding_idx is not None, "`padding_idx` should not be None, but of type int" 114 | # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 115 | # and adjust num_embeddings appropriately. Other models dont have this hack 116 | self.offset = 2 117 | super().__init__(num_embeddings + self.offset, embedding_dim, padding_idx=padding_idx) 118 | 119 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): 120 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 121 | bsz, seq_len = input_ids_shape[:2] 122 | positions = torch.arange( 123 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device 124 | ) 125 | return super().forward(positions + self.offset) 126 | 127 | 128 | class BartAttention(nn.Module): 129 | """Multi-headed attention from 'Attention Is All You Need' paper""" 130 | 131 | def __init__( 132 | self, 133 | embed_dim: int, 134 | num_heads: int, 135 | dropout: float = 0.0, 136 | is_decoder: bool = False, 137 | bias: bool = True, 138 | ): 139 | super().__init__() 140 | self.embed_dim = embed_dim 141 | self.num_heads = num_heads 142 | self.dropout = dropout 143 | self.head_dim = embed_dim // num_heads 144 | assert ( 145 | self.head_dim * num_heads == self.embed_dim 146 | ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {num_heads})." 147 | self.scaling = self.head_dim ** -0.5 148 | self.is_decoder = is_decoder 149 | 150 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 151 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 152 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 153 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 154 | 155 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 156 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 157 | 158 | def forward( 159 | self, 160 | hidden_states: torch.Tensor, 161 | key_value_states: Optional[torch.Tensor] = None, 162 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 163 | attention_mask: Optional[torch.Tensor] = None, 164 | layer_head_mask: Optional[torch.Tensor] = None, 165 | output_attentions: bool = False, 166 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 167 | """Input shape: Batch x Time x Channel""" 168 | 169 | # if key_value_states are provided this layer is used as a cross-attention layer 170 | # for the decoder 171 | is_cross_attention = key_value_states is not None 172 | bsz, tgt_len, embed_dim = hidden_states.size() 173 | 174 | # get query proj 175 | query_states = self.q_proj(hidden_states) * self.scaling 176 | # get key, value proj 177 | if is_cross_attention and past_key_value is not None: 178 | # reuse k,v, cross_attentions 179 | key_states = past_key_value[0] 180 | value_states = past_key_value[1] 181 | elif is_cross_attention: 182 | # cross_attentions 183 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 184 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 185 | elif past_key_value is not None: 186 | # reuse k, v, self_attention 187 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 188 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 189 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 190 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 191 | else: 192 | # self_attention 193 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 194 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 195 | 196 | if self.is_decoder: 197 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 198 | # Further calls to cross_attention layer can then reuse all cross-attention 199 | # key/value_states (first "if" case) 200 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 201 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 202 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 203 | # if encoder bi-directional self-attention `past_key_value` is always `None` 204 | past_key_value = (key_states, value_states) 205 | 206 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 207 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 208 | key_states = key_states.view(*proj_shape) 209 | value_states = value_states.view(*proj_shape) 210 | 211 | src_len = key_states.size(1) 212 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 213 | 214 | assert attn_weights.size() == ( 215 | bsz * self.num_heads, 216 | tgt_len, 217 | src_len, 218 | ), f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 219 | 220 | if attention_mask is not None: 221 | assert attention_mask.size() == ( 222 | bsz, 223 | 1, 224 | tgt_len, 225 | src_len, 226 | ), f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 227 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 228 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 229 | 230 | attn_weights = F.softmax(attn_weights, dim=-1) 231 | 232 | if layer_head_mask is not None: 233 | assert layer_head_mask.size() == ( 234 | self.num_heads, 235 | ), f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" 236 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 237 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 238 | 239 | if output_attentions: 240 | # this operation is a bit akward, but it's required to 241 | # make sure that attn_weights keeps its gradient. 242 | # In order to do so, attn_weights have to reshaped 243 | # twice and have to be reused in the following 244 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 245 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 246 | else: 247 | attn_weights_reshaped = None 248 | 249 | attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) 250 | 251 | attn_output = torch.bmm(attn_probs, value_states) 252 | 253 | assert attn_output.size() == ( 254 | bsz * self.num_heads, 255 | tgt_len, 256 | self.head_dim, 257 | ), f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 258 | 259 | attn_output = ( 260 | attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 261 | .transpose(1, 2) 262 | .reshape(bsz, tgt_len, embed_dim) 263 | ) 264 | 265 | attn_output = self.out_proj(attn_output) 266 | 267 | return attn_output, attn_weights_reshaped, past_key_value 268 | 269 | 270 | class BartEncoderLayer(nn.Module): 271 | def __init__(self, config: BartConfig): 272 | super().__init__() 273 | self.embed_dim = config.d_model 274 | self.self_attn = BartAttention( 275 | embed_dim=self.embed_dim, 276 | num_heads=config.encoder_attention_heads, 277 | dropout=config.attention_dropout, 278 | ) 279 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 280 | self.dropout = config.dropout 281 | self.activation_fn = ACT2FN[config.activation_function] 282 | self.activation_dropout = config.activation_dropout 283 | self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) 284 | self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) 285 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 286 | 287 | def forward( 288 | self, 289 | hidden_states: torch.Tensor, 290 | attention_mask: torch.Tensor, 291 | layer_head_mask: torch.Tensor, 292 | output_attentions: bool = False, 293 | ): 294 | """ 295 | Args: 296 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` 297 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size 298 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 299 | layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size 300 | `(config.encoder_attention_heads,)`. 301 | output_attentions (:obj:`bool`, `optional`): 302 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 303 | returned tensors for more detail. 304 | """ 305 | residual = hidden_states 306 | hidden_states, attn_weights, _ = self.self_attn( 307 | hidden_states=hidden_states, 308 | attention_mask=attention_mask, 309 | layer_head_mask=layer_head_mask, 310 | output_attentions=output_attentions, 311 | ) 312 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 313 | hidden_states = residual + hidden_states 314 | hidden_states = self.self_attn_layer_norm(hidden_states) 315 | 316 | residual = hidden_states 317 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 318 | hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) 319 | hidden_states = self.fc2(hidden_states) 320 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 321 | hidden_states = residual + hidden_states 322 | hidden_states = self.final_layer_norm(hidden_states) 323 | 324 | if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any(): 325 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 326 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 327 | 328 | outputs = (hidden_states,) 329 | 330 | if output_attentions: 331 | outputs += (attn_weights,) 332 | 333 | return outputs 334 | 335 | 336 | class BartDecoderLayer(nn.Module): 337 | def __init__(self, config: BartConfig): 338 | super().__init__() 339 | self.embed_dim = config.d_model 340 | 341 | self.self_attn = BartAttention( 342 | embed_dim=self.embed_dim, 343 | num_heads=config.decoder_attention_heads, 344 | dropout=config.attention_dropout, 345 | is_decoder=True, 346 | ) 347 | self.dropout = config.dropout 348 | self.activation_fn = ACT2FN[config.activation_function] 349 | self.activation_dropout = config.activation_dropout 350 | 351 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 352 | self.encoder_attn = BartAttention( 353 | self.embed_dim, 354 | config.decoder_attention_heads, 355 | dropout=config.attention_dropout, 356 | is_decoder=True, 357 | ) 358 | self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) 359 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 360 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 361 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 362 | 363 | def forward( 364 | self, 365 | hidden_states: torch.Tensor, 366 | attention_mask: Optional[torch.Tensor] = None, 367 | encoder_hidden_states: Optional[torch.Tensor] = None, 368 | encoder_attention_mask: Optional[torch.Tensor] = None, 369 | layer_head_mask: Optional[torch.Tensor] = None, 370 | encoder_layer_head_mask: Optional[torch.Tensor] = None, 371 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 372 | output_attentions: Optional[bool] = False, 373 | use_cache: Optional[bool] = True, 374 | ): 375 | """ 376 | Args: 377 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` 378 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size 379 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 380 | encoder_hidden_states (:obj:`torch.FloatTensor`): cross attention input to the layer of shape `(seq_len, batch, embed_dim)` 381 | encoder_attention_mask (:obj:`torch.FloatTensor`): encoder attention mask of size 382 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 383 | layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size 384 | `(config.encoder_attention_heads,)`. 385 | encoder_layer_head_mask (:obj:`torch.FloatTensor`): mask for encoder attention heads in a given layer of 386 | size `(config.encoder_attention_heads,)`. 387 | past_key_value (:obj:`Tuple(torch.FloatTensor)`): cached past key and value projection states 388 | output_attentions (:obj:`bool`, `optional`): 389 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 390 | returned tensors for more detail. 391 | """ 392 | residual = hidden_states 393 | 394 | # Self Attention 395 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 396 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 397 | # add present self-attn cache to positions 1,2 of present_key_value tuple 398 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 399 | hidden_states=hidden_states, 400 | past_key_value=self_attn_past_key_value, 401 | attention_mask=attention_mask, 402 | layer_head_mask=layer_head_mask, 403 | output_attentions=output_attentions, 404 | ) 405 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 406 | hidden_states = residual + hidden_states 407 | hidden_states = self.self_attn_layer_norm(hidden_states) 408 | 409 | # Cross-Attention Block 410 | cross_attn_present_key_value = None 411 | cross_attn_weights = None 412 | if encoder_hidden_states is not None: 413 | residual = hidden_states 414 | 415 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 416 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 417 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 418 | hidden_states=hidden_states, 419 | key_value_states=encoder_hidden_states, 420 | attention_mask=encoder_attention_mask, 421 | layer_head_mask=encoder_layer_head_mask, 422 | past_key_value=cross_attn_past_key_value, 423 | output_attentions=output_attentions, 424 | ) 425 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 426 | hidden_states = residual + hidden_states 427 | hidden_states = self.encoder_attn_layer_norm(hidden_states) 428 | 429 | # add cross-attn to positions 3,4 of present_key_value tuple 430 | present_key_value = present_key_value + cross_attn_present_key_value 431 | 432 | # Fully Connected 433 | residual = hidden_states 434 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 435 | hidden_states = F.dropout(hidden_states, p=self.activation_dropout, training=self.training) 436 | hidden_states = self.fc2(hidden_states) 437 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 438 | hidden_states = residual + hidden_states 439 | hidden_states = self.final_layer_norm(hidden_states) 440 | 441 | outputs = (hidden_states,) 442 | 443 | if output_attentions: 444 | outputs += (self_attn_weights, cross_attn_weights) 445 | 446 | if use_cache: 447 | outputs += (present_key_value,) 448 | 449 | return outputs 450 | 451 | 452 | class BartClassificationHead(nn.Module): 453 | """Head for sentence-level classification tasks.""" 454 | 455 | def __init__( 456 | self, 457 | input_dim: int, 458 | inner_dim: int, 459 | num_classes: int, 460 | pooler_dropout: float, 461 | ): 462 | super().__init__() 463 | self.dense = nn.Linear(input_dim, inner_dim) 464 | self.dropout = nn.Dropout(p=pooler_dropout) 465 | self.out_proj = nn.Linear(inner_dim, num_classes) 466 | 467 | def forward(self, hidden_states: torch.Tensor): 468 | hidden_states = self.dropout(hidden_states) 469 | hidden_states = self.dense(hidden_states) 470 | hidden_states = torch.tanh(hidden_states) 471 | hidden_states = self.dropout(hidden_states) 472 | hidden_states = self.out_proj(hidden_states) 473 | return hidden_states 474 | 475 | 476 | class BartPretrainedModel(PreTrainedModel): 477 | config_class = BartConfig 478 | base_model_prefix = "model" 479 | 480 | def _init_weights(self, module): 481 | std = self.config.init_std 482 | if isinstance(module, nn.Linear): 483 | module.weight.data.normal_(mean=0.0, std=std) 484 | if module.bias is not None: 485 | module.bias.data.zero_() 486 | elif isinstance(module, nn.Embedding): 487 | module.weight.data.normal_(mean=0.0, std=std) 488 | if module.padding_idx is not None: 489 | module.weight.data[module.padding_idx].zero_() 490 | 491 | @property 492 | def dummy_inputs(self): 493 | pad_token = self.config.pad_token_id 494 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 495 | dummy_inputs = { 496 | "attention_mask": input_ids.ne(pad_token), 497 | "input_ids": input_ids, 498 | } 499 | return dummy_inputs 500 | 501 | 502 | class PretrainedBartModel(BartPretrainedModel): 503 | def __init_subclass__(self): 504 | warnings.warn( 505 | "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", 506 | FutureWarning, 507 | ) 508 | 509 | 510 | BART_START_DOCSTRING = r""" 511 | This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic 512 | methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, 513 | pruning heads etc.) 514 | 515 | This model is also a PyTorch `torch.nn.Module `__ 516 | subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to 517 | general usage and behavior. 518 | 519 | Parameters: 520 | config (:class:`~transformers.BartConfig`): 521 | Model configuration class with all the parameters of the model. Initializing with a config file does not 522 | load the weights associated with the model, only the configuration. Check out the 523 | :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. 524 | """ 525 | 526 | BART_GENERATION_EXAMPLE = r""" 527 | Summarization example:: 528 | 529 | >>> from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig 530 | 531 | >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') 532 | >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 533 | 534 | >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs." 535 | >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt') 536 | 537 | >>> # Generate Summary 538 | >>> summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True) 539 | >>> print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids]) 540 | 541 | Mask filling example:: 542 | 543 | >>> from transformers import BartTokenizer, BartForConditionalGeneration 544 | >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 545 | >>> TXT = "My friends are but they eat too many carbs." 546 | 547 | >>> model = BartForConditionalGeneration.from_pretrained('facebook/bart-large') 548 | >>> input_ids = tokenizer([TXT], return_tensors='pt')['input_ids'] 549 | >>> logits = model(input_ids).logits 550 | 551 | >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() 552 | >>> probs = logits[0, masked_index].softmax(dim=0) 553 | >>> values, predictions = probs.topk(5) 554 | 555 | >>> tokenizer.decode(predictions).split() 556 | """ 557 | 558 | BART_INPUTS_DOCSTRING = r""" 559 | Args: 560 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 561 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 562 | it. 563 | 564 | Indices can be obtained using :class:`~transformers.BartTokenizer`. See 565 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 566 | details. 567 | 568 | `What are input IDs? <../glossary.html#input-ids>`__ 569 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 570 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 571 | 572 | - 1 for tokens that are **not masked**, 573 | - 0 for tokens that are **masked**. 574 | 575 | `What are attention masks? <../glossary.html#attention-mask>`__ 576 | decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): 577 | Indices of decoder input sequence tokens in the vocabulary. 578 | 579 | Indices can be obtained using :class:`~transformers.BartTokenizer`. See 580 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for 581 | details. 582 | 583 | `What are input IDs? <../glossary.html#input-ids>`__ 584 | 585 | Bart uses the :obj:`eos_token_id` as the starting token for :obj:`decoder_input_ids` generation. If 586 | :obj:`past_key_values` is used, optionally only the last :obj:`decoder_input_ids` have to be input (see 587 | :obj:`past_key_values`). 588 | 589 | For translation and summarization training, :obj:`decoder_input_ids` should be provided. If no 590 | :obj:`decoder_input_ids` is provided, the model will create this tensor by shifting the :obj:`input_ids` to 591 | the right for denoising pre-training following the paper. 592 | decoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`): 593 | Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will 594 | also be used by default. 595 | 596 | If you want to change padding behavior, you should read :func:`modeling_bart._prepare_decoder_inputs` and 597 | modify to your needs. See diagram 1 in `the paper `__ for more 598 | information on the default strategy. 599 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 600 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in ``[0, 1]``: 601 | 602 | - 1 indicates the head is **not masked**, 603 | - 0 indicates the heas is **masked**. 604 | 605 | decoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 606 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in ``[0, 1]``: 607 | 608 | - 1 indicates the head is **not masked**, 609 | - 0 indicates the head is **masked**. 610 | 611 | encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`): 612 | Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: 613 | :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, 614 | `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the 615 | cross-attention of the decoder. 616 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 617 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. 618 | 619 | If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` 620 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 621 | instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. 622 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 623 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. 624 | This is useful if you want more control over how to convert :obj:`input_ids` indices into associated 625 | vectors than the model's internal embedding lookup matrix. 626 | decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`): 627 | Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded 628 | representation. If :obj:`past_key_values` is used, optionally only the last :obj:`decoder_inputs_embeds` 629 | have to be input (see :obj:`past_key_values`). This is useful if you want more control over how to convert 630 | :obj:`decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 631 | 632 | If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both unset, :obj:`decoder_inputs_embeds` 633 | takes the value of :obj:`inputs_embeds`. 634 | use_cache (:obj:`bool`, `optional`): 635 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 636 | decoding (see :obj:`past_key_values`). 637 | output_attentions (:obj:`bool`, `optional`): 638 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned 639 | tensors for more detail. 640 | output_hidden_states (:obj:`bool`, `optional`): 641 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for 642 | more detail. 643 | return_dict (:obj:`bool`, `optional`): 644 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 645 | """ 646 | 647 | 648 | class BartEncoder(BartPretrainedModel): 649 | """ 650 | Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a 651 | :class:`BartEncoderLayer`. 652 | 653 | Args: 654 | config: BartConfig 655 | embed_tokens (torch.nn.Embedding): output embedding 656 | """ 657 | 658 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): 659 | super().__init__(config) 660 | 661 | self.dropout = config.dropout 662 | self.layerdrop = config.encoder_layerdrop 663 | 664 | embed_dim = config.d_model 665 | self.padding_idx = config.pad_token_id 666 | self.max_source_positions = config.max_position_embeddings 667 | self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 668 | 669 | if embed_tokens is not None: 670 | self.embed_tokens = embed_tokens 671 | else: 672 | self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) 673 | self.embed_passages = nn.Embedding(3, embed_dim, 0) 674 | self.embed_positions = BartLearnedPositionalEmbedding( 675 | config.max_position_embeddings, 676 | embed_dim, 677 | self.padding_idx, 678 | ) 679 | 680 | self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) 681 | self.layernorm_embedding = nn.LayerNorm(embed_dim) 682 | 683 | self.init_weights() 684 | 685 | def forward( 686 | self, 687 | input_ids=None, 688 | attention_mask=None, 689 | passage_ids=None, 690 | head_mask=None, 691 | inputs_embeds=None, 692 | output_attentions=None, 693 | output_hidden_states=None, 694 | return_dict=None, 695 | ): 696 | r""" 697 | Args: 698 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 699 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 700 | provide it. 701 | 702 | Indices can be obtained using :class:`~transformers.BartTokenizer`. See 703 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` 704 | for details. 705 | 706 | `What are input IDs? <../glossary.html#input-ids>`__ 707 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 708 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 709 | 710 | - 1 for tokens that are **not masked**, 711 | - 0 for tokens that are **masked**. 712 | 713 | `What are attention masks? <../glossary.html#attention-mask>`__ 714 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 715 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: 716 | 717 | - 1 indicates the head is **not masked**, 718 | - 0 indicates the heas is **masked**. 719 | 720 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 721 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded 722 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices 723 | into associated vectors than the model's internal embedding lookup matrix. 724 | output_attentions (:obj:`bool`, `optional`): 725 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 726 | returned tensors for more detail. 727 | output_hidden_states (:obj:`bool`, `optional`): 728 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors 729 | for more detail. 730 | return_dict (:obj:`bool`, `optional`): 731 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 732 | """ 733 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 734 | output_hidden_states = ( 735 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 736 | ) 737 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 738 | 739 | # retrieve input_ids and inputs_embeds 740 | if input_ids is not None and inputs_embeds is not None: 741 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 742 | elif input_ids is not None: 743 | input_shape = input_ids.size() 744 | input_ids = input_ids.view(-1, input_shape[-1]) 745 | elif inputs_embeds is not None: 746 | input_shape = inputs_embeds.size()[:-1] 747 | else: 748 | raise ValueError("You have to specify either input_ids or inputs_embeds") 749 | 750 | if inputs_embeds is None: 751 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 752 | passage_embeds = self.embed_passages(passage_ids) 753 | embed_pos = self.embed_positions(input_shape) 754 | 755 | hidden_states = inputs_embeds + passage_embeds + embed_pos 756 | hidden_states = self.layernorm_embedding(hidden_states) 757 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 758 | 759 | # expand attention_mask 760 | if attention_mask is not None: 761 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 762 | attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) 763 | 764 | encoder_states = () if output_hidden_states else None 765 | all_attentions = () if output_attentions else None 766 | 767 | # check if head_mask has a correct number of layers specified if desired 768 | if head_mask is not None: 769 | assert head_mask.size()[0] == ( 770 | len(self.layers) 771 | ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 772 | for idx, encoder_layer in enumerate(self.layers): 773 | if output_hidden_states: 774 | encoder_states = encoder_states + (hidden_states,) 775 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 776 | dropout_probability = random.uniform(0, 1) 777 | if self.training and (dropout_probability < self.layerdrop): # skip the layer 778 | layer_outputs = (None, None) 779 | else: 780 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 781 | 782 | def create_custom_forward(module): 783 | def custom_forward(*inputs): 784 | return module(*inputs, output_attentions) 785 | 786 | return custom_forward 787 | 788 | layer_outputs = torch.utils.checkpoint.checkpoint( 789 | create_custom_forward(encoder_layer), 790 | hidden_states, 791 | attention_mask, 792 | (head_mask[idx] if head_mask is not None else None), 793 | ) 794 | else: 795 | layer_outputs = encoder_layer( 796 | hidden_states, 797 | attention_mask, 798 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 799 | output_attentions=output_attentions, 800 | ) 801 | 802 | hidden_states = layer_outputs[0] 803 | 804 | if output_attentions: 805 | all_attentions = all_attentions + (layer_outputs[1],) 806 | 807 | if output_hidden_states: 808 | encoder_states = encoder_states + (hidden_states,) 809 | 810 | if not return_dict: 811 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) 812 | return BaseModelOutput( 813 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions 814 | ) 815 | 816 | 817 | class BartDecoder(BartPretrainedModel): 818 | """ 819 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`BartDecoderLayer` 820 | 821 | Args: 822 | config: BartConfig 823 | embed_tokens (torch.nn.Embedding): output embedding 824 | """ 825 | 826 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): 827 | super().__init__(config) 828 | self.dropout = config.dropout 829 | self.layerdrop = config.decoder_layerdrop 830 | self.padding_idx = config.pad_token_id 831 | self.max_target_positions = config.max_position_embeddings 832 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 833 | 834 | if embed_tokens is not None: 835 | self.embed_tokens = embed_tokens 836 | else: 837 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) 838 | 839 | self.embed_positions = BartLearnedPositionalEmbedding( 840 | config.max_position_embeddings, 841 | config.d_model, 842 | self.padding_idx, 843 | ) 844 | self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) 845 | self.layernorm_embedding = nn.LayerNorm(config.d_model) 846 | 847 | self.init_weights() 848 | 849 | def get_input_embeddings(self): 850 | return self.embed_tokens 851 | 852 | def set_input_embeddings(self, value): 853 | self.embed_tokens = value 854 | 855 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 856 | # create causal mask 857 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 858 | combined_attention_mask = None 859 | if input_shape[-1] > 1: 860 | combined_attention_mask = _make_causal_mask( 861 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 862 | ).to(self.device) 863 | 864 | if attention_mask is not None: 865 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 866 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 867 | combined_attention_mask = ( 868 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 869 | ) 870 | 871 | return combined_attention_mask 872 | 873 | def forward( 874 | self, 875 | input_ids=None, 876 | attention_mask=None, 877 | encoder_hidden_states=None, 878 | encoder_attention_mask=None, 879 | head_mask=None, 880 | encoder_head_mask=None, 881 | past_key_values=None, 882 | inputs_embeds=None, 883 | use_cache=None, 884 | output_attentions=None, 885 | output_hidden_states=None, 886 | return_dict=None, 887 | ): 888 | r""" 889 | Args: 890 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 891 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 892 | provide it. 893 | 894 | Indices can be obtained using :class:`~transformers.BartTokenizer`. See 895 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` 896 | for details. 897 | 898 | `What are input IDs? <../glossary.html#input-ids>`__ 899 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 900 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 901 | 902 | - 1 for tokens that are **not masked**, 903 | - 0 for tokens that are **masked**. 904 | 905 | `What are attention masks? <../glossary.html#attention-mask>`__ 906 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): 907 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 908 | of the decoder. 909 | encoder_attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, encoder_sequence_length)`, `optional`): 910 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 911 | selected in ``[0, 1]``: 912 | 913 | - 1 for tokens that are **not masked**, 914 | - 0 for tokens that are **masked**. 915 | 916 | `What are attention masks? <../glossary.html#attention-mask>`__ 917 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 918 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: 919 | 920 | - 1 indicates the head is **not masked**, 921 | - 0 indicates the heas is **masked**. 922 | 923 | encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 924 | Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention 925 | on hidden heads. Mask values selected in ``[0, 1]``: 926 | 927 | - 1 indicates the head is **not masked**, 928 | - 0 indicates the heas is **masked**. 929 | 930 | past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 931 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up 932 | decoding. 933 | 934 | If :obj:`past_key_values` are used, the user can optionally input only the last 935 | :obj:`decoder_input_ids` (those that don't have their past key value states given to this model) of 936 | shape :obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids`` of shape :obj:`(batch_size, 937 | sequence_length)`. 938 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 939 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded 940 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices 941 | into associated vectors than the model's internal embedding lookup matrix. 942 | output_attentions (:obj:`bool`, `optional`): 943 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 944 | returned tensors for more detail. 945 | output_hidden_states (:obj:`bool`, `optional`): 946 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors 947 | for more detail. 948 | return_dict (:obj:`bool`, `optional`): 949 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 950 | """ 951 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 952 | output_hidden_states = ( 953 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 954 | ) 955 | use_cache = use_cache if use_cache is not None else self.config.use_cache 956 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 957 | 958 | # retrieve input_ids and inputs_embeds 959 | if input_ids is not None and inputs_embeds is not None: 960 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 961 | elif input_ids is not None: 962 | input_shape = input_ids.size() 963 | input_ids = input_ids.view(-1, input_shape[-1]) 964 | elif inputs_embeds is not None: 965 | input_shape = inputs_embeds.size()[:-1] 966 | else: 967 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 968 | 969 | # past_key_values_length 970 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 971 | 972 | if inputs_embeds is None: 973 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 974 | 975 | attention_mask = self._prepare_decoder_attention_mask( 976 | attention_mask, input_shape, inputs_embeds, past_key_values_length 977 | ) 978 | 979 | # expand encoder attention mask 980 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 981 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 982 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 983 | 984 | # embed positions 985 | positions = self.embed_positions(input_shape, past_key_values_length) 986 | 987 | hidden_states = inputs_embeds + positions 988 | hidden_states = self.layernorm_embedding(hidden_states) 989 | 990 | hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) 991 | 992 | # decoder layers 993 | all_hidden_states = () if output_hidden_states else None 994 | all_self_attns = () if output_attentions else None 995 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 996 | next_decoder_cache = () if use_cache else None 997 | 998 | # check if head_mask has a correct number of layers specified if desired 999 | if head_mask is not None: 1000 | assert head_mask.size()[0] == ( 1001 | len(self.layers) 1002 | ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 1003 | for idx, decoder_layer in enumerate(self.layers): 1004 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 1005 | if output_hidden_states: 1006 | all_hidden_states += (hidden_states,) 1007 | dropout_probability = random.uniform(0, 1) 1008 | if self.training and (dropout_probability < self.layerdrop): 1009 | continue 1010 | 1011 | past_key_value = past_key_values[idx] if past_key_values is not None else None 1012 | 1013 | if getattr(self.config, "gradient_checkpointing", False) and self.training: 1014 | 1015 | if use_cache: 1016 | logger.warn( 1017 | "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " 1018 | "`use_cache=False`..." 1019 | ) 1020 | use_cache = False 1021 | 1022 | def create_custom_forward(module): 1023 | def custom_forward(*inputs): 1024 | # None for past_key_value 1025 | return module(*inputs, output_attentions, use_cache) 1026 | 1027 | return custom_forward 1028 | 1029 | layer_outputs = torch.utils.checkpoint.checkpoint( 1030 | create_custom_forward(decoder_layer), 1031 | hidden_states, 1032 | attention_mask, 1033 | encoder_hidden_states, 1034 | encoder_attention_mask, 1035 | head_mask[idx] if head_mask is not None else None, 1036 | encoder_head_mask[idx] if encoder_head_mask is not None else None, 1037 | None, 1038 | ) 1039 | else: 1040 | 1041 | layer_outputs = decoder_layer( 1042 | hidden_states, 1043 | attention_mask=attention_mask, 1044 | encoder_hidden_states=encoder_hidden_states, 1045 | encoder_attention_mask=encoder_attention_mask, 1046 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 1047 | encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), 1048 | past_key_value=past_key_value, 1049 | output_attentions=output_attentions, 1050 | use_cache=use_cache, 1051 | ) 1052 | hidden_states = layer_outputs[0] 1053 | 1054 | if use_cache: 1055 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 1056 | 1057 | if output_attentions: 1058 | all_self_attns += (layer_outputs[1],) 1059 | 1060 | if encoder_hidden_states is not None: 1061 | all_cross_attentions += (layer_outputs[2],) 1062 | 1063 | # add hidden states from the last decoder layer 1064 | if output_hidden_states: 1065 | all_hidden_states += (hidden_states,) 1066 | 1067 | next_cache = next_decoder_cache if use_cache else None 1068 | if not return_dict: 1069 | return tuple( 1070 | v 1071 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 1072 | if v is not None 1073 | ) 1074 | return BaseModelOutputWithPastAndCrossAttentions( 1075 | last_hidden_state=hidden_states, 1076 | past_key_values=next_cache, 1077 | hidden_states=all_hidden_states, 1078 | attentions=all_self_attns, 1079 | cross_attentions=all_cross_attentions, 1080 | ) 1081 | 1082 | 1083 | @add_start_docstrings( 1084 | "The bare BART Model outputting raw hidden-states without any specific head on top.", 1085 | BART_START_DOCSTRING, 1086 | ) 1087 | class BartModel(BartPretrainedModel): 1088 | def __init__(self, config: BartConfig): 1089 | super().__init__(config) 1090 | 1091 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size 1092 | self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) 1093 | 1094 | self.encoder = BartEncoder(config, self.shared) 1095 | self.decoder = BartDecoder(config, self.shared) 1096 | 1097 | self.init_weights() 1098 | 1099 | def get_input_embeddings(self): 1100 | return self.shared 1101 | 1102 | def set_input_embeddings(self, value): 1103 | self.shared = value 1104 | self.encoder.embed_tokens = self.shared 1105 | self.decoder.embed_tokens = self.shared 1106 | 1107 | def get_encoder(self): 1108 | return self.encoder 1109 | 1110 | def get_decoder(self): 1111 | return self.decoder 1112 | 1113 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1114 | @add_code_sample_docstrings( 1115 | tokenizer_class=_TOKENIZER_FOR_DOC, 1116 | checkpoint="facebook/bart-large", 1117 | output_type=Seq2SeqModelOutput, 1118 | config_class=_CONFIG_FOR_DOC, 1119 | ) 1120 | def forward( 1121 | self, 1122 | input_ids=None, 1123 | attention_mask=None, 1124 | passage_ids=None, 1125 | decoder_input_ids=None, 1126 | decoder_attention_mask=None, 1127 | head_mask=None, 1128 | decoder_head_mask=None, 1129 | encoder_outputs=None, 1130 | past_key_values=None, 1131 | inputs_embeds=None, 1132 | decoder_inputs_embeds=None, 1133 | use_cache=None, 1134 | output_attentions=None, 1135 | output_hidden_states=None, 1136 | return_dict=None, 1137 | ): 1138 | 1139 | # different to other models, Bart automatically creates decoder_input_ids from 1140 | # input_ids if no decoder_input_ids are provided 1141 | if decoder_input_ids is None and decoder_inputs_embeds is None: 1142 | decoder_input_ids = shift_tokens_right( 1143 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id 1144 | ) 1145 | 1146 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1147 | output_hidden_states = ( 1148 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1149 | ) 1150 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1151 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1152 | 1153 | if encoder_outputs is None: 1154 | encoder_outputs = self.encoder( 1155 | input_ids=input_ids, 1156 | attention_mask=attention_mask, 1157 | passage_ids=passage_ids, 1158 | head_mask=head_mask, 1159 | inputs_embeds=inputs_embeds, 1160 | output_attentions=output_attentions, 1161 | output_hidden_states=output_hidden_states, 1162 | return_dict=return_dict, 1163 | ) 1164 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 1165 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 1166 | encoder_outputs = BaseModelOutput( 1167 | last_hidden_state=encoder_outputs[0], 1168 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 1169 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 1170 | ) 1171 | 1172 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 1173 | decoder_outputs = self.decoder( 1174 | input_ids=decoder_input_ids, 1175 | attention_mask=decoder_attention_mask, 1176 | encoder_hidden_states=encoder_outputs[0], 1177 | encoder_attention_mask=attention_mask, 1178 | head_mask=decoder_head_mask, 1179 | encoder_head_mask=head_mask, 1180 | past_key_values=past_key_values, 1181 | inputs_embeds=decoder_inputs_embeds, 1182 | use_cache=use_cache, 1183 | output_attentions=output_attentions, 1184 | output_hidden_states=output_hidden_states, 1185 | return_dict=return_dict, 1186 | ) 1187 | 1188 | if not return_dict: 1189 | return decoder_outputs + encoder_outputs 1190 | 1191 | return Seq2SeqModelOutput( 1192 | last_hidden_state=decoder_outputs.last_hidden_state, 1193 | past_key_values=decoder_outputs.past_key_values, 1194 | decoder_hidden_states=decoder_outputs.hidden_states, 1195 | decoder_attentions=decoder_outputs.attentions, 1196 | cross_attentions=decoder_outputs.cross_attentions, 1197 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 1198 | encoder_hidden_states=encoder_outputs.hidden_states, 1199 | encoder_attentions=encoder_outputs.attentions, 1200 | ) 1201 | 1202 | 1203 | @add_start_docstrings( 1204 | "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING 1205 | ) 1206 | class BartForConditionalGeneration(BartPretrainedModel): 1207 | base_model_prefix = "model" 1208 | _keys_to_ignore_on_load_missing = [ 1209 | r"final_logits_bias", 1210 | r"encoder\.version", 1211 | r"decoder\.version", 1212 | r"lm_head\.weight", 1213 | ] 1214 | 1215 | def __init__(self, config: BartConfig): 1216 | super().__init__(config) 1217 | self.model = BartModel(config) 1218 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 1219 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 1220 | 1221 | self.init_weights() 1222 | 1223 | def get_encoder(self): 1224 | return self.model.get_encoder() 1225 | 1226 | def get_decoder(self): 1227 | return self.model.get_decoder() 1228 | 1229 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: 1230 | new_embeddings = super().resize_token_embeddings(new_num_tokens) 1231 | self._resize_final_logits_bias(new_num_tokens) 1232 | return new_embeddings 1233 | 1234 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None: 1235 | old_num_tokens = self.final_logits_bias.shape[-1] 1236 | if new_num_tokens <= old_num_tokens: 1237 | new_bias = self.final_logits_bias[:, :new_num_tokens] 1238 | else: 1239 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 1240 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 1241 | self.register_buffer("final_logits_bias", new_bias) 1242 | 1243 | def get_output_embeddings(self): 1244 | return self.lm_head 1245 | 1246 | def set_output_embeddings(self, new_embeddings): 1247 | self.lm_head = new_embeddings 1248 | 1249 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1250 | @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 1251 | @add_end_docstrings(BART_GENERATION_EXAMPLE) 1252 | def forward( 1253 | self, 1254 | input_ids=None, 1255 | attention_mask=None, 1256 | passage_ids=None, 1257 | decoder_input_ids=None, 1258 | decoder_attention_mask=None, 1259 | head_mask=None, 1260 | decoder_head_mask=None, 1261 | encoder_outputs=None, 1262 | past_key_values=None, 1263 | inputs_embeds=None, 1264 | decoder_inputs_embeds=None, 1265 | labels=None, 1266 | use_cache=None, 1267 | output_attentions=None, 1268 | output_hidden_states=None, 1269 | return_dict=None, 1270 | ): 1271 | r""" 1272 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1273 | Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., 1274 | config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are ignored 1275 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``. 1276 | 1277 | Returns: 1278 | """ 1279 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1280 | 1281 | if labels is not None: 1282 | if decoder_input_ids is None: 1283 | decoder_input_ids = shift_tokens_right( 1284 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 1285 | ) 1286 | 1287 | outputs = self.model( 1288 | input_ids, 1289 | attention_mask=attention_mask, 1290 | passage_ids=passage_ids, 1291 | decoder_input_ids=decoder_input_ids, 1292 | encoder_outputs=encoder_outputs, 1293 | decoder_attention_mask=decoder_attention_mask, 1294 | head_mask=head_mask, 1295 | decoder_head_mask=decoder_head_mask, 1296 | past_key_values=past_key_values, 1297 | inputs_embeds=inputs_embeds, 1298 | decoder_inputs_embeds=decoder_inputs_embeds, 1299 | use_cache=use_cache, 1300 | output_attentions=output_attentions, 1301 | output_hidden_states=output_hidden_states, 1302 | return_dict=return_dict, 1303 | ) 1304 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 1305 | 1306 | masked_lm_loss = None 1307 | if labels is not None: 1308 | loss_fct = CrossEntropyLoss() 1309 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 1310 | 1311 | if not return_dict: 1312 | output = (lm_logits,) + outputs[1:] 1313 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1314 | 1315 | return Seq2SeqLMOutput( 1316 | loss=masked_lm_loss, 1317 | logits=lm_logits, 1318 | past_key_values=outputs.past_key_values, 1319 | decoder_hidden_states=outputs.decoder_hidden_states, 1320 | decoder_attentions=outputs.decoder_attentions, 1321 | cross_attentions=outputs.cross_attentions, 1322 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1323 | encoder_hidden_states=outputs.encoder_hidden_states, 1324 | encoder_attentions=outputs.encoder_attentions, 1325 | ) 1326 | 1327 | def prepare_inputs_for_generation( 1328 | self, 1329 | decoder_input_ids, 1330 | past=None, 1331 | attention_mask=None, 1332 | head_mask=None, 1333 | use_cache=None, 1334 | encoder_outputs=None, 1335 | **kwargs 1336 | ): 1337 | # cut decoder_input_ids if past is used 1338 | if past is not None: 1339 | decoder_input_ids = decoder_input_ids[:, -1:] 1340 | 1341 | return { 1342 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 1343 | "encoder_outputs": encoder_outputs, 1344 | "past_key_values": past, 1345 | "decoder_input_ids": decoder_input_ids, 1346 | "attention_mask": attention_mask, 1347 | "head_mask": head_mask, 1348 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 1349 | } 1350 | 1351 | def adjust_logits_during_generation(self, logits, cur_len, max_length): 1352 | if cur_len == 1 and self.config.force_bos_token_to_be_generated: 1353 | self._force_token_id_to_be_generated(logits, self.config.bos_token_id) 1354 | elif cur_len == max_length - 1 and self.config.eos_token_id is not None: 1355 | self._force_token_id_to_be_generated(logits, self.config.eos_token_id) 1356 | return logits 1357 | 1358 | @staticmethod 1359 | def _force_token_id_to_be_generated(scores, token_id) -> None: 1360 | """force one of token_ids to be generated by setting prob of all other tokens to 0 (logprob=-float("inf"))""" 1361 | scores[:, [x for x in range(scores.shape[1]) if x != token_id]] = -float("inf") 1362 | 1363 | @staticmethod 1364 | def _reorder_cache(past, beam_idx): 1365 | reordered_past = () 1366 | for layer_past in past: 1367 | # cached cross_attention states don't have to be reordered -> they are always the same 1368 | reordered_past += ( 1369 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], 1370 | ) 1371 | return reordered_past 1372 | 1373 | 1374 | @add_start_docstrings( 1375 | """ 1376 | Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE 1377 | tasks. 1378 | """, 1379 | BART_START_DOCSTRING, 1380 | ) 1381 | class BartForSequenceClassification(BartPretrainedModel): 1382 | def __init__(self, config: BartConfig, **kwargs): 1383 | super().__init__(config, **kwargs) 1384 | self.model = BartModel(config) 1385 | self.classification_head = BartClassificationHead( 1386 | config.d_model, 1387 | config.d_model, 1388 | config.num_labels, 1389 | config.classifier_dropout, 1390 | ) 1391 | self.model._init_weights(self.classification_head.dense) 1392 | self.model._init_weights(self.classification_head.out_proj) 1393 | 1394 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1395 | @add_code_sample_docstrings( 1396 | tokenizer_class=_TOKENIZER_FOR_DOC, 1397 | checkpoint="facebook/bart-large", 1398 | output_type=Seq2SeqSequenceClassifierOutput, 1399 | config_class=_CONFIG_FOR_DOC, 1400 | ) 1401 | def forward( 1402 | self, 1403 | input_ids=None, 1404 | attention_mask=None, 1405 | decoder_input_ids=None, 1406 | decoder_attention_mask=None, 1407 | head_mask=None, 1408 | decoder_head_mask=None, 1409 | encoder_outputs=None, 1410 | inputs_embeds=None, 1411 | decoder_inputs_embeds=None, 1412 | labels=None, 1413 | use_cache=None, 1414 | output_attentions=None, 1415 | output_hidden_states=None, 1416 | return_dict=None, 1417 | ): 1418 | r""" 1419 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1420 | Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., 1421 | config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1422 | """ 1423 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1424 | if labels is not None: 1425 | use_cache = False 1426 | 1427 | if input_ids is None and inputs_embeds is not None: 1428 | raise NotImplementedError( 1429 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}" 1430 | ) 1431 | 1432 | outputs = self.model( 1433 | input_ids, 1434 | attention_mask=attention_mask, 1435 | decoder_input_ids=decoder_input_ids, 1436 | decoder_attention_mask=decoder_attention_mask, 1437 | head_mask=head_mask, 1438 | decoder_head_mask=decoder_head_mask, 1439 | encoder_outputs=encoder_outputs, 1440 | inputs_embeds=inputs_embeds, 1441 | decoder_inputs_embeds=decoder_inputs_embeds, 1442 | use_cache=use_cache, 1443 | output_attentions=output_attentions, 1444 | output_hidden_states=output_hidden_states, 1445 | return_dict=return_dict, 1446 | ) 1447 | hidden_states = outputs[0] # last hidden state 1448 | 1449 | eos_mask = input_ids.eq(self.config.eos_token_id) 1450 | 1451 | if len(torch.unique(eos_mask.sum(1))) > 1: 1452 | raise ValueError("All examples must have the same number of tokens.") 1453 | sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ 1454 | :, -1, : 1455 | ] 1456 | logits = self.classification_head(sentence_representation) 1457 | 1458 | loss = None 1459 | if labels is not None: 1460 | loss_fct = CrossEntropyLoss() 1461 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) 1462 | 1463 | if not return_dict: 1464 | output = (logits,) + outputs[1:] 1465 | return ((loss,) + output) if loss is not None else output 1466 | 1467 | return Seq2SeqSequenceClassifierOutput( 1468 | loss=loss, 1469 | logits=logits, 1470 | past_key_values=outputs.past_key_values, 1471 | decoder_hidden_states=outputs.decoder_hidden_states, 1472 | decoder_attentions=outputs.decoder_attentions, 1473 | cross_attentions=outputs.cross_attentions, 1474 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1475 | encoder_hidden_states=outputs.encoder_hidden_states, 1476 | encoder_attentions=outputs.encoder_attentions, 1477 | ) 1478 | 1479 | 1480 | @add_start_docstrings( 1481 | """ 1482 | BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear 1483 | layer on top of the hidden-states output to compute `span start logits` and `span end logits`). 1484 | """, 1485 | BART_START_DOCSTRING, 1486 | ) 1487 | class BartForQuestionAnswering(BartPretrainedModel): 1488 | def __init__(self, config): 1489 | super().__init__(config) 1490 | 1491 | config.num_labels = 2 1492 | self.num_labels = config.num_labels 1493 | 1494 | self.model = BartModel(config) 1495 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 1496 | 1497 | self.model._init_weights(self.qa_outputs) 1498 | 1499 | @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) 1500 | @add_code_sample_docstrings( 1501 | tokenizer_class=_TOKENIZER_FOR_DOC, 1502 | checkpoint="facebook/bart-large", 1503 | output_type=Seq2SeqQuestionAnsweringModelOutput, 1504 | config_class=_CONFIG_FOR_DOC, 1505 | ) 1506 | def forward( 1507 | self, 1508 | input_ids=None, 1509 | attention_mask=None, 1510 | decoder_input_ids=None, 1511 | decoder_attention_mask=None, 1512 | head_mask=None, 1513 | decoder_head_mask=None, 1514 | encoder_outputs=None, 1515 | start_positions=None, 1516 | end_positions=None, 1517 | inputs_embeds=None, 1518 | decoder_inputs_embeds=None, 1519 | use_cache=None, 1520 | output_attentions=None, 1521 | output_hidden_states=None, 1522 | return_dict=None, 1523 | ): 1524 | r""" 1525 | start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1526 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1527 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1528 | are not taken into account for computing the loss. 1529 | end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): 1530 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1531 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence 1532 | are not taken into account for computing the loss. 1533 | """ 1534 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1535 | if start_positions is not None and end_positions is not None: 1536 | use_cache = False 1537 | 1538 | outputs = self.model( 1539 | input_ids, 1540 | attention_mask=attention_mask, 1541 | decoder_input_ids=decoder_input_ids, 1542 | decoder_attention_mask=decoder_attention_mask, 1543 | head_mask=head_mask, 1544 | decoder_head_mask=decoder_head_mask, 1545 | encoder_outputs=encoder_outputs, 1546 | inputs_embeds=inputs_embeds, 1547 | decoder_inputs_embeds=decoder_inputs_embeds, 1548 | use_cache=use_cache, 1549 | output_attentions=output_attentions, 1550 | output_hidden_states=output_hidden_states, 1551 | return_dict=return_dict, 1552 | ) 1553 | 1554 | sequence_output = outputs[0] 1555 | 1556 | logits = self.qa_outputs(sequence_output) 1557 | start_logits, end_logits = logits.split(1, dim=-1) 1558 | start_logits = start_logits.squeeze(-1) 1559 | end_logits = end_logits.squeeze(-1) 1560 | 1561 | total_loss = None 1562 | if start_positions is not None and end_positions is not None: 1563 | # If we are on multi-GPU, split add a dimension 1564 | if len(start_positions.size()) > 1: 1565 | start_positions = start_positions.squeeze(-1) 1566 | if len(end_positions.size()) > 1: 1567 | end_positions = end_positions.squeeze(-1) 1568 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1569 | ignored_index = start_logits.size(1) 1570 | start_positions.clamp_(0, ignored_index) 1571 | end_positions.clamp_(0, ignored_index) 1572 | 1573 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1574 | start_loss = loss_fct(start_logits, start_positions) 1575 | end_loss = loss_fct(end_logits, end_positions) 1576 | total_loss = (start_loss + end_loss) / 2 1577 | 1578 | if not return_dict: 1579 | output = ( 1580 | start_logits, 1581 | end_logits, 1582 | ) + outputs[1:] 1583 | return ((total_loss,) + output) if total_loss is not None else output 1584 | 1585 | return Seq2SeqQuestionAnsweringModelOutput( 1586 | loss=total_loss, 1587 | start_logits=start_logits, 1588 | end_logits=end_logits, 1589 | past_key_values=outputs.past_key_values, 1590 | decoder_hidden_states=outputs.decoder_hidden_states, 1591 | decoder_attentions=outputs.decoder_attentions, 1592 | cross_attentions=outputs.cross_attentions, 1593 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1594 | encoder_hidden_states=outputs.encoder_hidden_states, 1595 | encoder_attentions=outputs.encoder_attentions, 1596 | ) 1597 | 1598 | 1599 | class BartDecoderWrapper(BartPretrainedModel): 1600 | """ 1601 | This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is 1602 | used in combination with the :class:`~transformers.EncoderDecoderModel` framework. 1603 | """ 1604 | 1605 | def __init__(self, config): 1606 | super().__init__(config) 1607 | self.decoder = BartDecoder(config) 1608 | 1609 | def forward(self, *args, **kwargs): 1610 | return self.decoder(*args, **kwargs) 1611 | 1612 | 1613 | class BartForCausalLM(BartPretrainedModel): 1614 | def __init__(self, config): 1615 | super().__init__(config) 1616 | config = copy.deepcopy(config) 1617 | config.is_decoder = True 1618 | config.is_encoder_decoder = False 1619 | self.model = BartDecoderWrapper(config) 1620 | 1621 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1622 | 1623 | self.init_weights() 1624 | 1625 | def get_input_embeddings(self): 1626 | return self.model.decoder.embed_tokens 1627 | 1628 | def set_input_embeddings(self, value): 1629 | self.model.decoder.embed_tokens = value 1630 | 1631 | def get_output_embeddings(self): 1632 | return self.lm_head 1633 | 1634 | def set_output_embeddings(self, new_embeddings): 1635 | self.lm_head = new_embeddings 1636 | 1637 | def set_decoder(self, decoder): 1638 | self.model.decoder = decoder 1639 | 1640 | def get_decoder(self): 1641 | return self.model.decoder 1642 | 1643 | @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) 1644 | def forward( 1645 | self, 1646 | input_ids=None, 1647 | attention_mask=None, 1648 | encoder_hidden_states=None, 1649 | encoder_attention_mask=None, 1650 | head_mask=None, 1651 | encoder_head_mask=None, 1652 | past_key_values=None, 1653 | inputs_embeds=None, 1654 | labels=None, 1655 | use_cache=None, 1656 | output_attentions=None, 1657 | output_hidden_states=None, 1658 | return_dict=None, 1659 | ): 1660 | r""" 1661 | Args: 1662 | input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): 1663 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 1664 | provide it. 1665 | 1666 | Indices can be obtained using :class:`~transformers.BartTokenizer`. See 1667 | :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` 1668 | for details. 1669 | 1670 | `What are input IDs? <../glossary.html#input-ids>`__ 1671 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1672 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 1673 | 1674 | - 1 for tokens that are **not masked**, 1675 | - 0 for tokens that are **masked**. 1676 | 1677 | `What are attention masks? <../glossary.html#attention-mask>`__ 1678 | encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): 1679 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 1680 | if the model is configured as a decoder. 1681 | encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1682 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used 1683 | in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: 1684 | head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 1685 | Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``: 1686 | 1687 | - 1 indicates the head is **not masked**, 1688 | - 0 indicates the heas is **masked**. 1689 | 1690 | encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): 1691 | Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention 1692 | on hidden heads. Mask values selected in ``[0, 1]``: 1693 | 1694 | - 1 indicates the head is **not masked**, 1695 | - 0 indicates the heas is **masked**. 1696 | 1697 | past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 1698 | Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up 1699 | decoding. 1700 | 1701 | If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids`` 1702 | (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` 1703 | instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`. 1704 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1705 | Labels for computing the masked language modeling loss. Indices should either be in ``[0, ..., 1706 | config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are 1707 | ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., 1708 | config.vocab_size]``. 1709 | use_cache (:obj:`bool`, `optional`): 1710 | If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up 1711 | decoding (see :obj:`past_key_values`). 1712 | 1713 | - 1 for tokens that are **not masked**, 1714 | - 0 for tokens that are **masked**. 1715 | output_attentions (:obj:`bool`, `optional`): 1716 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 1717 | returned tensors for more detail. 1718 | output_hidden_states (:obj:`bool`, `optional`): 1719 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors 1720 | for more detail. 1721 | return_dict (:obj:`bool`, `optional`): 1722 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 1723 | 1724 | Returns: 1725 | 1726 | Example:: 1727 | 1728 | >>> from transformers import BartTokenizer, BartForCausalLM 1729 | 1730 | >>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') 1731 | >>> model = BartForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False) 1732 | >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." 1733 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 1734 | >>> outputs = model(**inputs) 1735 | 1736 | >>> last_hidden_states = outputs.last_hidden_state 1737 | """ 1738 | 1739 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1740 | output_hidden_states = ( 1741 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1742 | ) 1743 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1744 | 1745 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1746 | outputs = self.model.decoder( 1747 | input_ids=input_ids, 1748 | attention_mask=attention_mask, 1749 | encoder_hidden_states=encoder_hidden_states, 1750 | encoder_attention_mask=encoder_attention_mask, 1751 | head_mask=head_mask, 1752 | encoder_head_mask=encoder_head_mask, 1753 | past_key_values=past_key_values, 1754 | inputs_embeds=inputs_embeds, 1755 | use_cache=use_cache, 1756 | output_attentions=output_attentions, 1757 | output_hidden_states=output_hidden_states, 1758 | return_dict=return_dict, 1759 | ) 1760 | 1761 | logits = self.lm_head(outputs[0]) 1762 | 1763 | loss = None 1764 | if labels is not None: 1765 | loss_fct = CrossEntropyLoss() 1766 | loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) 1767 | 1768 | if not return_dict: 1769 | output = (logits,) + outputs[1:] 1770 | return (loss,) + output if loss is not None else output 1771 | 1772 | return CausalLMOutputWithCrossAttentions( 1773 | loss=loss, 1774 | logits=logits, 1775 | past_key_values=outputs.past_key_values, 1776 | hidden_states=outputs.hidden_states, 1777 | attentions=outputs.attentions, 1778 | cross_attentions=outputs.cross_attentions, 1779 | ) 1780 | 1781 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): 1782 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly 1783 | if attention_mask is None: 1784 | attention_mask = input_ids.new_ones(input_ids.shape) 1785 | 1786 | if past: 1787 | input_ids = input_ids[:, -1:] 1788 | # first step, decoder_cached_states are empty 1789 | return { 1790 | "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed 1791 | "attention_mask": attention_mask, 1792 | "past_key_values": past, 1793 | "use_cache": use_cache, 1794 | } 1795 | 1796 | @staticmethod 1797 | def _reorder_cache(past, beam_idx): 1798 | reordered_past = () 1799 | for layer_past in past: 1800 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 1801 | return reordered_past 1802 | --------------------------------------------------------------------------------