├── .gitignore ├── filtered_dict.pkl ├── LICENSE.md ├── qfid ├── __pycache__ │ ├── qfid.cpython-38.pyc │ └── qfid.cpython-39.pyc ├── test.sh ├── train.sh ├── run_summarization.py └── qfid.py ├── requirements.txt ├── make_summarization_csv.py ├── json_to_df.py ├── README.md └── make_section_df.py /.gitignore: -------------------------------------------------------------------------------- 1 | dataset 2 | -------------------------------------------------------------------------------- /filtered_dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tetsu9923/SciReviewGen/HEAD/filtered_dict.pkl -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | SciReviewGen is licensed under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). 2 | -------------------------------------------------------------------------------- /qfid/__pycache__/qfid.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tetsu9923/SciReviewGen/HEAD/qfid/__pycache__/qfid.cpython-38.pyc -------------------------------------------------------------------------------- /qfid/__pycache__/qfid.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tetsu9923/SciReviewGen/HEAD/qfid/__pycache__/qfid.cpython-39.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.16.0 2 | datasets==1.18.4 3 | numpy==1.22.3 4 | pandas==1.4.1 5 | torch==1.11.0 6 | filelock 7 | nltk 8 | tqdm 9 | wrapt 10 | h5py 11 | rouge_score -------------------------------------------------------------------------------- /qfid/test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 \ 2 | python run_summarization.py \ 3 | --model_name_or_path facebook/bart-large-cnn \ 4 | --do_train \ 5 | --do_predict \ 6 | --train_file ../dataset/train_qfid.csv \ 7 | --validation_file ../dataset/val_qfid.csv \ 8 | --test_file ../dataset/test_qfid.csv \ 9 | --text_column reference \ 10 | --summary_column target \ 11 | --output_dir ./test \ 12 | --per_device_train_batch_size=1 \ 13 | --per_device_eval_batch_size=1 \ 14 | --num_train_epochs=1 \ 15 | --predict_with_generate \ 16 | --save_strategy epoch \ 17 | --learning_rate 5e-05 \ 18 | --evaluation_strategy epoch \ 19 | --load_best_model_at_end \ 20 | --metric_for_best_model=eval_rouge2 \ 21 | --greater_is_better True \ 22 | --save_total_limit=1 \ 23 | --max_source_length 2048 \ 24 | --max_target_length 512 \ 25 | --generation_max_length 256 \ 26 | --num_beams 4 \ 27 | -------------------------------------------------------------------------------- /qfid/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 \ 2 | python run_summarization.py \ 3 | --model_name_or_path facebook/bart-large-cnn \ 4 | --do_train \ 5 | --do_eval \ 6 | --do_predict \ 7 | --train_file ../dataset/train_qfid.csv \ 8 | --validation_file ../dataset/val_qfid.csv \ 9 | --test_file ../dataset/test_qfid.csv \ 10 | --text_column reference \ 11 | --summary_column target \ 12 | --output_dir ./test \ 13 | --per_device_train_batch_size=1 \ 14 | --per_device_eval_batch_size=1 \ 15 | --num_train_epochs=1 \ 16 | --predict_with_generate \ 17 | --save_strategy epoch \ 18 | --learning_rate 5e-05 \ 19 | --evaluation_strategy epoch \ 20 | --load_best_model_at_end \ 21 | --metric_for_best_model=eval_rouge2 \ 22 | --greater_is_better True \ 23 | --save_total_limit=1 \ 24 | --max_source_length 2048 \ 25 | --max_target_length 512 \ 26 | --generation_max_length 256 \ 27 | --num_beams 4 \ 28 | --overwrite_output_dir \ 29 | -------------------------------------------------------------------------------- /make_summarization_csv.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import logging 4 | import pickle 5 | import argparse 6 | 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from tqdm import tqdm 11 | 12 | 13 | def make_summarization_csv(args): 14 | if args.for_qfid: 15 | logging.info('Making csv files for QFiD...') 16 | logging.info('Columns={"reference": literature review title chapter title literature review title chapter title abstract of cited paper 1 BIB001 literature review title chapter title abstract of cited paper 2 BIB002 ..., "target": literature review chapter}') 17 | else: 18 | logging.info('Making csv files for summarization...') 19 | logging.info('Columns={"reference": literature review title chapter title abstract of cited paper 1 BIB001 literature review title chapter title abstract of cited paper 2 BIB002 ..., "target": literature review chapter}') 20 | section_df = pd.read_pickle(os.path.join(args.dataset_path, 'split_survey_df.pkl')) 21 | 22 | dataset_df = section_df[section_df['n_bibs'].apply(lambda n_bibs: n_bibs >= 2)] 23 | 24 | dataset_df = dataset_df.rename(columns={'text': 'target'}) 25 | dataset_df = dataset_df.rename(columns={'bib_cinting_sentences': 'bib_citing_sentences'}) 26 | 27 | dataset_df['reference'] = dataset_df[['bib_abstracts', 'section', 'title']].apply(lambda bib_abstracts: ' '.join([' {} {} {} BIB{}'.format(bib_abstracts[2], bib_abstracts[1], abstract, bib) for bib, abstract in bib_abstracts[0].items()]), axis=1) 28 | if args.for_qfid: 29 | dataset_df['reference'] = dataset_df['title'] + ' ' + dataset_df['section'] + ' ' + dataset_df['reference'] 30 | else: 31 | dataset_df['reference'] = dataset_df['reference'].apply(lambda s: s[5:]) 32 | 33 | split_df = dataset_df['split'] 34 | dataset_df = dataset_df[['reference', 'target']] 35 | 36 | train_df = dataset_df[split_df == 'train'] 37 | val_df = dataset_df[split_df == 'val'] 38 | test_df = dataset_df[split_df == 'test'] 39 | 40 | if args.for_qfid: 41 | train_df.to_csv(os.path.join(args.dataset_path, 'train_qfid.csv'), index=False) 42 | val_df.to_csv(os.path.join(args.dataset_path, 'val_qfid.csv'), index=False) 43 | test_df.to_csv(os.path.join(args.dataset_path, 'test_qfid.csv'), index=False) 44 | else: 45 | train_df.to_csv(os.path.join(args.dataset_path, 'train.csv'), index=False) 46 | val_df.to_csv(os.path.join(args.dataset_path, 'val.csv'), index=False) 47 | test_df.to_csv(os.path.join(args.dataset_path, 'test.csv'), index=False) 48 | logging.info('Done!') 49 | 50 | 51 | def anonymize_bib(args): 52 | logging.info('Converting BIB identifiers...') 53 | for split in ['val', 'test', 'train']: 54 | if args.for_qfid: 55 | df = pd.read_csv(os.path.join(args.dataset_path, '{}_qfid.csv'.format(split))) 56 | else: 57 | df = pd.read_csv(os.path.join(args.dataset_path, '{}.csv'.format(split))) 58 | bar = tqdm(total=len(df)) 59 | for row in df.itertuples(): 60 | cnt = 1 61 | bib_dict = {} 62 | for i in range(len(row.reference)): 63 | if row.reference[i:i+7] == ' BIB': 64 | bib_dict[row.reference[i+7:].split(' ')[0]] = cnt 65 | cnt += 1 66 | ref = row.reference 67 | tgt = row.target 68 | for key, value in bib_dict.items(): 69 | ref = re.sub('BIB{}'.format(key), 'BIB{:0>3}'.format(value), ref) 70 | tgt = re.sub('BIB{}'.format(key), 'BIB{:0>3}'.format(value), tgt) 71 | df.at[row.Index, 'reference'] = ref 72 | df.at[row.Index, 'target'] = tgt 73 | bar.update(1) 74 | logging.info('Saving...') 75 | if args.for_qfid: 76 | df.to_csv(os.path.join(args.dataset_path, '{}_qfid.csv'.format(split)), index=False) 77 | else: 78 | df.to_csv(os.path.join(args.dataset_path, '{}.csv'.format(split)), index=False) 79 | 80 | 81 | if __name__ == '__main__': 82 | logging.basicConfig(format='%(message)s', level=logging.DEBUG) 83 | 84 | parser = argparse.ArgumentParser(description='') 85 | parser.add_argument('-dataset_path', help='Path to the generated dataset') 86 | parser.add_argument('--for_qfid', action='store_true', help='Add if you train QFiD on the generated csv files') 87 | args = parser.parse_args() 88 | 89 | make_summarization_csv(args) # Convert split_survey_df into csv files suitable for summarization 90 | anonymize_bib(args) # Converting BIB{paper_id} into BIB{001, 002, ...} 91 | -------------------------------------------------------------------------------- /json_to_df.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import pickle 4 | import argparse 5 | 6 | import pandas as pd 7 | 8 | from tqdm import tqdm 9 | 10 | 11 | def save_metadata_survey_df(args): 12 | logging.info('Loading metadata of literature reviews...') 13 | keywords = ['survey', 'overview', 'literature review', 'a review'] 14 | 15 | metadata_survey_dfs = [] 16 | for i in tqdm(range(100)): 17 | metadata_df = pd.read_json(os.path.join(args.s2orc_path, 'metadata/metadata_{}.jsonl.gz'.format(i)), lines=True, compression='infer') 18 | metadata_survey_df = metadata_df[ 19 | (metadata_df.mag_field_of_study.apply(lambda field: args.field in field if field is not None else False)) 20 | & (metadata_df.title.apply(lambda title: any([word in title.lower() for word in keywords]))) 21 | & metadata_df.has_outbound_citations 22 | & metadata_df.has_pdf_body_text 23 | & ~metadata_df.abstract.isna() 24 | ] 25 | metadata_survey_dfs.append(metadata_survey_df) 26 | metadata_survey_df = pd.concat(metadata_survey_dfs) 27 | metadata_survey_df = metadata_survey_df.set_index('paper_id') 28 | metadata_survey_df.index = metadata_survey_df.index.astype('str') 29 | metadata_survey_df.to_pickle(os.path.join(args.dataset_path, 'metadata_survey_df.pkl')) 30 | logging.info('Done!') 31 | 32 | def save_metadata_outbound_df(args): 33 | logging.info('Loading metadata of cited papers...') 34 | metadata_survey_df = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_survey_df.pkl')) 35 | outbound_paper_ids = set([paper_id for paper_ids in metadata_survey_df.outbound_citations.values for paper_id in paper_ids]) 36 | 37 | metadata_outbound_dfs = [] 38 | for i in tqdm(range(100)): 39 | metadata_df = pd.read_json(os.path.join(args.s2orc_path, 'metadata/metadata_{}.jsonl.gz'.format(i)), lines=True, compression='infer') 40 | metadata_df = metadata_df.set_index('paper_id') 41 | metadata_df.index = metadata_df.index.astype('str') 42 | cs_survey_outbound_paper_ids = list(outbound_paper_ids & set(metadata_df.index)) 43 | metadata_outbound_df = metadata_df.loc[cs_survey_outbound_paper_ids] 44 | metadata_outbound_dfs.append(metadata_outbound_df) 45 | metadata_outbound_df = pd.concat(metadata_outbound_dfs) 46 | metadata_outbound_df.to_pickle(os.path.join(args.dataset_path, 'metadata_outbound_df.pkl')) 47 | logging.info('Done!') 48 | 49 | def save_pdf_df(args): 50 | logging.info('Loading pdf parses of literature review papers and cited papers...') 51 | metadata_survey_df = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_survey_df.pkl')) 52 | survey_index = metadata_survey_df.index 53 | outbound_index = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_outbound_df.pkl')).index 54 | 55 | pdf_survey_dfs = [] 56 | pdf_outbound_dfs = [] 57 | for i in tqdm(range(100)): 58 | pdf_df = pd.read_json(os.path.join(args.s2orc_path, 'pdf_parses/pdf_parses_{}.jsonl.gz'.format(i)), lines=True, compression='infer') 59 | pdf_df = pdf_df.set_index('paper_id') 60 | pdf_df.index = pdf_df.index.astype('str') 61 | 62 | pdf_survey_paper_ids = list(set(survey_index) & set(pdf_df.index)) 63 | pdf_survey_df = pdf_df.loc[pdf_survey_paper_ids] 64 | pdf_survey_df = pdf_survey_df[['body_text', 'bib_entries']] 65 | 66 | pdf_survey_df['title'] = '' 67 | pdf_survey_df['abstract'] = '' 68 | for i, row in enumerate(pdf_survey_df.itertuples()): 69 | pdf_survey_df.at[pdf_survey_df.index[i], 'title'] = metadata_survey_df.query('paper_id == @row.Index')['title'].item() 70 | pdf_survey_df.at[pdf_survey_df.index[i], 'abstract'] = metadata_survey_df.query('paper_id == @row.Index')['abstract'].item() 71 | pdf_survey_dfs.append(pdf_survey_df) 72 | 73 | pdf_outbound_paper_ids = list(set(outbound_index) & set(pdf_df.index)) 74 | pdf_outbound_df = pdf_df.loc[pdf_outbound_paper_ids] 75 | pdf_outbound_df = pdf_outbound_df[pdf_outbound_df.body_text.apply(lambda text: len(text) > 0)] 76 | pdf_outbound_df = pdf_outbound_df[['body_text', 'bib_entries']] 77 | pdf_outbound_dfs.append(pdf_outbound_df) 78 | 79 | pdf_survey_df = pd.concat(pdf_survey_dfs) 80 | pdf_survey_df.to_pickle(os.path.join(args.dataset_path, 'pdf_survey_df.pkl')) 81 | pdf_outbound_df = pd.concat(pdf_outbound_dfs) 82 | pdf_outbound_df.to_pickle(os.path.join(args.dataset_path, 'pdf_outbound_df.pkl')) 83 | logging.info('Done!') 84 | 85 | 86 | if __name__ == '__main__': 87 | logging.basicConfig(format='%(message)s', level=logging.DEBUG) 88 | 89 | parser = argparse.ArgumentParser(description='') 90 | parser.add_argument('-s2orc_path', help='Path to the S2ORC full dataset directory (Typically ".../s2orc/full/20200705v1/full")') 91 | parser.add_argument('-dataset_path', help='Path to the generated dataset') 92 | parser.add_argument('--field', default='Computer Science', help='Field of literature reviews') 93 | args = parser.parse_args() 94 | 95 | save_metadata_survey_df(args) # collect metadata of the literature reviews 96 | save_metadata_outbound_df(args) # collect metadata of the cited papers 97 | save_pdf_df(args) # collect pdf parses of the literature reviews and the cited papers -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SciReviewGen 2 | **This is the official dataset repository for [SciReviewGen: A Large-scale Dataset for Automatic Literature Review Generation](https://arxiv.org/pdf/2305.15186.pdf) in ACL findings 2023.** 3 | 4 | ## Dataset 5 | - [split_survey_df](https://drive.google.com/file/d/1S6v-xaCDND4ilK38sEpkfcOoMnffX7Zf/view?usp=sharing): The split version of SciReviewGen, which aims to generate literature review **chapters** 6 | - [original_survey_df](https://drive.google.com/file/d/1MnjQ2fQ_fJjcqKvIwj2w7P6IGh4GszXH/view?usp=sharing): The original version of SciReviewGen, which aims to generate **the entire text** of literature reviews 7 | - [summarization_csv](https://drive.google.com/file/d/1okvILkxfrpTQYWLxbV4lM9BQnuVaAfbY/view?usp=sharing): CSV files suitable for summarization task. You can apply them to [HuggingFace's official sample codes](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization#custom-csv-files) 8 | 9 | ### Data format 10 | #### split_survey_df & original_survey_df 11 | - Row: 12 | - literature review chapter or the entire text of literature review 13 | - Column: 14 | - paper_id: paper_id used in [S2ORC](https://github.com/allenai/s2orc) 15 | - title: title of the literature review 16 | - abstract: abstract of the literature review 17 | - section: chapter title 18 | - text: body text of literature review chapter or literature review paper 19 | - n_bibs: number of the cited papers that can be used as inputs 20 | - n_nonbibs: number of the cited papers that cannot be used as inputs 21 | - bib_titles: titles of the cited papers 22 | - bib_abstracts: abstracts of the cited papers 23 | - bib_citing_sentences: citing sentences that cite the cited papers 24 | - split: train/val/test split 25 | 26 | #### summarization_csv 27 | - Row: 28 | - literature review chapter 29 | - Column: 30 | - reference: `literature review title chapter title abstract of cited paper 1 BIB001 literature review title chapter title abstract of cited paper 2 BIB002 ...` 31 | - target: literature review chapter 32 | 33 | 34 | ## How to create SciReviewGen from S2ORC 35 | ### 0. Environment 36 | - Python 3.9 37 | - Run the following command to clone the repository and install the required packages 38 | ``` 39 | git clone https://github.com/tetsu9923/SciReviewGen.git 40 | cd SciReviewGen 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | ### 1. Preprocessing 45 | - Download [S2ORC](https://github.com/allenai/s2orc) (We use the version released on **2020-07-05**, which contains papers up until 2020-04-14) 46 | - Run the following command: 47 | ``` 48 | python json_to_df.py \ 49 | -s2orc_path \ 50 | -dataset_path \ 51 | --field 52 | ``` 53 | The metadata and pdf parses of the candidates for the literature reviews and the cited papers are stored in *dataset_path* (in the form of pandas dataframe). 54 | 55 | ### 2. Construct SciReviewGen 56 | - Run the following command: 57 | ``` 58 | python make_section_df.py \ 59 | -dataset_path \ 60 | --version 61 | ``` 62 | The SciReviewGen dataset (**split_survey_df.pkl** or **original_survey_df.pkl**) is stored in *dataset_path* (in the form of pandas dataframe). 63 | `filtered_dict.pkl` gives the list of literature reviews after filtering by the [SciBERT](https://arxiv.org/abs/1903.10676)-based classifier (Section 3.2). 64 | 65 | ### 3. Construct csv data for summarization 66 | - Run the following command: 67 | ``` 68 | python make_summarization_csv.py \ 69 | -dataset_path 70 | ``` 71 | The csv files for summarization (**train.csv**, **val.csv**, and **test.csv**) are stored in *dataset_path*. 72 | If you train QFiD on the generated csv files, add `--for_qfid` argument as below. 73 | ``` 74 | python make_summarization_csv.py \ 75 | -dataset_path \ 76 | --for_qfid 77 | ``` 78 | 79 | 80 | ## Additional resources 81 | ### SciBERT-based literature review classifier 82 | We trained the [SciBERT](https://arxiv.org/abs/1903.10676)-based literature review classifier. 83 | The model weights are available [here](https://drive.google.com/file/d/1cPGJpvCFQkHX2td99YyFitBirG-eCcLC/view?usp=sharing). 84 | 85 | ### Query-weighted Fusion-in-Decoder (QFiD) 86 | We proposed Query-weighted Fusion-in-Decoder (QFiD) that explicitly considers the relevance of each input document to the queries. 87 | You can train QFiD on SciReviewGen csv data (**Make sure that you passed** `--for_qfid` **argument when executing** `make_summarization_csv.py`). 88 | #### Train 89 | - Modify qfid/train.sh (CUDA_VISIBLE_DEVICES, csv file path, outpput_dir, and num_train_epochs) 90 | - Run the following command: 91 | ``` 92 | cd qfid 93 | ./train.sh 94 | ``` 95 | #### Test 96 | - Modify qfid/test.sh (CUDA_VISIBLE_DEVICES, csv file path, outpput_dir, and num_train_epochs. **Please set *num_train_epochs* as the number of epochs you trained in total**) 97 | - Run the following command: 98 | ``` 99 | ./test.sh 100 | ``` 101 | 102 | ## Licenses 103 | - SciReviewGen is released under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). **You can use SciReviewGen for only non-commercial purposes.** 104 | - SciReviewGen is created based on [S2ORC](https://github.com/allenai/s2orc). Note that S2ORC is released under CC BY-NC 4.0, which allows users to copy and redistribute for only non-commercial purposes. 105 | -------------------------------------------------------------------------------- /make_section_df.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import pickle 4 | import argparse 5 | 6 | import nltk.data 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from tqdm import tqdm 11 | 12 | 13 | def append_citing_sentence(args): 14 | extra_abbreviations = ['dr', 'vs', 'mr', 'mrs', 'prof', 'inc', 'i.e', 'al'] 15 | tokenizer = nltk.data.load('tokenizers/punkt/english.pickle') 16 | tokenizer._params.abbrev_types.update(extra_abbreviations) 17 | 18 | logging.info('Loading citing sentences...') 19 | cs_survey_ids = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_survey_df.pkl')).index 20 | pdf_outbound_df = pd.read_pickle(os.path.join(args.dataset_path, 'pdf_outbound_df.pkl')) 21 | metadata_outbound_df = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_outbound_df.pkl')) 22 | metadata_outbound_df = metadata_outbound_df.dropna(subset=['abstract']) 23 | 24 | citing_df = pdf_outbound_df.copy() 25 | citing_df['bib_entries'] = citing_df['bib_entries'].apply( 26 | lambda bib_entries: {key: value['link'] for key, value in bib_entries.items()} 27 | ) 28 | citing_df['bib_mark'] = citing_df['body_text'].apply( 29 | lambda body_text: [cite['text'] for cite in sum([body['cite_spans'] for body in body_text], [])] 30 | ) 31 | 32 | citing_dict = {} 33 | bar = tqdm(total=len(citing_df)) 34 | # citing_dict: {paper A: {paper B cited in A: 'citing sentence in A (explaining B)', paper C cited in A: 'citing sentence in A (explaining C)'}, paperD: ...} 35 | for i, row in enumerate(citing_df.itertuples()): 36 | tmp_dict = {} 37 | for section in row.body_text: 38 | section_text = section['text'].split() 39 | for citing in section['cite_spans']: 40 | for sentence in tokenizer.tokenize(section['text']): 41 | if citing['text'] in sentence: 42 | text = sentence 43 | if citing['ref_id'] in row.bib_entries.keys(): 44 | if row.bib_entries[citing['ref_id']] != None: 45 | tmp_dict[row.bib_entries[citing['ref_id']]] = text 46 | break 47 | citing_dict[citing_df.index[i]] = tmp_dict 48 | bar.update(1) 49 | logging.info('Done!') 50 | 51 | logging.info('Appending citing sentences to metadata_outbound_df...') 52 | citing_sentence_list = [] 53 | bar = tqdm(total=len(metadata_outbound_df)) 54 | for index, citing_papers in metadata_outbound_df['inbound_citations'].iteritems(): 55 | citing_sentence = [] 56 | n_of_citing = 0 57 | for citing_paper in citing_papers: 58 | if citing_paper in citing_dict.keys() and citing_paper not in cs_survey_ids: 59 | if index in citing_dict[citing_paper].keys(): 60 | citing_sentence.append(citing_dict[citing_paper][index]) 61 | n_of_citing += 1 62 | citing_sentence_list.append(citing_sentence) 63 | bar.update(1) 64 | 65 | metadata_outbound_df['citing_sentence'] = '' 66 | bar = tqdm(total=len(metadata_outbound_df)) 67 | for i, row in enumerate(metadata_outbound_df.itertuples()): 68 | append_sentence = citing_sentence_list[i] 69 | metadata_outbound_df.at[metadata_outbound_df.index[i], 'citing_sentence'] = append_sentence 70 | bar.update(1) 71 | 72 | metadata_outbound_df.to_pickle(os.path.join(args.dataset_path, 'metadata_outbound_citation_df.pkl')) 73 | logging.info('Done!') 74 | 75 | 76 | def make_scireviewgen(args): 77 | logging.info('Making section_df...') 78 | metadata_outbound_df = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_outbound_citation_df.pkl')) 79 | metadata_outbound_df = metadata_outbound_df.dropna(subset=['abstract']) 80 | pdf_survey_df = pd.read_pickle(os.path.join(args.dataset_path, 'pdf_survey_df.pkl')) 81 | pdf_survey_df = pdf_survey_df[pdf_survey_df['abstract'].apply(lambda s: type(s) == str)] 82 | 83 | def get_section_df(row): 84 | sections_duplicate = [paragraph['section'] for paragraph in row.body_text] 85 | sections = sorted(set(sections_duplicate), key=sections_duplicate.index) 86 | bib_df = pd.DataFrame.from_dict(row.bib_entries, orient='index') 87 | bib_df = bib_df[bib_df.link.apply(lambda paper_id: paper_id in metadata_outbound_df.index)] 88 | bib_dict = bib_df.link.dropna().to_dict() 89 | 90 | def replace_cite(body_row): 91 | body_text = '' 92 | start_index = 0 93 | for cite_span in body_row.cite_spans: 94 | end_index = cite_span['start'] 95 | ref_id = cite_span['ref_id'] 96 | body_text += body_row.text_raw[start_index:end_index] 97 | body_text += 'BIB{}'.format(bib_dict[ref_id]) if ref_id in bib_dict else '' 98 | start_index = cite_span['end'] 99 | body_text += body_row.text_raw[start_index:] 100 | return body_text 101 | 102 | body_df = pd.DataFrame(row.body_text).rename(columns={'text': 'text_raw'}) 103 | body_df['text'] = body_df[['text_raw', 'cite_spans']].apply(replace_cite, axis=1) 104 | body_df['title'] = row.title 105 | body_df['abstract'] = row.abstract 106 | 107 | section_df = body_df.groupby('section').agg({ 108 | 'text': lambda text_series: ' '.join([text for text in text_series]), 109 | 'title': lambda text_series: [text for text in text_series][0], 110 | 'abstract': lambda text_series: [text for text in text_series][0], 111 | 'cite_spans': lambda cite_spans_series: [cite['ref_id'] for cite_spans in cite_spans_series for cite in cite_spans], 112 | }) 113 | section_df = section_df.loc[sections] 114 | section_df['bibs'] = section_df['cite_spans'].apply(lambda spans: [bib_dict[span] for span in spans if span in bib_dict]) 115 | section_df['n_bibs'] = section_df[['cite_spans', 'bibs']].apply(lambda row: len(row['bibs']), axis=1) 116 | section_df['n_nonbibs'] = section_df[['cite_spans', 'bibs']].apply(lambda row: len(row['cite_spans']) - len(row['bibs']), axis=1) 117 | 118 | section_df['paper_id'] = row.name 119 | section_df['section_id'] = section_df['paper_id'] + '-' + np.arange(len(section_df)).astype('str') 120 | section_df['section'] = section_df.index 121 | section_df = section_df.set_index('section_id') 122 | 123 | section_df['bib_titles'] = section_df['bibs'].apply(lambda bibs: metadata_outbound_df.loc[bibs]['title'].to_dict()) 124 | section_df['bib_abstracts'] = section_df['bibs'].apply(lambda bibs: metadata_outbound_df.loc[bibs]['abstract'].to_dict()) 125 | section_df['bib_years'] = section_df['bibs'].apply(lambda bibs: metadata_outbound_df.loc[bibs]['year'].to_dict()) 126 | section_df['bib_abstracts'] = section_df[['bib_abstracts', 'bib_years']].apply(lambda bib: dict(sorted(bib[0].items(), key=lambda x: bib[1][x[0]])), axis=1) # Sort by publication year 127 | section_df['bib_citing_sentences'] = section_df['bibs'].apply(lambda bibs: metadata_outbound_df.loc[bibs]['citing_sentence'].to_dict()) 128 | section_df = section_df[['paper_id', 'title', 'abstract', 'section', 'text', 'n_bibs', 'n_nonbibs', 'bib_titles', 'bib_abstracts', 'bib_citing_sentences']] 129 | return section_df 130 | 131 | section_survey_df = pd.concat(pdf_survey_df.apply(get_section_df, axis=1).values) 132 | section_survey_df = section_survey_df[section_survey_df['text'].apply(len) >= 1] # Remove sections without body text 133 | 134 | with open ('filtered_dict.pkl', 'rb') as f: 135 | filtering_dict = pickle.load(f) 136 | section_survey_df = section_survey_df[section_survey_df['paper_id'].isin(filtering_dict.keys())] 137 | section_survey_df['split'] = section_survey_df['paper_id'].apply(lambda s: filtering_dict[s]) 138 | 139 | if args.version == 'split': 140 | section_survey_df = section_survey_df[section_survey_df['bib_abstracts'].apply(lambda _dict: len(_dict) >= 2)] # Remove sections with less than two cited papers 141 | section_survey_df.to_pickle(os.path.join(args.dataset_path, 'split_survey_df.pkl')) 142 | else: 143 | section_survey_df = section_survey_df.groupby('paper_id').agg({ 144 | 'title': lambda l: list(l)[0], 145 | 'abstract': lambda l: list(l)[0], 146 | 'section': lambda l: list(l), 147 | 'text': lambda l: list(l), 148 | 'n_bibs': lambda l: list(l), 149 | 'n_nonbibs': lambda l: list(l), 150 | 'bib_titles': lambda l: list(l), 151 | 'bib_abstracts': lambda l: list(l), 152 | 'bib_citing_sentences': lambda l: list(l), 153 | 'split': lambda l: list(l)[0], 154 | }) 155 | section_survey_df.to_pickle(os.path.join(args.dataset_path, 'original_survey_df.pkl')) 156 | 157 | logging.info('Done!') 158 | 159 | 160 | if __name__ == '__main__': 161 | logging.basicConfig(format='%(message)s', level=logging.DEBUG) 162 | 163 | parser = argparse.ArgumentParser(description='') 164 | parser.add_argument('-dataset_path', help='Path to the generated dataset') 165 | parser.add_argument('--version', default='split', help='Specify the version ("split" or "original")', choices=['split', 'original']) 166 | parser.add_argument('--n_val', type=int, default=1000, help='Number of literature review papers in validation set') 167 | parser.add_argument('--n_test', type=int, default=1000, help='Number of literature review papers in test set') 168 | args = parser.parse_args() 169 | 170 | append_citing_sentence(args) # collect citing sentences for the cited papers 171 | make_scireviewgen(args) # make scireviewgen dataset in the form of pandas dataframe -------------------------------------------------------------------------------- /qfid/run_summarization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import Optional 26 | 27 | import datasets 28 | import nltk # Here to have a nice missing dependency error message early on 29 | import numpy as np 30 | from datasets import load_dataset, load_metric 31 | 32 | import transformers 33 | from filelock import FileLock 34 | from transformers import ( 35 | AutoConfig, 36 | AutoModelForSeq2SeqLM, 37 | AutoTokenizer, 38 | DataCollatorForSeq2Seq, 39 | HfArgumentParser, 40 | MBart50Tokenizer, 41 | MBart50TokenizerFast, 42 | MBartTokenizer, 43 | MBartTokenizerFast, 44 | Seq2SeqTrainer, 45 | Seq2SeqTrainingArguments, 46 | set_seed, 47 | ) 48 | from transformers.file_utils import is_offline_mode 49 | from transformers.trainer_utils import get_last_checkpoint 50 | from transformers.utils import check_min_version 51 | from transformers.utils.versions import require_version 52 | from qfid import BartForConditionalGeneration 53 | 54 | 55 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 56 | check_min_version("4.16.0.dev0") 57 | 58 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 59 | 60 | logger = logging.getLogger(__name__) 61 | 62 | try: 63 | nltk.data.find("tokenizers/punkt") 64 | except (LookupError, OSError): 65 | if is_offline_mode(): 66 | raise LookupError( 67 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 68 | ) 69 | with FileLock(".lock") as lock: 70 | nltk.download("punkt", quiet=True) 71 | 72 | # A list of all multilingual tokenizer which require lang attribute. 73 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] 74 | 75 | 76 | @dataclass 77 | class ModelArguments: 78 | """ 79 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 80 | """ 81 | 82 | model_name_or_path: str = field( 83 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 84 | ) 85 | config_name: Optional[str] = field( 86 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 87 | ) 88 | tokenizer_name: Optional[str] = field( 89 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 90 | ) 91 | cache_dir: Optional[str] = field( 92 | default=None, 93 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 94 | ) 95 | use_fast_tokenizer: bool = field( 96 | default=True, 97 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 98 | ) 99 | model_revision: str = field( 100 | default="main", 101 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 102 | ) 103 | use_auth_token: bool = field( 104 | default=False, 105 | metadata={ 106 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 107 | "with private models)." 108 | }, 109 | ) 110 | resize_position_embeddings: Optional[bool] = field( 111 | default=None, 112 | metadata={ 113 | "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 114 | "the model's position embeddings." 115 | }, 116 | ) 117 | 118 | 119 | @dataclass 120 | class DataTrainingArguments: 121 | """ 122 | Arguments pertaining to what data we are going to input our model for training and eval. 123 | """ 124 | 125 | lang: str = field(default=None, metadata={"help": "Language id for summarization."}) 126 | 127 | dataset_name: Optional[str] = field( 128 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 129 | ) 130 | dataset_config_name: Optional[str] = field( 131 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 132 | ) 133 | text_column: Optional[str] = field( 134 | default=None, 135 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 136 | ) 137 | summary_column: Optional[str] = field( 138 | default=None, 139 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 140 | ) 141 | train_file: Optional[str] = field( 142 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 143 | ) 144 | validation_file: Optional[str] = field( 145 | default=None, 146 | metadata={ 147 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 148 | "(a jsonlines or csv file)." 149 | }, 150 | ) 151 | test_file: Optional[str] = field( 152 | default=None, 153 | metadata={ 154 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)." 155 | }, 156 | ) 157 | overwrite_cache: bool = field( 158 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 159 | ) 160 | preprocessing_num_workers: Optional[int] = field( 161 | default=None, 162 | metadata={"help": "The number of processes to use for the preprocessing."}, 163 | ) 164 | max_source_length: Optional[int] = field( 165 | default=1024, 166 | metadata={ 167 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 168 | "than this will be truncated, sequences shorter will be padded." 169 | }, 170 | ) 171 | max_target_length: Optional[int] = field( 172 | default=128, 173 | metadata={ 174 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 175 | "than this will be truncated, sequences shorter will be padded." 176 | }, 177 | ) 178 | val_max_target_length: Optional[int] = field( 179 | default=None, 180 | metadata={ 181 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 182 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 183 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 184 | "during ``evaluate`` and ``predict``." 185 | }, 186 | ) 187 | pad_to_max_length: bool = field( 188 | default=False, 189 | metadata={ 190 | "help": "Whether to pad all samples to model maximum sentence length. " 191 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 192 | "efficient on GPU but very bad for TPU." 193 | }, 194 | ) 195 | max_train_samples: Optional[int] = field( 196 | default=None, 197 | metadata={ 198 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 199 | "value if set." 200 | }, 201 | ) 202 | max_eval_samples: Optional[int] = field( 203 | default=None, 204 | metadata={ 205 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 206 | "value if set." 207 | }, 208 | ) 209 | max_predict_samples: Optional[int] = field( 210 | default=None, 211 | metadata={ 212 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 213 | "value if set." 214 | }, 215 | ) 216 | num_beams: Optional[int] = field( 217 | default=None, 218 | metadata={ 219 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 220 | "which is used during ``evaluate`` and ``predict``." 221 | }, 222 | ) 223 | ignore_pad_token_for_loss: bool = field( 224 | default=True, 225 | metadata={ 226 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 227 | }, 228 | ) 229 | source_prefix: Optional[str] = field( 230 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 231 | ) 232 | 233 | forced_bos_token: Optional[str] = field( 234 | default=None, 235 | metadata={ 236 | "help": "The token to force as the first generated token after the decoder_start_token_id." 237 | "Useful for multilingual models like mBART where the first generated token" 238 | "needs to be the target language token (Usually it is the target language token)" 239 | }, 240 | ) 241 | 242 | def __post_init__(self): 243 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 244 | raise ValueError("Need either a dataset name or a training/validation file.") 245 | else: 246 | if self.train_file is not None: 247 | extension = self.train_file.split(".")[-1] 248 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 249 | if self.validation_file is not None: 250 | extension = self.validation_file.split(".")[-1] 251 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 252 | if self.val_max_target_length is None: 253 | self.val_max_target_length = self.max_target_length 254 | 255 | 256 | summarization_name_mapping = { 257 | "amazon_reviews_multi": ("review_body", "review_title"), 258 | "big_patent": ("description", "abstract"), 259 | "cnn_dailymail": ("article", "highlights"), 260 | "orange_sum": ("text", "summary"), 261 | "pn_summary": ("article", "summary"), 262 | "psc": ("extract_text", "summary_text"), 263 | "samsum": ("dialogue", "summary"), 264 | "thaisum": ("body", "summary"), 265 | "xglue": ("news_body", "news_title"), 266 | "xsum": ("document", "summary"), 267 | "wiki_summary": ("article", "highlights"), 268 | } 269 | 270 | 271 | def main(): 272 | # See all possible arguments in src/transformers/training_args.py 273 | # or by passing the --help flag to this script. 274 | # We now keep distinct sets of args, for a cleaner separation of concerns. 275 | 276 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 277 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 278 | # If we pass only one argument to the script and it's the path to a json file, 279 | # let's parse it to get our arguments. 280 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 281 | else: 282 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 283 | 284 | # Setup logging 285 | logging.basicConfig( 286 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 287 | datefmt="%m/%d/%Y %H:%M:%S", 288 | handlers=[logging.StreamHandler(sys.stdout)], 289 | ) 290 | log_level = training_args.get_process_log_level() 291 | logger.setLevel(log_level) 292 | datasets.utils.logging.set_verbosity(log_level) 293 | transformers.utils.logging.set_verbosity(log_level) 294 | transformers.utils.logging.enable_default_handler() 295 | transformers.utils.logging.enable_explicit_format() 296 | 297 | # Log on each process the small summary: 298 | logger.warning( 299 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 300 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 301 | ) 302 | logger.info(f"Training/evaluation parameters {training_args}") 303 | 304 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 305 | "t5-small", 306 | "t5-base", 307 | "t5-large", 308 | "t5-3b", 309 | "t5-11b", 310 | ]: 311 | logger.warning( 312 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 313 | "`--source_prefix 'summarize: ' `" 314 | ) 315 | 316 | # Detecting last checkpoint. 317 | last_checkpoint = None 318 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 319 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 320 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 321 | raise ValueError( 322 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 323 | "Use --overwrite_output_dir to overcome." 324 | ) 325 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 326 | logger.info( 327 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 328 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 329 | ) 330 | 331 | # Set seed before initializing model. 332 | set_seed(training_args.seed) 333 | 334 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 335 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 336 | # (the dataset will be downloaded automatically from the datasets Hub). 337 | # 338 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 339 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 340 | # 341 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 342 | # download the dataset. 343 | if data_args.dataset_name is not None: 344 | # Downloading and loading a dataset from the hub. 345 | raw_datasets = load_dataset( 346 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir 347 | ) 348 | else: 349 | data_files = {} 350 | if data_args.train_file is not None: 351 | data_files["train"] = data_args.train_file 352 | extension = data_args.train_file.split(".")[-1] 353 | if data_args.validation_file is not None: 354 | data_files["validation"] = data_args.validation_file 355 | extension = data_args.validation_file.split(".")[-1] 356 | if data_args.test_file is not None: 357 | data_files["test"] = data_args.test_file 358 | extension = data_args.test_file.split(".")[-1] 359 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) 360 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 361 | # https://huggingface.co/docs/datasets/loading_datasets.html. 362 | 363 | # Load pretrained model and tokenizer 364 | # 365 | # Distributed training: 366 | # The .from_pretrained methods guarantee that only one local process can concurrently 367 | # download model & vocab. 368 | config = AutoConfig.from_pretrained( 369 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 370 | cache_dir=model_args.cache_dir, 371 | revision=model_args.model_revision, 372 | use_auth_token=True if model_args.use_auth_token else None, 373 | ) 374 | tokenizer = AutoTokenizer.from_pretrained( 375 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 376 | cache_dir=model_args.cache_dir, 377 | use_fast=model_args.use_fast_tokenizer, 378 | revision=model_args.model_revision, 379 | use_auth_token=True if model_args.use_auth_token else None, 380 | ) 381 | model = BartForConditionalGeneration.from_pretrained( 382 | model_args.model_name_or_path, 383 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 384 | config=config, 385 | cache_dir=model_args.cache_dir, 386 | revision=model_args.model_revision, 387 | use_auth_token=True if model_args.use_auth_token else None, 388 | ) 389 | 390 | #special_tokens_dict = {'additional_special_tokens': ['[BOT]', '[BOCT]', '[BOA]', '[BOC]', '[BOB]']}# + ['BIB{:0>3}'.format(i+1) for i in range(100)]} 391 | #tokenizer.add_special_tokens(special_tokens_dict) 392 | 393 | model.resize_token_embeddings(len(tokenizer)) 394 | 395 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): 396 | if isinstance(tokenizer, MBartTokenizer): 397 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang] 398 | else: 399 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang) 400 | 401 | if model.config.decoder_start_token_id is None: 402 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 403 | 404 | #if ( 405 | # hasattr(model.config, "max_position_embeddings") 406 | # and model.config.max_position_embeddings < data_args.max_source_length 407 | #): 408 | # if model_args.resize_position_embeddings is None: 409 | # logger.warning( 410 | # f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} " 411 | # f"to {data_args.max_source_length}." 412 | # ) 413 | # model.resize_position_embeddings(data_args.max_source_length) 414 | # elif model_args.resize_position_embeddings: 415 | # model.resize_position_embeddings(data_args.max_source_length) 416 | # else: 417 | # raise ValueError( 418 | # f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}" 419 | # f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically " 420 | # "resize the model's position encodings by passing `--resize_position_embeddings`." 421 | # ) 422 | 423 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 424 | 425 | # Preprocessing the datasets. 426 | # We need to tokenize inputs and targets. 427 | if training_args.do_train: 428 | column_names = raw_datasets["train"].column_names 429 | elif training_args.do_eval: 430 | column_names = raw_datasets["validation"].column_names 431 | elif training_args.do_predict: 432 | column_names = raw_datasets["test"].column_names 433 | else: 434 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 435 | return 436 | 437 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): 438 | assert ( 439 | data_args.lang is not None 440 | ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" 441 | 442 | tokenizer.src_lang = data_args.lang 443 | tokenizer.tgt_lang = data_args.lang 444 | 445 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token 446 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. 447 | forced_bos_token_id = ( 448 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None 449 | ) 450 | model.config.forced_bos_token_id = forced_bos_token_id 451 | 452 | # Get the column names for input/target. 453 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 454 | if data_args.text_column is None: 455 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 456 | else: 457 | text_column = data_args.text_column 458 | if text_column not in column_names: 459 | raise ValueError( 460 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 461 | ) 462 | if data_args.summary_column is None: 463 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 464 | else: 465 | summary_column = data_args.summary_column 466 | if summary_column not in column_names: 467 | raise ValueError( 468 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 469 | ) 470 | 471 | # Temporarily set max_target_length for training. 472 | max_target_length = data_args.max_target_length 473 | padding = "max_length" if data_args.pad_to_max_length else False 474 | 475 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 476 | logger.warning( 477 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 478 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 479 | ) 480 | 481 | def preprocess_function(examples): 482 | # remove pairs where at least one record is None 483 | 484 | inputs, targets = [], [] 485 | for i in range(len(examples[text_column])): 486 | if examples[text_column][i] is not None and examples[summary_column][i] is not None: 487 | inputs.append(examples[text_column][i]) 488 | targets.append(examples[summary_column][i]) 489 | 490 | inputs = examples[text_column] 491 | targets = examples[summary_column] 492 | inputs = [prefix + inp for inp in inputs] 493 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 494 | 495 | # Setup the tokenizer for targets 496 | with tokenizer.as_target_tokenizer(): 497 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) 498 | 499 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 500 | # padding in the loss. 501 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 502 | labels["input_ids"] = [ 503 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 504 | ] 505 | 506 | model_inputs["labels"] = labels["input_ids"] 507 | return model_inputs 508 | 509 | if training_args.do_train: 510 | if "train" not in raw_datasets: 511 | raise ValueError("--do_train requires a train dataset") 512 | train_dataset = raw_datasets["train"] 513 | if data_args.max_train_samples is not None: 514 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 515 | with training_args.main_process_first(desc="train dataset map pre-processing"): 516 | train_dataset = train_dataset.map( 517 | preprocess_function, 518 | batched=True, 519 | num_proc=data_args.preprocessing_num_workers, 520 | remove_columns=column_names, 521 | load_from_cache_file=not data_args.overwrite_cache, 522 | desc="Running tokenizer on train dataset", 523 | ) 524 | 525 | if training_args.do_eval: 526 | max_target_length = data_args.val_max_target_length 527 | if "validation" not in raw_datasets: 528 | raise ValueError("--do_eval requires a validation dataset") 529 | eval_dataset = raw_datasets["validation"] 530 | if data_args.max_eval_samples is not None: 531 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 532 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 533 | eval_dataset = eval_dataset.map( 534 | preprocess_function, 535 | batched=True, 536 | num_proc=data_args.preprocessing_num_workers, 537 | remove_columns=column_names, 538 | load_from_cache_file=not data_args.overwrite_cache, 539 | desc="Running tokenizer on validation dataset", 540 | ) 541 | 542 | if training_args.do_predict: 543 | max_target_length = data_args.val_max_target_length 544 | if "test" not in raw_datasets: 545 | raise ValueError("--do_predict requires a test dataset") 546 | predict_dataset = raw_datasets["test"] 547 | if data_args.max_predict_samples is not None: 548 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 549 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 550 | predict_dataset = predict_dataset.map( 551 | preprocess_function, 552 | batched=True, 553 | num_proc=data_args.preprocessing_num_workers, 554 | remove_columns=column_names, 555 | load_from_cache_file=not data_args.overwrite_cache, 556 | desc="Running tokenizer on prediction dataset", 557 | ) 558 | 559 | # Data collator 560 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 561 | data_collator = DataCollatorForSeq2Seq( 562 | tokenizer, 563 | model=model, 564 | label_pad_token_id=label_pad_token_id, 565 | pad_to_multiple_of=8 if training_args.fp16 else None, 566 | ) 567 | 568 | # Metric 569 | metric = load_metric("rouge") 570 | 571 | def postprocess_text(preds, labels): 572 | preds = [pred.strip() for pred in preds] 573 | labels = [label.strip() for label in labels] 574 | 575 | # rougeLSum expects newline after each sentence 576 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 577 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 578 | 579 | return preds, labels 580 | 581 | def compute_metrics(eval_preds): 582 | preds, labels = eval_preds 583 | if isinstance(preds, tuple): 584 | preds = preds[0] 585 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 586 | if data_args.ignore_pad_token_for_loss: 587 | # Replace -100 in the labels as we can't decode them. 588 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 589 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 590 | 591 | # Some simple post-processing 592 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 593 | 594 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 595 | # Extract a few results from ROUGE 596 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 597 | 598 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 599 | result["gen_len"] = np.mean(prediction_lens) 600 | result = {k: round(v, 4) for k, v in result.items()} 601 | return result 602 | 603 | # Initialize our Trainer 604 | trainer = Seq2SeqTrainer( 605 | model=model, 606 | args=training_args, 607 | train_dataset=train_dataset if training_args.do_train else None, 608 | eval_dataset=eval_dataset if training_args.do_eval else None, 609 | tokenizer=tokenizer, 610 | data_collator=data_collator, 611 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 612 | ) 613 | 614 | # Training 615 | if training_args.do_train: 616 | checkpoint = None 617 | if training_args.resume_from_checkpoint is not None: 618 | checkpoint = training_args.resume_from_checkpoint 619 | elif last_checkpoint is not None: 620 | checkpoint = last_checkpoint 621 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 622 | trainer.save_model() # Saves the tokenizer too for easy upload 623 | 624 | metrics = train_result.metrics 625 | max_train_samples = ( 626 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 627 | ) 628 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 629 | 630 | trainer.log_metrics("train", metrics) 631 | trainer.save_metrics("train", metrics) 632 | trainer.save_state() 633 | 634 | # Evaluation 635 | results = {} 636 | max_length = ( 637 | training_args.generation_max_length 638 | if training_args.generation_max_length is not None 639 | else data_args.val_max_target_length 640 | ) 641 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 642 | if training_args.do_eval: 643 | logger.info("*** Evaluate ***") 644 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 645 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 646 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 647 | 648 | trainer.log_metrics("eval", metrics) 649 | trainer.save_metrics("eval", metrics) 650 | 651 | if training_args.do_predict: 652 | logger.info("*** Predict ***") 653 | 654 | predict_results = trainer.predict( 655 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 656 | ) 657 | metrics = predict_results.metrics 658 | max_predict_samples = ( 659 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 660 | ) 661 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 662 | 663 | trainer.log_metrics("predict", metrics) 664 | trainer.save_metrics("predict", metrics) 665 | 666 | if trainer.is_world_process_zero(): 667 | if training_args.predict_with_generate: 668 | predictions = tokenizer.batch_decode( 669 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 670 | ) 671 | predictions = [pred.strip() for pred in predictions] 672 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 673 | with open(output_prediction_file, "w") as writer: 674 | writer.write("\n".join(predictions)) 675 | 676 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 677 | if data_args.dataset_name is not None: 678 | kwargs["dataset_tags"] = data_args.dataset_name 679 | if data_args.dataset_config_name is not None: 680 | kwargs["dataset_args"] = data_args.dataset_config_name 681 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 682 | else: 683 | kwargs["dataset"] = data_args.dataset_name 684 | 685 | if data_args.lang is not None: 686 | kwargs["language"] = data_args.lang 687 | 688 | if training_args.push_to_hub: 689 | trainer.push_to_hub(**kwargs) 690 | else: 691 | trainer.create_model_card(**kwargs) 692 | 693 | return results 694 | 695 | 696 | def _mp_fn(index): 697 | # For xla_spawn (TPUs) 698 | main() 699 | 700 | 701 | if __name__ == "__main__": 702 | main() -------------------------------------------------------------------------------- /qfid/qfid.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import random 4 | import warnings 5 | from typing import List, Optional, Tuple, Union 6 | 7 | import torch 8 | import torch.utils.checkpoint 9 | from torch import nn 10 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 11 | 12 | from transformers.activations import ACT2FN 13 | from transformers.modeling_outputs import ( 14 | BaseModelOutput, 15 | BaseModelOutputWithPastAndCrossAttentions, 16 | CausalLMOutputWithCrossAttentions, 17 | Seq2SeqLMOutput, 18 | Seq2SeqModelOutput, 19 | Seq2SeqQuestionAnsweringModelOutput, 20 | Seq2SeqSequenceClassifierOutput, 21 | ) 22 | from transformers.modeling_utils import PreTrainedModel 23 | from transformers.models.bart.configuration_bart import BartConfig 24 | 25 | 26 | _CHECKPOINT_FOR_DOC = "facebook/bart-base" 27 | _CONFIG_FOR_DOC = "BartConfig" 28 | _TOKENIZER_FOR_DOC = "BartTokenizer" 29 | 30 | # Base model docstring 31 | _EXPECTED_OUTPUT_SHAPE = [1, 8, 768] 32 | 33 | # SequenceClassification docstring 34 | _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2" 35 | _SEQ_CLASS_EXPECTED_LOSS = 0.0 36 | _SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'" 37 | 38 | # QuestionAsnwering docstring 39 | _CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1" 40 | _QA_EXPECTED_LOSS = 0.59 41 | _QA_EXPECTED_OUTPUT = "' nice puppet'" 42 | 43 | 44 | BART_PRETRAINED_MODEL_ARCHIVE_LIST = [ 45 | "facebook/bart-large", 46 | # see all BART models at https://huggingface.co/models?filter=bart 47 | ] 48 | 49 | 50 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 51 | """ 52 | Shift input ids one token to the right. 53 | """ 54 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 55 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 56 | shifted_input_ids[:, 0] = decoder_start_token_id 57 | 58 | if pad_token_id is None: 59 | raise ValueError("self.model.config.pad_token_id has to be defined.") 60 | # replace possible -100 values in labels by `pad_token_id` 61 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 62 | 63 | return shifted_input_ids 64 | 65 | 66 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0): 67 | """ 68 | Make causal mask used for bi-directional self-attention. 69 | """ 70 | bsz, tgt_len = input_ids_shape 71 | mask = torch.full((tgt_len, tgt_len), float("-inf")) 72 | mask_cond = torch.arange(mask.size(-1)) 73 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 74 | mask = mask.to(dtype) 75 | 76 | if past_key_values_length > 0: 77 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1) 78 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 79 | 80 | 81 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 82 | """ 83 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 84 | """ 85 | bsz, src_len = mask.size() 86 | tgt_len = tgt_len if tgt_len is not None else src_len 87 | 88 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 89 | 90 | inverted_mask = 1.0 - expanded_mask 91 | 92 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) 93 | 94 | 95 | class BartLearnedPositionalEmbedding(nn.Embedding): 96 | """ 97 | This module learns positional embeddings up to a fixed maximum size. 98 | """ 99 | 100 | def __init__(self, num_embeddings: int, embedding_dim: int): 101 | # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2 102 | # and adjust num_embeddings appropriately. Other models don't have this hack 103 | self.offset = 2 104 | super().__init__(num_embeddings + self.offset, embedding_dim) 105 | 106 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): 107 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 108 | bsz, seq_len = input_ids_shape[:2] 109 | positions = torch.arange( 110 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device 111 | ) 112 | return super().forward(positions + self.offset) 113 | 114 | 115 | class BartAttention(nn.Module): 116 | """Multi-headed attention from 'Attention Is All You Need' paper""" 117 | 118 | def __init__( 119 | self, 120 | embed_dim: int, 121 | num_heads: int, 122 | dropout: float = 0.0, 123 | is_decoder: bool = False, 124 | bias: bool = True, 125 | ): 126 | super().__init__() 127 | self.embed_dim = embed_dim 128 | self.num_heads = num_heads 129 | self.dropout = dropout 130 | self.head_dim = embed_dim // num_heads 131 | 132 | if (self.head_dim * num_heads) != self.embed_dim: 133 | raise ValueError( 134 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 135 | f" and `num_heads`: {num_heads})." 136 | ) 137 | self.scaling = self.head_dim**-0.5 138 | self.is_decoder = is_decoder 139 | 140 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 141 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 142 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 143 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 144 | 145 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 146 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 147 | 148 | def forward( 149 | self, 150 | hidden_states: torch.Tensor, 151 | key_value_states: Optional[torch.Tensor] = None, 152 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 153 | attention_mask: Optional[torch.Tensor] = None, 154 | layer_head_mask: Optional[torch.Tensor] = None, 155 | output_attentions: bool = False, 156 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 157 | """Input shape: Batch x Time x Channel""" 158 | 159 | # if key_value_states are provided this layer is used as a cross-attention layer 160 | # for the decoder 161 | is_cross_attention = key_value_states is not None 162 | 163 | bsz, tgt_len, _ = hidden_states.size() 164 | 165 | # get query proj 166 | query_states = self.q_proj(hidden_states) * self.scaling 167 | # get key, value proj 168 | if is_cross_attention and past_key_value is not None: 169 | # reuse k,v, cross_attentions 170 | key_states = past_key_value[0] 171 | value_states = past_key_value[1] 172 | elif is_cross_attention: 173 | # cross_attentions 174 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 175 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 176 | elif past_key_value is not None: 177 | # reuse k, v, self_attention 178 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 179 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 180 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 181 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 182 | else: 183 | # self_attention 184 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 185 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 186 | 187 | if self.is_decoder: 188 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 189 | # Further calls to cross_attention layer can then reuse all cross-attention 190 | # key/value_states (first "if" case) 191 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 192 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 193 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 194 | # if encoder bi-directional self-attention `past_key_value` is always `None` 195 | past_key_value = (key_states, value_states) 196 | 197 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 198 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 199 | key_states = key_states.view(*proj_shape) 200 | value_states = value_states.view(*proj_shape) 201 | 202 | src_len = key_states.size(1) 203 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 204 | 205 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 206 | raise ValueError( 207 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 208 | ) 209 | 210 | if attention_mask is not None: 211 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 212 | raise ValueError( 213 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 214 | ) 215 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 216 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 217 | 218 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 219 | 220 | if layer_head_mask is not None: 221 | if layer_head_mask.size() != (self.num_heads,): 222 | raise ValueError( 223 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}" 224 | ) 225 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 226 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 227 | 228 | if output_attentions: 229 | # this operation is a bit awkward, but it's required to 230 | # make sure that attn_weights keeps its gradient. 231 | # In order to do so, attn_weights have to be reshaped 232 | # twice and have to be reused in the following 233 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 234 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 235 | else: 236 | attn_weights_reshaped = None 237 | 238 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 239 | 240 | attn_output = torch.bmm(attn_probs, value_states) 241 | 242 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 243 | raise ValueError( 244 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 245 | ) 246 | 247 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 248 | attn_output = attn_output.transpose(1, 2) 249 | 250 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 251 | # partitioned aross GPUs when using tensor-parallelism. 252 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 253 | 254 | attn_output = self.out_proj(attn_output) 255 | 256 | return attn_output, attn_weights_reshaped, past_key_value 257 | 258 | 259 | class BartEncoderLayer(nn.Module): 260 | def __init__(self, config: BartConfig): 261 | super().__init__() 262 | self.embed_dim = config.d_model 263 | self.self_attn = BartAttention( 264 | embed_dim=self.embed_dim, 265 | num_heads=config.encoder_attention_heads, 266 | dropout=config.attention_dropout, 267 | ) 268 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 269 | self.dropout = config.dropout 270 | self.activation_fn = ACT2FN[config.activation_function] 271 | self.activation_dropout = config.activation_dropout 272 | self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) 273 | self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) 274 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 275 | 276 | def forward( 277 | self, 278 | hidden_states: torch.FloatTensor, 279 | attention_mask: torch.FloatTensor, 280 | layer_head_mask: torch.FloatTensor, 281 | output_attentions: Optional[bool] = False, 282 | ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: 283 | """ 284 | Args: 285 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` 286 | attention_mask (`torch.FloatTensor`): attention mask of size 287 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 288 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size 289 | `(encoder_attention_heads,)`. 290 | output_attentions (`bool`, *optional*): 291 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 292 | returned tensors for more detail. 293 | """ 294 | residual = hidden_states 295 | hidden_states, attn_weights, _ = self.self_attn( 296 | hidden_states=hidden_states, 297 | attention_mask=attention_mask, 298 | layer_head_mask=layer_head_mask, 299 | output_attentions=output_attentions, 300 | ) 301 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 302 | hidden_states = residual + hidden_states 303 | hidden_states = self.self_attn_layer_norm(hidden_states) 304 | 305 | residual = hidden_states 306 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 307 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 308 | hidden_states = self.fc2(hidden_states) 309 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 310 | hidden_states = residual + hidden_states 311 | hidden_states = self.final_layer_norm(hidden_states) 312 | 313 | if hidden_states.dtype == torch.float16 and ( 314 | torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() 315 | ): 316 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 317 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 318 | 319 | outputs = (hidden_states,) 320 | 321 | if output_attentions: 322 | outputs += (attn_weights,) 323 | 324 | return outputs 325 | 326 | 327 | class BartDecoderLayer(nn.Module): 328 | def __init__(self, config: BartConfig): 329 | super().__init__() 330 | self.embed_dim = config.d_model 331 | 332 | self.self_attn = BartAttention( 333 | embed_dim=self.embed_dim, 334 | num_heads=config.decoder_attention_heads, 335 | dropout=config.attention_dropout, 336 | is_decoder=True, 337 | ) 338 | self.dropout = config.dropout 339 | self.activation_fn = ACT2FN[config.activation_function] 340 | self.activation_dropout = config.activation_dropout 341 | 342 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 343 | self.encoder_attn = BartAttention( 344 | self.embed_dim, 345 | config.decoder_attention_heads, 346 | dropout=config.attention_dropout, 347 | is_decoder=True, 348 | ) 349 | self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) 350 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 351 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 352 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 353 | 354 | def forward( 355 | self, 356 | hidden_states: torch.Tensor, 357 | attention_mask: Optional[torch.Tensor] = None, 358 | encoder_hidden_states: Optional[torch.Tensor] = None, 359 | encoder_attention_mask: Optional[torch.Tensor] = None, 360 | layer_head_mask: Optional[torch.Tensor] = None, 361 | cross_attn_layer_head_mask: Optional[torch.Tensor] = None, 362 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 363 | output_attentions: Optional[bool] = False, 364 | use_cache: Optional[bool] = True, 365 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 366 | """ 367 | Args: 368 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 369 | attention_mask (`torch.FloatTensor`): attention mask of size 370 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 371 | encoder_hidden_states (`torch.FloatTensor`): 372 | cross attention input to the layer of shape `(batch, seq_len, embed_dim)` 373 | encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size 374 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 375 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size 376 | `(encoder_attention_heads,)`. 377 | cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of 378 | size `(decoder_attention_heads,)`. 379 | past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states 380 | output_attentions (`bool`, *optional*): 381 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 382 | returned tensors for more detail. 383 | """ 384 | residual = hidden_states 385 | # Self Attention 386 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 387 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 388 | # add present self-attn cache to positions 1,2 of present_key_value tuple 389 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 390 | hidden_states=hidden_states, 391 | past_key_value=self_attn_past_key_value, 392 | attention_mask=attention_mask, 393 | layer_head_mask=layer_head_mask, 394 | output_attentions=output_attentions, 395 | ) 396 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 397 | hidden_states = residual + hidden_states 398 | hidden_states = self.self_attn_layer_norm(hidden_states) 399 | 400 | # Cross-Attention Block 401 | cross_attn_present_key_value = None 402 | cross_attn_weights = None 403 | if encoder_hidden_states is not None: 404 | residual = hidden_states 405 | 406 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 407 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 408 | #print("BartDecoderLayer: {}".format(encoder_attention_mask.shape)) # こいつをtensor([], size=(1, 1, 512, 0))→size=(1, 1, 512, 1024)にする 409 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 410 | hidden_states=hidden_states, 411 | key_value_states=encoder_hidden_states, 412 | attention_mask=encoder_attention_mask, 413 | layer_head_mask=cross_attn_layer_head_mask, 414 | past_key_value=cross_attn_past_key_value, 415 | output_attentions=output_attentions, 416 | ) 417 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 418 | hidden_states = residual + hidden_states 419 | hidden_states = self.encoder_attn_layer_norm(hidden_states) 420 | 421 | # add cross-attn to positions 3,4 of present_key_value tuple 422 | present_key_value = present_key_value + cross_attn_present_key_value 423 | 424 | # Fully Connected 425 | residual = hidden_states 426 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 427 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 428 | hidden_states = self.fc2(hidden_states) 429 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 430 | hidden_states = residual + hidden_states 431 | hidden_states = self.final_layer_norm(hidden_states) 432 | 433 | outputs = (hidden_states,) 434 | 435 | if output_attentions: 436 | outputs += (self_attn_weights, cross_attn_weights) 437 | 438 | if use_cache: 439 | outputs += (present_key_value,) 440 | 441 | return outputs 442 | 443 | 444 | class BartClassificationHead(nn.Module): 445 | """Head for sentence-level classification tasks.""" 446 | 447 | def __init__( 448 | self, 449 | input_dim: int, 450 | inner_dim: int, 451 | num_classes: int, 452 | pooler_dropout: float, 453 | ): 454 | super().__init__() 455 | self.dense = nn.Linear(input_dim, inner_dim) 456 | self.dropout = nn.Dropout(p=pooler_dropout) 457 | self.out_proj = nn.Linear(inner_dim, num_classes) 458 | 459 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 460 | hidden_states = self.dropout(hidden_states) 461 | hidden_states = self.dense(hidden_states) 462 | hidden_states = torch.tanh(hidden_states) 463 | hidden_states = self.dropout(hidden_states) 464 | hidden_states = self.out_proj(hidden_states) 465 | return hidden_states 466 | 467 | 468 | class BartPretrainedModel(PreTrainedModel): 469 | config_class = BartConfig 470 | base_model_prefix = "model" 471 | supports_gradient_checkpointing = True 472 | _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] 473 | 474 | def _init_weights(self, module): 475 | std = self.config.init_std 476 | if isinstance(module, nn.Linear): 477 | module.weight.data.normal_(mean=0.0, std=std) 478 | if module.bias is not None: 479 | module.bias.data.zero_() 480 | elif isinstance(module, nn.Embedding): 481 | module.weight.data.normal_(mean=0.0, std=std) 482 | if module.padding_idx is not None: 483 | module.weight.data[module.padding_idx].zero_() 484 | 485 | def _set_gradient_checkpointing(self, module, value=False): 486 | if isinstance(module, (BartDecoder, BartEncoder)): 487 | module.gradient_checkpointing = value 488 | 489 | @property 490 | def dummy_inputs(self): 491 | pad_token = self.config.pad_token_id 492 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 493 | dummy_inputs = { 494 | "attention_mask": input_ids.ne(pad_token), 495 | "input_ids": input_ids, 496 | } 497 | return dummy_inputs 498 | 499 | 500 | class PretrainedBartModel(BartPretrainedModel): 501 | def __init_subclass__(self): 502 | warnings.warn( 503 | "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.", 504 | FutureWarning, 505 | ) 506 | 507 | 508 | BART_START_DOCSTRING = r""" 509 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 510 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 511 | etc.) 512 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 513 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 514 | and behavior. 515 | Parameters: 516 | config ([`BartConfig`]): 517 | Model configuration class with all the parameters of the model. Initializing with a config file does not 518 | load the weights associated with the model, only the configuration. Check out the 519 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 520 | """ 521 | 522 | BART_GENERATION_EXAMPLE = r""" 523 | Summarization example: 524 | ```python 525 | >>> from transformers import BartTokenizer, BartForConditionalGeneration 526 | >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn") 527 | >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn") 528 | >>> ARTICLE_TO_SUMMARIZE = ( 529 | ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds " 530 | ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were " 531 | ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." 532 | ... ) 533 | >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt") 534 | >>> # Generate Summary 535 | >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, max_length=20) 536 | >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 537 | 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions' 538 | ``` 539 | Mask filling example: 540 | ```python 541 | >>> from transformers import BartTokenizer, BartForConditionalGeneration 542 | >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 543 | >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") 544 | >>> TXT = "My friends are but they eat too many carbs." 545 | >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] 546 | >>> logits = model(input_ids).logits 547 | >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() 548 | >>> probs = logits[0, masked_index].softmax(dim=0) 549 | >>> values, predictions = probs.topk(5) 550 | >>> tokenizer.decode(predictions).split() 551 | ['not', 'good', 'healthy', 'great', 'very'] 552 | ``` 553 | """ 554 | 555 | BART_INPUTS_DOCSTRING = r""" 556 | Args: 557 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 558 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 559 | it. 560 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and 561 | [`PreTrainedTokenizer.__call__`] for details. 562 | [What are input IDs?](../glossary#input-ids) 563 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 564 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 565 | - 1 for tokens that are **not masked**, 566 | - 0 for tokens that are **masked**. 567 | [What are attention masks?](../glossary#attention-mask) 568 | decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 569 | Indices of decoder input sequence tokens in the vocabulary. 570 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and 571 | [`PreTrainedTokenizer.__call__`] for details. 572 | [What are decoder input IDs?](../glossary#decoder-input-ids) 573 | Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` 574 | is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). 575 | For translation and summarization training, `decoder_input_ids` should be provided. If no 576 | `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right 577 | for denoising pre-training following the paper. 578 | decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 579 | Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also 580 | be used by default. 581 | If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_inputs`] and 582 | modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information 583 | on the default strategy. 584 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): 585 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`: 586 | - 1 indicates the head is **not masked**, 587 | - 0 indicates the head is **masked**. 588 | decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 589 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`: 590 | - 1 indicates the head is **not masked**, 591 | - 0 indicates the head is **masked**. 592 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 593 | Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 594 | 1]`: 595 | - 1 indicates the head is **not masked**, 596 | - 0 indicates the head is **masked**. 597 | encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): 598 | Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) 599 | `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of 600 | hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. 601 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 602 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 603 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 604 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 605 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 606 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 607 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 608 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 609 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape 610 | `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you 611 | can choose to directly pass an embedded representation. This is useful if you want more control over how to 612 | convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 613 | decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): 614 | Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded 615 | representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be 616 | input (see `past_key_values`). This is useful if you want more control over how to convert 617 | `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. 618 | If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value 619 | of `inputs_embeds`. 620 | use_cache (`bool`, *optional*): 621 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 622 | `past_key_values`). 623 | output_attentions (`bool`, *optional*): 624 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 625 | tensors for more detail. 626 | output_hidden_states (`bool`, *optional*): 627 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 628 | more detail. 629 | return_dict (`bool`, *optional*): 630 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 631 | """ 632 | 633 | 634 | class BartEncoder(BartPretrainedModel): 635 | """ 636 | Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a 637 | [`BartEncoderLayer`]. 638 | Args: 639 | config: BartConfig 640 | embed_tokens (nn.Embedding): output embedding 641 | """ 642 | 643 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): 644 | super().__init__(config) 645 | 646 | self.dropout = config.dropout 647 | self.layerdrop = config.encoder_layerdrop 648 | 649 | embed_dim = config.d_model 650 | self.padding_idx = config.pad_token_id 651 | self.max_source_positions = config.max_position_embeddings 652 | self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 653 | 654 | if embed_tokens is not None: 655 | self.embed_tokens = embed_tokens 656 | else: 657 | self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) 658 | 659 | self.embed_positions = BartLearnedPositionalEmbedding( 660 | config.max_position_embeddings, 661 | embed_dim, 662 | ) 663 | self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) 664 | self.layernorm_embedding = nn.LayerNorm(embed_dim) 665 | 666 | self.gradient_checkpointing = False 667 | # Initialize weights and apply final processing 668 | self.post_init() 669 | 670 | def get_input_embeddings(self): 671 | return self.embed_tokens 672 | 673 | def set_input_embeddings(self, value): 674 | self.embed_tokens = value 675 | 676 | def forward( 677 | self, 678 | input_ids: torch.LongTensor = None, 679 | attention_mask: Optional[torch.Tensor] = None, 680 | head_mask: Optional[torch.Tensor] = None, 681 | inputs_embeds: Optional[torch.FloatTensor] = None, 682 | output_attentions: Optional[bool] = None, 683 | output_hidden_states: Optional[bool] = None, 684 | return_dict: Optional[bool] = None, 685 | ) -> Union[Tuple, BaseModelOutput]: 686 | r""" 687 | Args: 688 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 689 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 690 | provide it. 691 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and 692 | [`PreTrainedTokenizer.__call__`] for details. 693 | [What are input IDs?](../glossary#input-ids) 694 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 695 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 696 | - 1 for tokens that are **not masked**, 697 | - 0 for tokens that are **masked**. 698 | [What are attention masks?](../glossary#attention-mask) 699 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): 700 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 701 | - 1 indicates the head is **not masked**, 702 | - 0 indicates the head is **masked**. 703 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 704 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 705 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 706 | than the model's internal embedding lookup matrix. 707 | output_attentions (`bool`, *optional*): 708 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 709 | returned tensors for more detail. 710 | output_hidden_states (`bool`, *optional*): 711 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 712 | for more detail. 713 | return_dict (`bool`, *optional*): 714 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 715 | """ 716 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 717 | output_hidden_states = ( 718 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 719 | ) 720 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 721 | 722 | # retrieve input_ids and inputs_embeds 723 | if input_ids is not None and inputs_embeds is not None: 724 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 725 | elif input_ids is not None: 726 | input_shape = input_ids.size() 727 | input_ids = input_ids.view(-1, input_shape[-1]) 728 | elif inputs_embeds is not None: 729 | raise ValueError("You cannot specify inputs_embeds for now") 730 | else: 731 | raise ValueError("You have to specify either input_ids or inputs_embeds") 732 | 733 | # divide input_ids and attention_mask by '' 734 | input_ids_list = [] 735 | attention_mask_list = [] 736 | prev_token = 0 737 | prev_i = 1 738 | n_sep = 0 739 | for i, token in enumerate(input_ids[0]): 740 | if token == 2 and prev_token == 1437: 741 | input_ids_list.append(torch.cat((torch.cat((torch.tensor([0], device="cuda"), torch.tensor(input_ids[0][prev_i:i-1])))[:min(len(input_ids[0][prev_i-1:i-1]), self.config.max_position_embeddings-1)], torch.tensor([2], device="cuda"))).unsqueeze(dim=0)) 742 | attention_mask_list.append(torch.tensor(attention_mask[0][prev_i-1:i])[:self.config.max_position_embeddings].unsqueeze(dim=0)) 743 | prev_i = i + 1 744 | n_sep += 1 745 | prev_token = token 746 | 747 | input_ids_list.append(torch.cat((torch.cat((torch.tensor([0], device="cuda"), torch.tensor(input_ids[0][prev_i:])))[:min(len(input_ids[0][prev_i:]), self.config.max_position_embeddings-1)], torch.tensor([2], device="cuda"))).unsqueeze(dim=0)) 748 | attention_mask_list.append((torch.tensor(attention_mask[0][prev_i-1:]))[:self.config.max_position_embeddings].unsqueeze(dim=0)) 749 | 750 | l = max(list(map(lambda x: len(x[0]), input_ids_list))) 751 | input_ids = torch.ones(len(input_ids_list), l, dtype=torch.int32, device="cuda") 752 | attention_mask = torch.zeros(len(attention_mask_list), l, dtype=torch.int32, device="cuda") 753 | for i, (_input_ids, _attention_mask) in enumerate(zip(input_ids_list, attention_mask_list)): 754 | input_ids[i, :len(_input_ids[0])] = _input_ids[0] 755 | attention_mask[i, :len(_attention_mask[0])] = _attention_mask[0] 756 | 757 | if inputs_embeds is None: 758 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 759 | 760 | input_shape = input_ids.size() 761 | embed_pos = self.embed_positions(input_shape) 762 | 763 | hidden_states = inputs_embeds + embed_pos 764 | hidden_states = self.layernorm_embedding(hidden_states) 765 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 766 | 767 | # expand attention_mask 768 | if attention_mask is not None: 769 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 770 | original_attention_mask = attention_mask 771 | attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype) 772 | 773 | encoder_states = () if output_hidden_states else None 774 | all_attentions = () if output_attentions else None 775 | 776 | # check if head_mask has a correct number of layers specified if desired 777 | if head_mask is not None: 778 | if head_mask.size()[0] != (len(self.layers)): 779 | raise ValueError( 780 | f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 781 | ) 782 | 783 | for idx, encoder_layer in enumerate(self.layers): 784 | if output_hidden_states: 785 | encoder_states = encoder_states + (hidden_states,) 786 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 787 | dropout_probability = random.uniform(0, 1) 788 | if self.training and (dropout_probability < self.layerdrop): # skip the layer 789 | layer_outputs = (None, None) 790 | else: 791 | if self.gradient_checkpointing and self.training: 792 | 793 | def create_custom_forward(module): 794 | def custom_forward(*inputs): 795 | return module(*inputs, output_attentions) 796 | 797 | return custom_forward 798 | 799 | layer_outputs = torch.utils.checkpoint.checkpoint( 800 | create_custom_forward(encoder_layer), 801 | hidden_states, 802 | attention_mask, 803 | (head_mask[idx] if head_mask is not None else None), 804 | ) 805 | else: 806 | layer_outputs = encoder_layer( 807 | hidden_states, 808 | attention_mask, 809 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 810 | output_attentions=output_attentions, 811 | ) 812 | 813 | hidden_states = layer_outputs[0] 814 | 815 | if output_attentions: 816 | all_attentions = all_attentions + (layer_outputs[1],) 817 | 818 | if output_hidden_states: 819 | encoder_states = encoder_states + (hidden_states,) 820 | 821 | matching_list = [] 822 | for i, (_hidden_states, _attention_mask) in enumerate(zip(hidden_states, original_attention_mask)): 823 | if i == 0: 824 | title_hidden_states = torch.mean(_hidden_states.index_select(0, torch.tensor([i for i, mask in enumerate(_attention_mask) if mask == 1], device="cuda")), dim=0) 825 | else: 826 | abst_hidden_states = torch.mean(_hidden_states.index_select(0, torch.tensor([i for i, mask in enumerate(_attention_mask) if mask == 1], device="cuda")), dim=0) 827 | matching_list.append(torch.dot(title_hidden_states, abst_hidden_states).item()) 828 | 829 | f = torch.nn.Softmax(dim=0) 830 | matching_list = f(torch.tensor(matching_list, dtype=torch.float32)) 831 | try: 832 | matching_list = (1 + matching_list).tolist() 833 | except: 834 | matching_list = [1 for i in range(len(matching_list))] 835 | 836 | for i, (_hidden_states, _attention_mask) in enumerate(zip(hidden_states, original_attention_mask)): 837 | if i == 0: 838 | return_hidden_states = _hidden_states.index_select(0, torch.tensor([i for i, mask in enumerate(_attention_mask) if mask == 1], device="cuda")) 839 | else: 840 | return_hidden_states = torch.cat((return_hidden_states, matching_list[i-1] * _hidden_states.index_select(0, torch.tensor([i for i, mask in enumerate(_attention_mask) if mask == 1], device="cuda"))), 0) 841 | 842 | if not return_dict: 843 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) 844 | return BaseModelOutput( 845 | last_hidden_state=return_hidden_states.unsqueeze(0), hidden_states=encoder_states, attentions=all_attentions 846 | ) 847 | 848 | 849 | class BartDecoder(BartPretrainedModel): 850 | """ 851 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`] 852 | Args: 853 | config: BartConfig 854 | embed_tokens (nn.Embedding): output embedding 855 | """ 856 | 857 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): 858 | super().__init__(config) 859 | self.dropout = config.dropout 860 | self.layerdrop = config.decoder_layerdrop 861 | self.padding_idx = config.pad_token_id 862 | self.max_target_positions = config.max_position_embeddings 863 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 864 | 865 | if embed_tokens is not None: 866 | self.embed_tokens = embed_tokens 867 | else: 868 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) 869 | 870 | self.embed_positions = BartLearnedPositionalEmbedding( 871 | config.max_position_embeddings, 872 | config.d_model, 873 | ) 874 | self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) 875 | self.layernorm_embedding = nn.LayerNorm(config.d_model) 876 | 877 | self.gradient_checkpointing = False 878 | # Initialize weights and apply final processing 879 | self.post_init() 880 | 881 | def get_input_embeddings(self): 882 | return self.embed_tokens 883 | 884 | def set_input_embeddings(self, value): 885 | self.embed_tokens = value 886 | 887 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 888 | # create causal mask 889 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 890 | combined_attention_mask = None 891 | if input_shape[-1] > 1: 892 | combined_attention_mask = _make_causal_mask( 893 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length 894 | ).to(self.device) 895 | 896 | if attention_mask is not None: 897 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 898 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 899 | combined_attention_mask = ( 900 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 901 | ) 902 | 903 | return combined_attention_mask 904 | 905 | def forward( 906 | self, 907 | input_ids: torch.LongTensor = None, 908 | attention_mask: Optional[torch.Tensor] = None, 909 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 910 | encoder_attention_mask: Optional[torch.LongTensor] = None, 911 | head_mask: Optional[torch.Tensor] = None, 912 | cross_attn_head_mask: Optional[torch.Tensor] = None, 913 | past_key_values: Optional[List[torch.FloatTensor]] = None, 914 | inputs_embeds: Optional[torch.FloatTensor] = None, 915 | use_cache: Optional[bool] = None, 916 | output_attentions: Optional[bool] = None, 917 | output_hidden_states: Optional[bool] = None, 918 | return_dict: Optional[bool] = None, 919 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 920 | r""" 921 | Args: 922 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 923 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 924 | provide it. 925 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and 926 | [`PreTrainedTokenizer.__call__`] for details. 927 | [What are input IDs?](../glossary#input-ids) 928 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 929 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 930 | - 1 for tokens that are **not masked**, 931 | - 0 for tokens that are **masked**. 932 | [What are attention masks?](../glossary#attention-mask) 933 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): 934 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 935 | of the decoder. 936 | encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): 937 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 938 | selected in `[0, 1]`: 939 | - 1 for tokens that are **not masked**, 940 | - 0 for tokens that are **masked**. 941 | [What are attention masks?](../glossary#attention-mask) 942 | head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 943 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 944 | - 1 indicates the head is **not masked**, 945 | - 0 indicates the head is **masked**. 946 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 947 | Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing 948 | cross-attention on hidden heads. Mask values selected in `[0, 1]`: 949 | - 1 indicates the head is **not masked**, 950 | - 0 indicates the head is **masked**. 951 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 952 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 953 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 954 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 955 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 956 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 957 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 958 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 959 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of 960 | shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing 961 | `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more 962 | control over how to convert `input_ids` indices into associated vectors than the model's internal 963 | embedding lookup matrix. 964 | output_attentions (`bool`, *optional*): 965 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 966 | returned tensors for more detail. 967 | output_hidden_states (`bool`, *optional*): 968 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 969 | for more detail. 970 | return_dict (`bool`, *optional*): 971 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 972 | """ 973 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 974 | output_hidden_states = ( 975 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 976 | ) 977 | use_cache = use_cache if use_cache is not None else self.config.use_cache 978 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 979 | 980 | # retrieve input_ids and inputs_embeds 981 | if input_ids is not None and inputs_embeds is not None: 982 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 983 | elif input_ids is not None: 984 | input_shape = input_ids.size() 985 | input_ids = input_ids.view(-1, input_shape[-1]) 986 | elif inputs_embeds is not None: 987 | input_shape = inputs_embeds.size()[:-1] 988 | else: 989 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 990 | 991 | # past_key_values_length 992 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 993 | 994 | if inputs_embeds is None: 995 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale 996 | 997 | attention_mask = self._prepare_decoder_attention_mask( 998 | attention_mask, input_shape, inputs_embeds, past_key_values_length 999 | ) 1000 | 1001 | # encoder_attention_maskを1埋めで作る 1002 | #encoder_attention_mask = torch.ones(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1], device="cuda") 1003 | encoder_attention_mask = encoder_attention_mask.narrow(dim=1, start=0, length=encoder_hidden_states.shape[1]) 1004 | # expand encoder attention mask 1005 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 1006 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1007 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) 1008 | 1009 | # embed positions 1010 | positions = self.embed_positions(input_shape, past_key_values_length) 1011 | 1012 | hidden_states = inputs_embeds + positions 1013 | hidden_states = self.layernorm_embedding(hidden_states) 1014 | 1015 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1016 | 1017 | # decoder layers 1018 | all_hidden_states = () if output_hidden_states else None 1019 | all_self_attns = () if output_attentions else None 1020 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 1021 | next_decoder_cache = () if use_cache else None 1022 | 1023 | # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired 1024 | for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): 1025 | if attn_mask is not None: 1026 | if attn_mask.size()[0] != (len(self.layers)): 1027 | raise ValueError( 1028 | "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." 1029 | ) 1030 | 1031 | for idx, decoder_layer in enumerate(self.layers): 1032 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 1033 | if output_hidden_states: 1034 | all_hidden_states += (hidden_states,) 1035 | dropout_probability = random.uniform(0, 1) 1036 | if self.training and (dropout_probability < self.layerdrop): 1037 | continue 1038 | 1039 | past_key_value = past_key_values[idx] if past_key_values is not None else None 1040 | 1041 | if self.gradient_checkpointing and self.training: 1042 | 1043 | if use_cache: 1044 | print( 1045 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1046 | ) 1047 | use_cache = False 1048 | 1049 | def create_custom_forward(module): 1050 | def custom_forward(*inputs): 1051 | # None for past_key_value 1052 | return module(*inputs, output_attentions, use_cache) 1053 | 1054 | return custom_forward 1055 | 1056 | layer_outputs = torch.utils.checkpoint.checkpoint( 1057 | create_custom_forward(decoder_layer), 1058 | hidden_states, 1059 | attention_mask, 1060 | encoder_hidden_states, 1061 | encoder_attention_mask, 1062 | head_mask[idx] if head_mask is not None else None, 1063 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, 1064 | None, 1065 | ) 1066 | else: 1067 | layer_outputs = decoder_layer( 1068 | hidden_states, 1069 | attention_mask=attention_mask, 1070 | encoder_hidden_states=encoder_hidden_states, 1071 | encoder_attention_mask=encoder_attention_mask, 1072 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 1073 | cross_attn_layer_head_mask=( 1074 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None 1075 | ), 1076 | past_key_value=past_key_value, 1077 | output_attentions=output_attentions, 1078 | use_cache=use_cache, 1079 | ) 1080 | hidden_states = layer_outputs[0] 1081 | 1082 | if use_cache: 1083 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 1084 | 1085 | if output_attentions: 1086 | all_self_attns += (layer_outputs[1],) 1087 | 1088 | if encoder_hidden_states is not None: 1089 | all_cross_attentions += (layer_outputs[2],) 1090 | 1091 | # add hidden states from the last decoder layer 1092 | if output_hidden_states: 1093 | all_hidden_states += (hidden_states,) 1094 | 1095 | next_cache = next_decoder_cache if use_cache else None 1096 | if not return_dict: 1097 | return tuple( 1098 | v 1099 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 1100 | if v is not None 1101 | ) 1102 | return BaseModelOutputWithPastAndCrossAttentions( 1103 | last_hidden_state=hidden_states, 1104 | past_key_values=next_cache, 1105 | hidden_states=all_hidden_states, 1106 | attentions=all_self_attns, 1107 | cross_attentions=all_cross_attentions, 1108 | ) 1109 | 1110 | 1111 | class BartModel(BartPretrainedModel): 1112 | def __init__(self, config: BartConfig): 1113 | super().__init__(config) 1114 | 1115 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size 1116 | self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) 1117 | 1118 | self.encoder = BartEncoder(config, self.shared) 1119 | self.decoder = BartDecoder(config, self.shared) 1120 | 1121 | # Initialize weights and apply final processing 1122 | self.post_init() 1123 | 1124 | def get_input_embeddings(self): 1125 | return self.shared 1126 | 1127 | def set_input_embeddings(self, value): 1128 | self.shared = value 1129 | self.encoder.embed_tokens = self.shared 1130 | self.decoder.embed_tokens = self.shared 1131 | 1132 | def get_encoder(self): 1133 | return self.encoder 1134 | 1135 | def get_decoder(self): 1136 | return self.decoder 1137 | 1138 | def forward( 1139 | self, 1140 | input_ids: torch.LongTensor = None, 1141 | attention_mask: Optional[torch.Tensor] = None, 1142 | decoder_input_ids: Optional[torch.LongTensor] = None, 1143 | decoder_attention_mask: Optional[torch.LongTensor] = None, 1144 | head_mask: Optional[torch.Tensor] = None, 1145 | decoder_head_mask: Optional[torch.Tensor] = None, 1146 | cross_attn_head_mask: Optional[torch.Tensor] = None, 1147 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 1148 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1149 | inputs_embeds: Optional[torch.FloatTensor] = None, 1150 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 1151 | use_cache: Optional[bool] = None, 1152 | output_attentions: Optional[bool] = None, 1153 | output_hidden_states: Optional[bool] = None, 1154 | return_dict: Optional[bool] = None, 1155 | ) -> Union[Tuple, Seq2SeqModelOutput]: 1156 | 1157 | # different to other models, Bart automatically creates decoder_input_ids from 1158 | # input_ids if no decoder_input_ids are provided 1159 | if decoder_input_ids is None and decoder_inputs_embeds is None: 1160 | if input_ids is None: 1161 | raise ValueError( 1162 | "If no `decoder_input_ids` or `decoder_inputs_embeds` are " 1163 | "passed, `input_ids` cannot be `None`. Please pass either " 1164 | "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." 1165 | ) 1166 | 1167 | decoder_input_ids = shift_tokens_right( 1168 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id 1169 | ) 1170 | 1171 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1172 | output_hidden_states = ( 1173 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1174 | ) 1175 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1176 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1177 | 1178 | if encoder_outputs is None: 1179 | encoder_outputs = self.encoder( 1180 | input_ids=input_ids, 1181 | attention_mask=attention_mask, 1182 | head_mask=head_mask, 1183 | inputs_embeds=inputs_embeds, 1184 | output_attentions=output_attentions, 1185 | output_hidden_states=output_hidden_states, 1186 | return_dict=return_dict, 1187 | ) 1188 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 1189 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 1190 | encoder_outputs = BaseModelOutput( 1191 | last_hidden_state=encoder_outputs[0], 1192 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 1193 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 1194 | ) 1195 | 1196 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 1197 | decoder_outputs = self.decoder( 1198 | input_ids=decoder_input_ids, 1199 | attention_mask=decoder_attention_mask, 1200 | encoder_hidden_states=encoder_outputs[0], 1201 | encoder_attention_mask=attention_mask, 1202 | head_mask=decoder_head_mask, 1203 | cross_attn_head_mask=cross_attn_head_mask, 1204 | past_key_values=past_key_values, 1205 | inputs_embeds=decoder_inputs_embeds, 1206 | use_cache=use_cache, 1207 | output_attentions=output_attentions, 1208 | output_hidden_states=output_hidden_states, 1209 | return_dict=return_dict, 1210 | ) 1211 | 1212 | if not return_dict: 1213 | return decoder_outputs + encoder_outputs 1214 | 1215 | return Seq2SeqModelOutput( 1216 | last_hidden_state=decoder_outputs.last_hidden_state, 1217 | past_key_values=decoder_outputs.past_key_values, 1218 | decoder_hidden_states=decoder_outputs.hidden_states, 1219 | decoder_attentions=decoder_outputs.attentions, 1220 | cross_attentions=decoder_outputs.cross_attentions, 1221 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 1222 | encoder_hidden_states=encoder_outputs.hidden_states, 1223 | encoder_attentions=encoder_outputs.attentions, 1224 | ) 1225 | 1226 | 1227 | class BartForConditionalGeneration(BartPretrainedModel): 1228 | base_model_prefix = "model" 1229 | _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"] 1230 | 1231 | def __init__(self, config: BartConfig): 1232 | super().__init__(config) 1233 | self.model = BartModel(config) 1234 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 1235 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 1236 | 1237 | # Initialize weights and apply final processing 1238 | self.post_init() 1239 | 1240 | def get_encoder(self): 1241 | return self.model.get_encoder() 1242 | 1243 | def get_decoder(self): 1244 | return self.model.get_decoder() 1245 | 1246 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: 1247 | new_embeddings = super().resize_token_embeddings(new_num_tokens) 1248 | self._resize_final_logits_bias(new_num_tokens) 1249 | return new_embeddings 1250 | 1251 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None: 1252 | old_num_tokens = self.final_logits_bias.shape[-1] 1253 | if new_num_tokens <= old_num_tokens: 1254 | new_bias = self.final_logits_bias[:, :new_num_tokens] 1255 | else: 1256 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 1257 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 1258 | self.register_buffer("final_logits_bias", new_bias) 1259 | 1260 | def get_output_embeddings(self): 1261 | return self.lm_head 1262 | 1263 | def set_output_embeddings(self, new_embeddings): 1264 | self.lm_head = new_embeddings 1265 | 1266 | def forward( 1267 | self, 1268 | input_ids: torch.LongTensor = None, 1269 | attention_mask: Optional[torch.Tensor] = None, 1270 | decoder_input_ids: Optional[torch.LongTensor] = None, 1271 | decoder_attention_mask: Optional[torch.LongTensor] = None, 1272 | head_mask: Optional[torch.Tensor] = None, 1273 | decoder_head_mask: Optional[torch.Tensor] = None, 1274 | cross_attn_head_mask: Optional[torch.Tensor] = None, 1275 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 1276 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1277 | inputs_embeds: Optional[torch.FloatTensor] = None, 1278 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 1279 | labels: Optional[torch.LongTensor] = None, 1280 | use_cache: Optional[bool] = None, 1281 | output_attentions: Optional[bool] = None, 1282 | output_hidden_states: Optional[bool] = None, 1283 | return_dict: Optional[bool] = None, 1284 | ) -> Union[Tuple, Seq2SeqLMOutput]: 1285 | r""" 1286 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1287 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1288 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1289 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1290 | Returns: 1291 | """ 1292 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1293 | 1294 | if labels is not None: 1295 | if use_cache: 1296 | print("The `use_cache` argument is changed to `False` since `labels` is provided.") 1297 | use_cache = False 1298 | if decoder_input_ids is None and decoder_inputs_embeds is None: 1299 | decoder_input_ids = shift_tokens_right( 1300 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 1301 | ) 1302 | 1303 | outputs = self.model( 1304 | input_ids, 1305 | attention_mask=attention_mask, 1306 | decoder_input_ids=decoder_input_ids, 1307 | encoder_outputs=encoder_outputs, 1308 | decoder_attention_mask=decoder_attention_mask, 1309 | head_mask=head_mask, 1310 | decoder_head_mask=decoder_head_mask, 1311 | cross_attn_head_mask=cross_attn_head_mask, 1312 | past_key_values=past_key_values, 1313 | inputs_embeds=inputs_embeds, 1314 | decoder_inputs_embeds=decoder_inputs_embeds, 1315 | use_cache=use_cache, 1316 | output_attentions=output_attentions, 1317 | output_hidden_states=output_hidden_states, 1318 | return_dict=return_dict, 1319 | ) 1320 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias 1321 | 1322 | masked_lm_loss = None 1323 | if labels is not None: 1324 | loss_fct = CrossEntropyLoss() 1325 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 1326 | 1327 | if not return_dict: 1328 | output = (lm_logits,) + outputs[1:] 1329 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1330 | 1331 | return Seq2SeqLMOutput( 1332 | loss=masked_lm_loss, 1333 | logits=lm_logits, 1334 | past_key_values=outputs.past_key_values, 1335 | decoder_hidden_states=outputs.decoder_hidden_states, 1336 | decoder_attentions=outputs.decoder_attentions, 1337 | cross_attentions=outputs.cross_attentions, 1338 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1339 | encoder_hidden_states=outputs.encoder_hidden_states, 1340 | encoder_attentions=outputs.encoder_attentions, 1341 | ) 1342 | 1343 | def prepare_inputs_for_generation( 1344 | self, 1345 | decoder_input_ids, 1346 | past=None, 1347 | attention_mask=None, 1348 | head_mask=None, 1349 | decoder_head_mask=None, 1350 | cross_attn_head_mask=None, 1351 | use_cache=None, 1352 | encoder_outputs=None, 1353 | **kwargs 1354 | ): 1355 | # cut decoder_input_ids if past is used 1356 | if past is not None: 1357 | decoder_input_ids = decoder_input_ids[:, -1:] 1358 | 1359 | return { 1360 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 1361 | "encoder_outputs": encoder_outputs, 1362 | "past_key_values": past, 1363 | "decoder_input_ids": decoder_input_ids, 1364 | "attention_mask": attention_mask, 1365 | "head_mask": head_mask, 1366 | "decoder_head_mask": decoder_head_mask, 1367 | "cross_attn_head_mask": cross_attn_head_mask, 1368 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 1369 | } 1370 | 1371 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 1372 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 1373 | 1374 | @staticmethod 1375 | def _reorder_cache(past, beam_idx): 1376 | reordered_past = () 1377 | for layer_past in past: 1378 | # cached cross_attention states don't have to be reordered -> they are always the same 1379 | reordered_past += ( 1380 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], 1381 | ) 1382 | return reordered_past 1383 | 1384 | 1385 | class BartForSequenceClassification(BartPretrainedModel): 1386 | def __init__(self, config: BartConfig, **kwargs): 1387 | super().__init__(config, **kwargs) 1388 | self.model = BartModel(config) 1389 | self.classification_head = BartClassificationHead( 1390 | config.d_model, 1391 | config.d_model, 1392 | config.num_labels, 1393 | config.classifier_dropout, 1394 | ) 1395 | self.model._init_weights(self.classification_head.dense) 1396 | self.model._init_weights(self.classification_head.out_proj) 1397 | 1398 | def forward( 1399 | self, 1400 | input_ids: torch.LongTensor = None, 1401 | attention_mask: Optional[torch.Tensor] = None, 1402 | decoder_input_ids: Optional[torch.LongTensor] = None, 1403 | decoder_attention_mask: Optional[torch.LongTensor] = None, 1404 | head_mask: Optional[torch.Tensor] = None, 1405 | decoder_head_mask: Optional[torch.Tensor] = None, 1406 | cross_attn_head_mask: Optional[torch.Tensor] = None, 1407 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 1408 | inputs_embeds: Optional[torch.FloatTensor] = None, 1409 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 1410 | labels: Optional[torch.LongTensor] = None, 1411 | use_cache: Optional[bool] = None, 1412 | output_attentions: Optional[bool] = None, 1413 | output_hidden_states: Optional[bool] = None, 1414 | return_dict: Optional[bool] = None, 1415 | ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: 1416 | r""" 1417 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1418 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1419 | config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1420 | """ 1421 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1422 | if labels is not None: 1423 | use_cache = False 1424 | 1425 | if input_ids is None and inputs_embeds is not None: 1426 | raise NotImplementedError( 1427 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}" 1428 | ) 1429 | 1430 | outputs = self.model( 1431 | input_ids, 1432 | attention_mask=attention_mask, 1433 | decoder_input_ids=decoder_input_ids, 1434 | decoder_attention_mask=decoder_attention_mask, 1435 | head_mask=head_mask, 1436 | decoder_head_mask=decoder_head_mask, 1437 | cross_attn_head_mask=cross_attn_head_mask, 1438 | encoder_outputs=encoder_outputs, 1439 | inputs_embeds=inputs_embeds, 1440 | decoder_inputs_embeds=decoder_inputs_embeds, 1441 | use_cache=use_cache, 1442 | output_attentions=output_attentions, 1443 | output_hidden_states=output_hidden_states, 1444 | return_dict=return_dict, 1445 | ) 1446 | hidden_states = outputs[0] # last hidden state 1447 | 1448 | eos_mask = input_ids.eq(self.config.eos_token_id) 1449 | 1450 | if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: 1451 | raise ValueError("All examples must have the same number of tokens.") 1452 | sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ 1453 | :, -1, : 1454 | ] 1455 | logits = self.classification_head(sentence_representation) 1456 | 1457 | loss = None 1458 | if labels is not None: 1459 | if self.config.problem_type is None: 1460 | if self.config.num_labels == 1: 1461 | self.config.problem_type = "regression" 1462 | elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1463 | self.config.problem_type = "single_label_classification" 1464 | else: 1465 | self.config.problem_type = "multi_label_classification" 1466 | 1467 | if self.config.problem_type == "regression": 1468 | loss_fct = MSELoss() 1469 | if self.config.num_labels == 1: 1470 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 1471 | else: 1472 | loss = loss_fct(logits, labels) 1473 | elif self.config.problem_type == "single_label_classification": 1474 | loss_fct = CrossEntropyLoss() 1475 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) 1476 | elif self.config.problem_type == "multi_label_classification": 1477 | loss_fct = BCEWithLogitsLoss() 1478 | loss = loss_fct(logits, labels) 1479 | if not return_dict: 1480 | output = (logits,) + outputs[1:] 1481 | return ((loss,) + output) if loss is not None else output 1482 | 1483 | return Seq2SeqSequenceClassifierOutput( 1484 | loss=loss, 1485 | logits=logits, 1486 | past_key_values=outputs.past_key_values, 1487 | decoder_hidden_states=outputs.decoder_hidden_states, 1488 | decoder_attentions=outputs.decoder_attentions, 1489 | cross_attentions=outputs.cross_attentions, 1490 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1491 | encoder_hidden_states=outputs.encoder_hidden_states, 1492 | encoder_attentions=outputs.encoder_attentions, 1493 | ) 1494 | 1495 | 1496 | class BartForQuestionAnswering(BartPretrainedModel): 1497 | def __init__(self, config): 1498 | super().__init__(config) 1499 | 1500 | config.num_labels = 2 1501 | self.num_labels = config.num_labels 1502 | 1503 | self.model = BartModel(config) 1504 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) 1505 | 1506 | self.model._init_weights(self.qa_outputs) 1507 | 1508 | def forward( 1509 | self, 1510 | input_ids: torch.Tensor = None, 1511 | attention_mask: Optional[torch.Tensor] = None, 1512 | decoder_input_ids: Optional[torch.LongTensor] = None, 1513 | decoder_attention_mask: Optional[torch.LongTensor] = None, 1514 | head_mask: Optional[torch.Tensor] = None, 1515 | decoder_head_mask: Optional[torch.Tensor] = None, 1516 | cross_attn_head_mask: Optional[torch.Tensor] = None, 1517 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 1518 | start_positions: Optional[torch.LongTensor] = None, 1519 | end_positions: Optional[torch.LongTensor] = None, 1520 | inputs_embeds: Optional[torch.FloatTensor] = None, 1521 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 1522 | use_cache: Optional[bool] = None, 1523 | output_attentions: Optional[bool] = None, 1524 | output_hidden_states: Optional[bool] = None, 1525 | return_dict: Optional[bool] = None, 1526 | ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]: 1527 | r""" 1528 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1529 | Labels for position (index) of the start of the labelled span for computing the token classification loss. 1530 | Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence 1531 | are not taken into account for computing the loss. 1532 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1533 | Labels for position (index) of the end of the labelled span for computing the token classification loss. 1534 | Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence 1535 | are not taken into account for computing the loss. 1536 | """ 1537 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1538 | if start_positions is not None and end_positions is not None: 1539 | use_cache = False 1540 | 1541 | outputs = self.model( 1542 | input_ids, 1543 | attention_mask=attention_mask, 1544 | decoder_input_ids=decoder_input_ids, 1545 | decoder_attention_mask=decoder_attention_mask, 1546 | head_mask=head_mask, 1547 | decoder_head_mask=decoder_head_mask, 1548 | cross_attn_head_mask=cross_attn_head_mask, 1549 | encoder_outputs=encoder_outputs, 1550 | inputs_embeds=inputs_embeds, 1551 | decoder_inputs_embeds=decoder_inputs_embeds, 1552 | use_cache=use_cache, 1553 | output_attentions=output_attentions, 1554 | output_hidden_states=output_hidden_states, 1555 | return_dict=return_dict, 1556 | ) 1557 | 1558 | sequence_output = outputs[0] 1559 | 1560 | logits = self.qa_outputs(sequence_output) 1561 | start_logits, end_logits = logits.split(1, dim=-1) 1562 | start_logits = start_logits.squeeze(-1).contiguous() 1563 | end_logits = end_logits.squeeze(-1).contiguous() 1564 | 1565 | total_loss = None 1566 | if start_positions is not None and end_positions is not None: 1567 | # If we are on multi-GPU, split add a dimension 1568 | if len(start_positions.size()) > 1: 1569 | start_positions = start_positions.squeeze(-1) 1570 | if len(end_positions.size()) > 1: 1571 | end_positions = end_positions.squeeze(-1) 1572 | # sometimes the start/end positions are outside our model inputs, we ignore these terms 1573 | ignored_index = start_logits.size(1) 1574 | start_positions = start_positions.clamp(0, ignored_index) 1575 | end_positions = end_positions.clamp(0, ignored_index) 1576 | 1577 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index) 1578 | start_loss = loss_fct(start_logits, start_positions) 1579 | end_loss = loss_fct(end_logits, end_positions) 1580 | total_loss = (start_loss + end_loss) / 2 1581 | 1582 | if not return_dict: 1583 | output = ( 1584 | start_logits, 1585 | end_logits, 1586 | ) + outputs[1:] 1587 | return ((total_loss,) + output) if total_loss is not None else output 1588 | 1589 | return Seq2SeqQuestionAnsweringModelOutput( 1590 | loss=total_loss, 1591 | start_logits=start_logits, 1592 | end_logits=end_logits, 1593 | past_key_values=outputs.past_key_values, 1594 | decoder_hidden_states=outputs.decoder_hidden_states, 1595 | decoder_attentions=outputs.decoder_attentions, 1596 | cross_attentions=outputs.cross_attentions, 1597 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 1598 | encoder_hidden_states=outputs.encoder_hidden_states, 1599 | encoder_attentions=outputs.encoder_attentions, 1600 | ) 1601 | 1602 | 1603 | class BartDecoderWrapper(BartPretrainedModel): 1604 | """ 1605 | This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is 1606 | used in combination with the [`EncoderDecoderModel`] framework. 1607 | """ 1608 | 1609 | def __init__(self, config): 1610 | super().__init__(config) 1611 | self.decoder = BartDecoder(config) 1612 | 1613 | def forward(self, *args, **kwargs): 1614 | return self.decoder(*args, **kwargs) 1615 | 1616 | 1617 | class BartForCausalLM(BartPretrainedModel): 1618 | def __init__(self, config): 1619 | config = copy.deepcopy(config) 1620 | config.is_decoder = True 1621 | config.is_encoder_decoder = False 1622 | super().__init__(config) 1623 | self.model = BartDecoderWrapper(config) 1624 | 1625 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 1626 | 1627 | # Initialize weights and apply final processing 1628 | self.post_init() 1629 | 1630 | def get_input_embeddings(self): 1631 | return self.model.decoder.embed_tokens 1632 | 1633 | def set_input_embeddings(self, value): 1634 | self.model.decoder.embed_tokens = value 1635 | 1636 | def get_output_embeddings(self): 1637 | return self.lm_head 1638 | 1639 | def set_output_embeddings(self, new_embeddings): 1640 | self.lm_head = new_embeddings 1641 | 1642 | def set_decoder(self, decoder): 1643 | self.model.decoder = decoder 1644 | 1645 | def get_decoder(self): 1646 | return self.model.decoder 1647 | 1648 | def forward( 1649 | self, 1650 | input_ids: torch.LongTensor = None, 1651 | attention_mask: Optional[torch.Tensor] = None, 1652 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 1653 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 1654 | head_mask: Optional[torch.Tensor] = None, 1655 | cross_attn_head_mask: Optional[torch.Tensor] = None, 1656 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1657 | inputs_embeds: Optional[torch.FloatTensor] = None, 1658 | labels: Optional[torch.LongTensor] = None, 1659 | use_cache: Optional[bool] = None, 1660 | output_attentions: Optional[bool] = None, 1661 | output_hidden_states: Optional[bool] = None, 1662 | return_dict: Optional[bool] = None, 1663 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: 1664 | r""" 1665 | Args: 1666 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1667 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 1668 | provide it. 1669 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1670 | [`PreTrainedTokenizer.__call__`] for details. 1671 | [What are input IDs?](../glossary#input-ids) 1672 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 1673 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1674 | - 1 for tokens that are **not masked**, 1675 | - 0 for tokens that are **masked**. 1676 | [What are attention masks?](../glossary#attention-mask) 1677 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1678 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 1679 | if the model is configured as a decoder. 1680 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): 1681 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used 1682 | in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: 1683 | head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 1684 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 1685 | - 1 indicates the head is **not masked**, 1686 | - 0 indicates the head is **masked**. 1687 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 1688 | Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: 1689 | - 1 indicates the head is **not masked**, 1690 | - 0 indicates the head is **masked**. 1691 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 1692 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 1693 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 1694 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional 1695 | tensors are only required when the model is used as a decoder in a Sequence to Sequence model. 1696 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 1697 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 1698 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 1699 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 1700 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. 1701 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1702 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1703 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1704 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1705 | use_cache (`bool`, *optional*): 1706 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 1707 | (see `past_key_values`). 1708 | - 1 for tokens that are **not masked**, 1709 | - 0 for tokens that are **masked**. 1710 | output_attentions (`bool`, *optional*): 1711 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1712 | returned tensors for more detail. 1713 | output_hidden_states (`bool`, *optional*): 1714 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 1715 | for more detail. 1716 | return_dict (`bool`, *optional*): 1717 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 1718 | Returns: 1719 | Example: 1720 | ```python 1721 | >>> from transformers import BartTokenizer, BartForCausalLM 1722 | >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") 1723 | >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False) 1724 | >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." 1725 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") 1726 | >>> outputs = model(**inputs) 1727 | >>> logits = outputs.logits 1728 | >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] 1729 | >>> list(logits.shape) == expected_shape 1730 | True 1731 | ```""" 1732 | 1733 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1734 | output_hidden_states = ( 1735 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1736 | ) 1737 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1738 | 1739 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1740 | outputs = self.model.decoder( 1741 | input_ids=input_ids, 1742 | attention_mask=attention_mask, 1743 | encoder_hidden_states=encoder_hidden_states, 1744 | encoder_attention_mask=encoder_attention_mask, 1745 | head_mask=head_mask, 1746 | cross_attn_head_mask=cross_attn_head_mask, 1747 | past_key_values=past_key_values, 1748 | inputs_embeds=inputs_embeds, 1749 | use_cache=use_cache, 1750 | output_attentions=output_attentions, 1751 | output_hidden_states=output_hidden_states, 1752 | return_dict=return_dict, 1753 | ) 1754 | 1755 | logits = self.lm_head(outputs[0]) 1756 | 1757 | loss = None 1758 | if labels is not None: 1759 | loss_fct = CrossEntropyLoss() 1760 | loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) 1761 | 1762 | if not return_dict: 1763 | output = (logits,) + outputs[1:] 1764 | return (loss,) + output if loss is not None else output 1765 | 1766 | return CausalLMOutputWithCrossAttentions( 1767 | loss=loss, 1768 | logits=logits, 1769 | past_key_values=outputs.past_key_values, 1770 | hidden_states=outputs.hidden_states, 1771 | attentions=outputs.attentions, 1772 | cross_attentions=outputs.cross_attentions, 1773 | ) 1774 | 1775 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs): 1776 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly 1777 | if attention_mask is None: 1778 | attention_mask = input_ids.new_ones(input_ids.shape) 1779 | 1780 | if past: 1781 | input_ids = input_ids[:, -1:] 1782 | # first step, decoder_cached_states are empty 1783 | return { 1784 | "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed 1785 | "attention_mask": attention_mask, 1786 | "past_key_values": past, 1787 | "use_cache": use_cache, 1788 | } 1789 | 1790 | @staticmethod 1791 | def _reorder_cache(past, beam_idx): 1792 | reordered_past = () 1793 | for layer_past in past: 1794 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 1795 | return reordered_past --------------------------------------------------------------------------------