├── .gitignore
├── filtered_dict.pkl
├── LICENSE.md
├── qfid
├── __pycache__
│ ├── qfid.cpython-38.pyc
│ └── qfid.cpython-39.pyc
├── test.sh
├── train.sh
├── run_summarization.py
└── qfid.py
├── requirements.txt
├── make_summarization_csv.py
├── json_to_df.py
├── README.md
└── make_section_df.py
/.gitignore:
--------------------------------------------------------------------------------
1 | dataset
2 |
--------------------------------------------------------------------------------
/filtered_dict.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tetsu9923/SciReviewGen/HEAD/filtered_dict.pkl
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | SciReviewGen is licensed under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/).
2 |
--------------------------------------------------------------------------------
/qfid/__pycache__/qfid.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tetsu9923/SciReviewGen/HEAD/qfid/__pycache__/qfid.cpython-38.pyc
--------------------------------------------------------------------------------
/qfid/__pycache__/qfid.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tetsu9923/SciReviewGen/HEAD/qfid/__pycache__/qfid.cpython-39.pyc
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.16.0
2 | datasets==1.18.4
3 | numpy==1.22.3
4 | pandas==1.4.1
5 | torch==1.11.0
6 | filelock
7 | nltk
8 | tqdm
9 | wrapt
10 | h5py
11 | rouge_score
--------------------------------------------------------------------------------
/qfid/test.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 \
2 | python run_summarization.py \
3 | --model_name_or_path facebook/bart-large-cnn \
4 | --do_train \
5 | --do_predict \
6 | --train_file ../dataset/train_qfid.csv \
7 | --validation_file ../dataset/val_qfid.csv \
8 | --test_file ../dataset/test_qfid.csv \
9 | --text_column reference \
10 | --summary_column target \
11 | --output_dir ./test \
12 | --per_device_train_batch_size=1 \
13 | --per_device_eval_batch_size=1 \
14 | --num_train_epochs=1 \
15 | --predict_with_generate \
16 | --save_strategy epoch \
17 | --learning_rate 5e-05 \
18 | --evaluation_strategy epoch \
19 | --load_best_model_at_end \
20 | --metric_for_best_model=eval_rouge2 \
21 | --greater_is_better True \
22 | --save_total_limit=1 \
23 | --max_source_length 2048 \
24 | --max_target_length 512 \
25 | --generation_max_length 256 \
26 | --num_beams 4 \
27 |
--------------------------------------------------------------------------------
/qfid/train.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 \
2 | python run_summarization.py \
3 | --model_name_or_path facebook/bart-large-cnn \
4 | --do_train \
5 | --do_eval \
6 | --do_predict \
7 | --train_file ../dataset/train_qfid.csv \
8 | --validation_file ../dataset/val_qfid.csv \
9 | --test_file ../dataset/test_qfid.csv \
10 | --text_column reference \
11 | --summary_column target \
12 | --output_dir ./test \
13 | --per_device_train_batch_size=1 \
14 | --per_device_eval_batch_size=1 \
15 | --num_train_epochs=1 \
16 | --predict_with_generate \
17 | --save_strategy epoch \
18 | --learning_rate 5e-05 \
19 | --evaluation_strategy epoch \
20 | --load_best_model_at_end \
21 | --metric_for_best_model=eval_rouge2 \
22 | --greater_is_better True \
23 | --save_total_limit=1 \
24 | --max_source_length 2048 \
25 | --max_target_length 512 \
26 | --generation_max_length 256 \
27 | --num_beams 4 \
28 | --overwrite_output_dir \
29 |
--------------------------------------------------------------------------------
/make_summarization_csv.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import logging
4 | import pickle
5 | import argparse
6 |
7 | import numpy as np
8 | import pandas as pd
9 |
10 | from tqdm import tqdm
11 |
12 |
13 | def make_summarization_csv(args):
14 | if args.for_qfid:
15 | logging.info('Making csv files for QFiD...')
16 | logging.info('Columns={"reference": literature review title chapter title literature review title chapter title abstract of cited paper 1 BIB001 literature review title chapter title abstract of cited paper 2 BIB002 ..., "target": literature review chapter}')
17 | else:
18 | logging.info('Making csv files for summarization...')
19 | logging.info('Columns={"reference": literature review title chapter title abstract of cited paper 1 BIB001 literature review title chapter title abstract of cited paper 2 BIB002 ..., "target": literature review chapter}')
20 | section_df = pd.read_pickle(os.path.join(args.dataset_path, 'split_survey_df.pkl'))
21 |
22 | dataset_df = section_df[section_df['n_bibs'].apply(lambda n_bibs: n_bibs >= 2)]
23 |
24 | dataset_df = dataset_df.rename(columns={'text': 'target'})
25 | dataset_df = dataset_df.rename(columns={'bib_cinting_sentences': 'bib_citing_sentences'})
26 |
27 | dataset_df['reference'] = dataset_df[['bib_abstracts', 'section', 'title']].apply(lambda bib_abstracts: ' '.join([' {} {} {} BIB{}'.format(bib_abstracts[2], bib_abstracts[1], abstract, bib) for bib, abstract in bib_abstracts[0].items()]), axis=1)
28 | if args.for_qfid:
29 | dataset_df['reference'] = dataset_df['title'] + ' ' + dataset_df['section'] + ' ' + dataset_df['reference']
30 | else:
31 | dataset_df['reference'] = dataset_df['reference'].apply(lambda s: s[5:])
32 |
33 | split_df = dataset_df['split']
34 | dataset_df = dataset_df[['reference', 'target']]
35 |
36 | train_df = dataset_df[split_df == 'train']
37 | val_df = dataset_df[split_df == 'val']
38 | test_df = dataset_df[split_df == 'test']
39 |
40 | if args.for_qfid:
41 | train_df.to_csv(os.path.join(args.dataset_path, 'train_qfid.csv'), index=False)
42 | val_df.to_csv(os.path.join(args.dataset_path, 'val_qfid.csv'), index=False)
43 | test_df.to_csv(os.path.join(args.dataset_path, 'test_qfid.csv'), index=False)
44 | else:
45 | train_df.to_csv(os.path.join(args.dataset_path, 'train.csv'), index=False)
46 | val_df.to_csv(os.path.join(args.dataset_path, 'val.csv'), index=False)
47 | test_df.to_csv(os.path.join(args.dataset_path, 'test.csv'), index=False)
48 | logging.info('Done!')
49 |
50 |
51 | def anonymize_bib(args):
52 | logging.info('Converting BIB identifiers...')
53 | for split in ['val', 'test', 'train']:
54 | if args.for_qfid:
55 | df = pd.read_csv(os.path.join(args.dataset_path, '{}_qfid.csv'.format(split)))
56 | else:
57 | df = pd.read_csv(os.path.join(args.dataset_path, '{}.csv'.format(split)))
58 | bar = tqdm(total=len(df))
59 | for row in df.itertuples():
60 | cnt = 1
61 | bib_dict = {}
62 | for i in range(len(row.reference)):
63 | if row.reference[i:i+7] == ' BIB':
64 | bib_dict[row.reference[i+7:].split(' ')[0]] = cnt
65 | cnt += 1
66 | ref = row.reference
67 | tgt = row.target
68 | for key, value in bib_dict.items():
69 | ref = re.sub('BIB{}'.format(key), 'BIB{:0>3}'.format(value), ref)
70 | tgt = re.sub('BIB{}'.format(key), 'BIB{:0>3}'.format(value), tgt)
71 | df.at[row.Index, 'reference'] = ref
72 | df.at[row.Index, 'target'] = tgt
73 | bar.update(1)
74 | logging.info('Saving...')
75 | if args.for_qfid:
76 | df.to_csv(os.path.join(args.dataset_path, '{}_qfid.csv'.format(split)), index=False)
77 | else:
78 | df.to_csv(os.path.join(args.dataset_path, '{}.csv'.format(split)), index=False)
79 |
80 |
81 | if __name__ == '__main__':
82 | logging.basicConfig(format='%(message)s', level=logging.DEBUG)
83 |
84 | parser = argparse.ArgumentParser(description='')
85 | parser.add_argument('-dataset_path', help='Path to the generated dataset')
86 | parser.add_argument('--for_qfid', action='store_true', help='Add if you train QFiD on the generated csv files')
87 | args = parser.parse_args()
88 |
89 | make_summarization_csv(args) # Convert split_survey_df into csv files suitable for summarization
90 | anonymize_bib(args) # Converting BIB{paper_id} into BIB{001, 002, ...}
91 |
--------------------------------------------------------------------------------
/json_to_df.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import pickle
4 | import argparse
5 |
6 | import pandas as pd
7 |
8 | from tqdm import tqdm
9 |
10 |
11 | def save_metadata_survey_df(args):
12 | logging.info('Loading metadata of literature reviews...')
13 | keywords = ['survey', 'overview', 'literature review', 'a review']
14 |
15 | metadata_survey_dfs = []
16 | for i in tqdm(range(100)):
17 | metadata_df = pd.read_json(os.path.join(args.s2orc_path, 'metadata/metadata_{}.jsonl.gz'.format(i)), lines=True, compression='infer')
18 | metadata_survey_df = metadata_df[
19 | (metadata_df.mag_field_of_study.apply(lambda field: args.field in field if field is not None else False))
20 | & (metadata_df.title.apply(lambda title: any([word in title.lower() for word in keywords])))
21 | & metadata_df.has_outbound_citations
22 | & metadata_df.has_pdf_body_text
23 | & ~metadata_df.abstract.isna()
24 | ]
25 | metadata_survey_dfs.append(metadata_survey_df)
26 | metadata_survey_df = pd.concat(metadata_survey_dfs)
27 | metadata_survey_df = metadata_survey_df.set_index('paper_id')
28 | metadata_survey_df.index = metadata_survey_df.index.astype('str')
29 | metadata_survey_df.to_pickle(os.path.join(args.dataset_path, 'metadata_survey_df.pkl'))
30 | logging.info('Done!')
31 |
32 | def save_metadata_outbound_df(args):
33 | logging.info('Loading metadata of cited papers...')
34 | metadata_survey_df = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_survey_df.pkl'))
35 | outbound_paper_ids = set([paper_id for paper_ids in metadata_survey_df.outbound_citations.values for paper_id in paper_ids])
36 |
37 | metadata_outbound_dfs = []
38 | for i in tqdm(range(100)):
39 | metadata_df = pd.read_json(os.path.join(args.s2orc_path, 'metadata/metadata_{}.jsonl.gz'.format(i)), lines=True, compression='infer')
40 | metadata_df = metadata_df.set_index('paper_id')
41 | metadata_df.index = metadata_df.index.astype('str')
42 | cs_survey_outbound_paper_ids = list(outbound_paper_ids & set(metadata_df.index))
43 | metadata_outbound_df = metadata_df.loc[cs_survey_outbound_paper_ids]
44 | metadata_outbound_dfs.append(metadata_outbound_df)
45 | metadata_outbound_df = pd.concat(metadata_outbound_dfs)
46 | metadata_outbound_df.to_pickle(os.path.join(args.dataset_path, 'metadata_outbound_df.pkl'))
47 | logging.info('Done!')
48 |
49 | def save_pdf_df(args):
50 | logging.info('Loading pdf parses of literature review papers and cited papers...')
51 | metadata_survey_df = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_survey_df.pkl'))
52 | survey_index = metadata_survey_df.index
53 | outbound_index = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_outbound_df.pkl')).index
54 |
55 | pdf_survey_dfs = []
56 | pdf_outbound_dfs = []
57 | for i in tqdm(range(100)):
58 | pdf_df = pd.read_json(os.path.join(args.s2orc_path, 'pdf_parses/pdf_parses_{}.jsonl.gz'.format(i)), lines=True, compression='infer')
59 | pdf_df = pdf_df.set_index('paper_id')
60 | pdf_df.index = pdf_df.index.astype('str')
61 |
62 | pdf_survey_paper_ids = list(set(survey_index) & set(pdf_df.index))
63 | pdf_survey_df = pdf_df.loc[pdf_survey_paper_ids]
64 | pdf_survey_df = pdf_survey_df[['body_text', 'bib_entries']]
65 |
66 | pdf_survey_df['title'] = ''
67 | pdf_survey_df['abstract'] = ''
68 | for i, row in enumerate(pdf_survey_df.itertuples()):
69 | pdf_survey_df.at[pdf_survey_df.index[i], 'title'] = metadata_survey_df.query('paper_id == @row.Index')['title'].item()
70 | pdf_survey_df.at[pdf_survey_df.index[i], 'abstract'] = metadata_survey_df.query('paper_id == @row.Index')['abstract'].item()
71 | pdf_survey_dfs.append(pdf_survey_df)
72 |
73 | pdf_outbound_paper_ids = list(set(outbound_index) & set(pdf_df.index))
74 | pdf_outbound_df = pdf_df.loc[pdf_outbound_paper_ids]
75 | pdf_outbound_df = pdf_outbound_df[pdf_outbound_df.body_text.apply(lambda text: len(text) > 0)]
76 | pdf_outbound_df = pdf_outbound_df[['body_text', 'bib_entries']]
77 | pdf_outbound_dfs.append(pdf_outbound_df)
78 |
79 | pdf_survey_df = pd.concat(pdf_survey_dfs)
80 | pdf_survey_df.to_pickle(os.path.join(args.dataset_path, 'pdf_survey_df.pkl'))
81 | pdf_outbound_df = pd.concat(pdf_outbound_dfs)
82 | pdf_outbound_df.to_pickle(os.path.join(args.dataset_path, 'pdf_outbound_df.pkl'))
83 | logging.info('Done!')
84 |
85 |
86 | if __name__ == '__main__':
87 | logging.basicConfig(format='%(message)s', level=logging.DEBUG)
88 |
89 | parser = argparse.ArgumentParser(description='')
90 | parser.add_argument('-s2orc_path', help='Path to the S2ORC full dataset directory (Typically ".../s2orc/full/20200705v1/full")')
91 | parser.add_argument('-dataset_path', help='Path to the generated dataset')
92 | parser.add_argument('--field', default='Computer Science', help='Field of literature reviews')
93 | args = parser.parse_args()
94 |
95 | save_metadata_survey_df(args) # collect metadata of the literature reviews
96 | save_metadata_outbound_df(args) # collect metadata of the cited papers
97 | save_pdf_df(args) # collect pdf parses of the literature reviews and the cited papers
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SciReviewGen
2 | **This is the official dataset repository for [SciReviewGen: A Large-scale Dataset for Automatic Literature Review Generation](https://arxiv.org/pdf/2305.15186.pdf) in ACL findings 2023.**
3 |
4 | ## Dataset
5 | - [split_survey_df](https://drive.google.com/file/d/1S6v-xaCDND4ilK38sEpkfcOoMnffX7Zf/view?usp=sharing): The split version of SciReviewGen, which aims to generate literature review **chapters**
6 | - [original_survey_df](https://drive.google.com/file/d/1MnjQ2fQ_fJjcqKvIwj2w7P6IGh4GszXH/view?usp=sharing): The original version of SciReviewGen, which aims to generate **the entire text** of literature reviews
7 | - [summarization_csv](https://drive.google.com/file/d/1okvILkxfrpTQYWLxbV4lM9BQnuVaAfbY/view?usp=sharing): CSV files suitable for summarization task. You can apply them to [HuggingFace's official sample codes](https://github.com/huggingface/transformers/tree/main/examples/pytorch/summarization#custom-csv-files)
8 |
9 | ### Data format
10 | #### split_survey_df & original_survey_df
11 | - Row:
12 | - literature review chapter or the entire text of literature review
13 | - Column:
14 | - paper_id: paper_id used in [S2ORC](https://github.com/allenai/s2orc)
15 | - title: title of the literature review
16 | - abstract: abstract of the literature review
17 | - section: chapter title
18 | - text: body text of literature review chapter or literature review paper
19 | - n_bibs: number of the cited papers that can be used as inputs
20 | - n_nonbibs: number of the cited papers that cannot be used as inputs
21 | - bib_titles: titles of the cited papers
22 | - bib_abstracts: abstracts of the cited papers
23 | - bib_citing_sentences: citing sentences that cite the cited papers
24 | - split: train/val/test split
25 |
26 | #### summarization_csv
27 | - Row:
28 | - literature review chapter
29 | - Column:
30 | - reference: `literature review title chapter title abstract of cited paper 1 BIB001 literature review title chapter title abstract of cited paper 2 BIB002 ...`
31 | - target: literature review chapter
32 |
33 |
34 | ## How to create SciReviewGen from S2ORC
35 | ### 0. Environment
36 | - Python 3.9
37 | - Run the following command to clone the repository and install the required packages
38 | ```
39 | git clone https://github.com/tetsu9923/SciReviewGen.git
40 | cd SciReviewGen
41 | pip install -r requirements.txt
42 | ```
43 |
44 | ### 1. Preprocessing
45 | - Download [S2ORC](https://github.com/allenai/s2orc) (We use the version released on **2020-07-05**, which contains papers up until 2020-04-14)
46 | - Run the following command:
47 | ```
48 | python json_to_df.py \
49 | -s2orc_path \
50 | -dataset_path \
51 | --field
52 | ```
53 | The metadata and pdf parses of the candidates for the literature reviews and the cited papers are stored in *dataset_path* (in the form of pandas dataframe).
54 |
55 | ### 2. Construct SciReviewGen
56 | - Run the following command:
57 | ```
58 | python make_section_df.py \
59 | -dataset_path \
60 | --version
61 | ```
62 | The SciReviewGen dataset (**split_survey_df.pkl** or **original_survey_df.pkl**) is stored in *dataset_path* (in the form of pandas dataframe).
63 | `filtered_dict.pkl` gives the list of literature reviews after filtering by the [SciBERT](https://arxiv.org/abs/1903.10676)-based classifier (Section 3.2).
64 |
65 | ### 3. Construct csv data for summarization
66 | - Run the following command:
67 | ```
68 | python make_summarization_csv.py \
69 | -dataset_path
70 | ```
71 | The csv files for summarization (**train.csv**, **val.csv**, and **test.csv**) are stored in *dataset_path*.
72 | If you train QFiD on the generated csv files, add `--for_qfid` argument as below.
73 | ```
74 | python make_summarization_csv.py \
75 | -dataset_path \
76 | --for_qfid
77 | ```
78 |
79 |
80 | ## Additional resources
81 | ### SciBERT-based literature review classifier
82 | We trained the [SciBERT](https://arxiv.org/abs/1903.10676)-based literature review classifier.
83 | The model weights are available [here](https://drive.google.com/file/d/1cPGJpvCFQkHX2td99YyFitBirG-eCcLC/view?usp=sharing).
84 |
85 | ### Query-weighted Fusion-in-Decoder (QFiD)
86 | We proposed Query-weighted Fusion-in-Decoder (QFiD) that explicitly considers the relevance of each input document to the queries.
87 | You can train QFiD on SciReviewGen csv data (**Make sure that you passed** `--for_qfid` **argument when executing** `make_summarization_csv.py`).
88 | #### Train
89 | - Modify qfid/train.sh (CUDA_VISIBLE_DEVICES, csv file path, outpput_dir, and num_train_epochs)
90 | - Run the following command:
91 | ```
92 | cd qfid
93 | ./train.sh
94 | ```
95 | #### Test
96 | - Modify qfid/test.sh (CUDA_VISIBLE_DEVICES, csv file path, outpput_dir, and num_train_epochs. **Please set *num_train_epochs* as the number of epochs you trained in total**)
97 | - Run the following command:
98 | ```
99 | ./test.sh
100 | ```
101 |
102 | ## Licenses
103 | - SciReviewGen is released under [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/). **You can use SciReviewGen for only non-commercial purposes.**
104 | - SciReviewGen is created based on [S2ORC](https://github.com/allenai/s2orc). Note that S2ORC is released under CC BY-NC 4.0, which allows users to copy and redistribute for only non-commercial purposes.
105 |
--------------------------------------------------------------------------------
/make_section_df.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import pickle
4 | import argparse
5 |
6 | import nltk.data
7 | import numpy as np
8 | import pandas as pd
9 |
10 | from tqdm import tqdm
11 |
12 |
13 | def append_citing_sentence(args):
14 | extra_abbreviations = ['dr', 'vs', 'mr', 'mrs', 'prof', 'inc', 'i.e', 'al']
15 | tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
16 | tokenizer._params.abbrev_types.update(extra_abbreviations)
17 |
18 | logging.info('Loading citing sentences...')
19 | cs_survey_ids = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_survey_df.pkl')).index
20 | pdf_outbound_df = pd.read_pickle(os.path.join(args.dataset_path, 'pdf_outbound_df.pkl'))
21 | metadata_outbound_df = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_outbound_df.pkl'))
22 | metadata_outbound_df = metadata_outbound_df.dropna(subset=['abstract'])
23 |
24 | citing_df = pdf_outbound_df.copy()
25 | citing_df['bib_entries'] = citing_df['bib_entries'].apply(
26 | lambda bib_entries: {key: value['link'] for key, value in bib_entries.items()}
27 | )
28 | citing_df['bib_mark'] = citing_df['body_text'].apply(
29 | lambda body_text: [cite['text'] for cite in sum([body['cite_spans'] for body in body_text], [])]
30 | )
31 |
32 | citing_dict = {}
33 | bar = tqdm(total=len(citing_df))
34 | # citing_dict: {paper A: {paper B cited in A: 'citing sentence in A (explaining B)', paper C cited in A: 'citing sentence in A (explaining C)'}, paperD: ...}
35 | for i, row in enumerate(citing_df.itertuples()):
36 | tmp_dict = {}
37 | for section in row.body_text:
38 | section_text = section['text'].split()
39 | for citing in section['cite_spans']:
40 | for sentence in tokenizer.tokenize(section['text']):
41 | if citing['text'] in sentence:
42 | text = sentence
43 | if citing['ref_id'] in row.bib_entries.keys():
44 | if row.bib_entries[citing['ref_id']] != None:
45 | tmp_dict[row.bib_entries[citing['ref_id']]] = text
46 | break
47 | citing_dict[citing_df.index[i]] = tmp_dict
48 | bar.update(1)
49 | logging.info('Done!')
50 |
51 | logging.info('Appending citing sentences to metadata_outbound_df...')
52 | citing_sentence_list = []
53 | bar = tqdm(total=len(metadata_outbound_df))
54 | for index, citing_papers in metadata_outbound_df['inbound_citations'].iteritems():
55 | citing_sentence = []
56 | n_of_citing = 0
57 | for citing_paper in citing_papers:
58 | if citing_paper in citing_dict.keys() and citing_paper not in cs_survey_ids:
59 | if index in citing_dict[citing_paper].keys():
60 | citing_sentence.append(citing_dict[citing_paper][index])
61 | n_of_citing += 1
62 | citing_sentence_list.append(citing_sentence)
63 | bar.update(1)
64 |
65 | metadata_outbound_df['citing_sentence'] = ''
66 | bar = tqdm(total=len(metadata_outbound_df))
67 | for i, row in enumerate(metadata_outbound_df.itertuples()):
68 | append_sentence = citing_sentence_list[i]
69 | metadata_outbound_df.at[metadata_outbound_df.index[i], 'citing_sentence'] = append_sentence
70 | bar.update(1)
71 |
72 | metadata_outbound_df.to_pickle(os.path.join(args.dataset_path, 'metadata_outbound_citation_df.pkl'))
73 | logging.info('Done!')
74 |
75 |
76 | def make_scireviewgen(args):
77 | logging.info('Making section_df...')
78 | metadata_outbound_df = pd.read_pickle(os.path.join(args.dataset_path, 'metadata_outbound_citation_df.pkl'))
79 | metadata_outbound_df = metadata_outbound_df.dropna(subset=['abstract'])
80 | pdf_survey_df = pd.read_pickle(os.path.join(args.dataset_path, 'pdf_survey_df.pkl'))
81 | pdf_survey_df = pdf_survey_df[pdf_survey_df['abstract'].apply(lambda s: type(s) == str)]
82 |
83 | def get_section_df(row):
84 | sections_duplicate = [paragraph['section'] for paragraph in row.body_text]
85 | sections = sorted(set(sections_duplicate), key=sections_duplicate.index)
86 | bib_df = pd.DataFrame.from_dict(row.bib_entries, orient='index')
87 | bib_df = bib_df[bib_df.link.apply(lambda paper_id: paper_id in metadata_outbound_df.index)]
88 | bib_dict = bib_df.link.dropna().to_dict()
89 |
90 | def replace_cite(body_row):
91 | body_text = ''
92 | start_index = 0
93 | for cite_span in body_row.cite_spans:
94 | end_index = cite_span['start']
95 | ref_id = cite_span['ref_id']
96 | body_text += body_row.text_raw[start_index:end_index]
97 | body_text += 'BIB{}'.format(bib_dict[ref_id]) if ref_id in bib_dict else ''
98 | start_index = cite_span['end']
99 | body_text += body_row.text_raw[start_index:]
100 | return body_text
101 |
102 | body_df = pd.DataFrame(row.body_text).rename(columns={'text': 'text_raw'})
103 | body_df['text'] = body_df[['text_raw', 'cite_spans']].apply(replace_cite, axis=1)
104 | body_df['title'] = row.title
105 | body_df['abstract'] = row.abstract
106 |
107 | section_df = body_df.groupby('section').agg({
108 | 'text': lambda text_series: ' '.join([text for text in text_series]),
109 | 'title': lambda text_series: [text for text in text_series][0],
110 | 'abstract': lambda text_series: [text for text in text_series][0],
111 | 'cite_spans': lambda cite_spans_series: [cite['ref_id'] for cite_spans in cite_spans_series for cite in cite_spans],
112 | })
113 | section_df = section_df.loc[sections]
114 | section_df['bibs'] = section_df['cite_spans'].apply(lambda spans: [bib_dict[span] for span in spans if span in bib_dict])
115 | section_df['n_bibs'] = section_df[['cite_spans', 'bibs']].apply(lambda row: len(row['bibs']), axis=1)
116 | section_df['n_nonbibs'] = section_df[['cite_spans', 'bibs']].apply(lambda row: len(row['cite_spans']) - len(row['bibs']), axis=1)
117 |
118 | section_df['paper_id'] = row.name
119 | section_df['section_id'] = section_df['paper_id'] + '-' + np.arange(len(section_df)).astype('str')
120 | section_df['section'] = section_df.index
121 | section_df = section_df.set_index('section_id')
122 |
123 | section_df['bib_titles'] = section_df['bibs'].apply(lambda bibs: metadata_outbound_df.loc[bibs]['title'].to_dict())
124 | section_df['bib_abstracts'] = section_df['bibs'].apply(lambda bibs: metadata_outbound_df.loc[bibs]['abstract'].to_dict())
125 | section_df['bib_years'] = section_df['bibs'].apply(lambda bibs: metadata_outbound_df.loc[bibs]['year'].to_dict())
126 | section_df['bib_abstracts'] = section_df[['bib_abstracts', 'bib_years']].apply(lambda bib: dict(sorted(bib[0].items(), key=lambda x: bib[1][x[0]])), axis=1) # Sort by publication year
127 | section_df['bib_citing_sentences'] = section_df['bibs'].apply(lambda bibs: metadata_outbound_df.loc[bibs]['citing_sentence'].to_dict())
128 | section_df = section_df[['paper_id', 'title', 'abstract', 'section', 'text', 'n_bibs', 'n_nonbibs', 'bib_titles', 'bib_abstracts', 'bib_citing_sentences']]
129 | return section_df
130 |
131 | section_survey_df = pd.concat(pdf_survey_df.apply(get_section_df, axis=1).values)
132 | section_survey_df = section_survey_df[section_survey_df['text'].apply(len) >= 1] # Remove sections without body text
133 |
134 | with open ('filtered_dict.pkl', 'rb') as f:
135 | filtering_dict = pickle.load(f)
136 | section_survey_df = section_survey_df[section_survey_df['paper_id'].isin(filtering_dict.keys())]
137 | section_survey_df['split'] = section_survey_df['paper_id'].apply(lambda s: filtering_dict[s])
138 |
139 | if args.version == 'split':
140 | section_survey_df = section_survey_df[section_survey_df['bib_abstracts'].apply(lambda _dict: len(_dict) >= 2)] # Remove sections with less than two cited papers
141 | section_survey_df.to_pickle(os.path.join(args.dataset_path, 'split_survey_df.pkl'))
142 | else:
143 | section_survey_df = section_survey_df.groupby('paper_id').agg({
144 | 'title': lambda l: list(l)[0],
145 | 'abstract': lambda l: list(l)[0],
146 | 'section': lambda l: list(l),
147 | 'text': lambda l: list(l),
148 | 'n_bibs': lambda l: list(l),
149 | 'n_nonbibs': lambda l: list(l),
150 | 'bib_titles': lambda l: list(l),
151 | 'bib_abstracts': lambda l: list(l),
152 | 'bib_citing_sentences': lambda l: list(l),
153 | 'split': lambda l: list(l)[0],
154 | })
155 | section_survey_df.to_pickle(os.path.join(args.dataset_path, 'original_survey_df.pkl'))
156 |
157 | logging.info('Done!')
158 |
159 |
160 | if __name__ == '__main__':
161 | logging.basicConfig(format='%(message)s', level=logging.DEBUG)
162 |
163 | parser = argparse.ArgumentParser(description='')
164 | parser.add_argument('-dataset_path', help='Path to the generated dataset')
165 | parser.add_argument('--version', default='split', help='Specify the version ("split" or "original")', choices=['split', 'original'])
166 | parser.add_argument('--n_val', type=int, default=1000, help='Number of literature review papers in validation set')
167 | parser.add_argument('--n_test', type=int, default=1000, help='Number of literature review papers in test set')
168 | args = parser.parse_args()
169 |
170 | append_citing_sentence(args) # collect citing sentences for the cited papers
171 | make_scireviewgen(args) # make scireviewgen dataset in the form of pandas dataframe
--------------------------------------------------------------------------------
/qfid/run_summarization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 | from dataclasses import dataclass, field
25 | from typing import Optional
26 |
27 | import datasets
28 | import nltk # Here to have a nice missing dependency error message early on
29 | import numpy as np
30 | from datasets import load_dataset, load_metric
31 |
32 | import transformers
33 | from filelock import FileLock
34 | from transformers import (
35 | AutoConfig,
36 | AutoModelForSeq2SeqLM,
37 | AutoTokenizer,
38 | DataCollatorForSeq2Seq,
39 | HfArgumentParser,
40 | MBart50Tokenizer,
41 | MBart50TokenizerFast,
42 | MBartTokenizer,
43 | MBartTokenizerFast,
44 | Seq2SeqTrainer,
45 | Seq2SeqTrainingArguments,
46 | set_seed,
47 | )
48 | from transformers.file_utils import is_offline_mode
49 | from transformers.trainer_utils import get_last_checkpoint
50 | from transformers.utils import check_min_version
51 | from transformers.utils.versions import require_version
52 | from qfid import BartForConditionalGeneration
53 |
54 |
55 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
56 | check_min_version("4.16.0.dev0")
57 |
58 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
59 |
60 | logger = logging.getLogger(__name__)
61 |
62 | try:
63 | nltk.data.find("tokenizers/punkt")
64 | except (LookupError, OSError):
65 | if is_offline_mode():
66 | raise LookupError(
67 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
68 | )
69 | with FileLock(".lock") as lock:
70 | nltk.download("punkt", quiet=True)
71 |
72 | # A list of all multilingual tokenizer which require lang attribute.
73 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]
74 |
75 |
76 | @dataclass
77 | class ModelArguments:
78 | """
79 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
80 | """
81 |
82 | model_name_or_path: str = field(
83 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
84 | )
85 | config_name: Optional[str] = field(
86 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
87 | )
88 | tokenizer_name: Optional[str] = field(
89 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
90 | )
91 | cache_dir: Optional[str] = field(
92 | default=None,
93 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
94 | )
95 | use_fast_tokenizer: bool = field(
96 | default=True,
97 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
98 | )
99 | model_revision: str = field(
100 | default="main",
101 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
102 | )
103 | use_auth_token: bool = field(
104 | default=False,
105 | metadata={
106 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
107 | "with private models)."
108 | },
109 | )
110 | resize_position_embeddings: Optional[bool] = field(
111 | default=None,
112 | metadata={
113 | "help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
114 | "the model's position embeddings."
115 | },
116 | )
117 |
118 |
119 | @dataclass
120 | class DataTrainingArguments:
121 | """
122 | Arguments pertaining to what data we are going to input our model for training and eval.
123 | """
124 |
125 | lang: str = field(default=None, metadata={"help": "Language id for summarization."})
126 |
127 | dataset_name: Optional[str] = field(
128 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
129 | )
130 | dataset_config_name: Optional[str] = field(
131 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
132 | )
133 | text_column: Optional[str] = field(
134 | default=None,
135 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
136 | )
137 | summary_column: Optional[str] = field(
138 | default=None,
139 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
140 | )
141 | train_file: Optional[str] = field(
142 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
143 | )
144 | validation_file: Optional[str] = field(
145 | default=None,
146 | metadata={
147 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on "
148 | "(a jsonlines or csv file)."
149 | },
150 | )
151 | test_file: Optional[str] = field(
152 | default=None,
153 | metadata={
154 | "help": "An optional input test data file to evaluate the metrics (rouge) on " "(a jsonlines or csv file)."
155 | },
156 | )
157 | overwrite_cache: bool = field(
158 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
159 | )
160 | preprocessing_num_workers: Optional[int] = field(
161 | default=None,
162 | metadata={"help": "The number of processes to use for the preprocessing."},
163 | )
164 | max_source_length: Optional[int] = field(
165 | default=1024,
166 | metadata={
167 | "help": "The maximum total input sequence length after tokenization. Sequences longer "
168 | "than this will be truncated, sequences shorter will be padded."
169 | },
170 | )
171 | max_target_length: Optional[int] = field(
172 | default=128,
173 | metadata={
174 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer "
175 | "than this will be truncated, sequences shorter will be padded."
176 | },
177 | )
178 | val_max_target_length: Optional[int] = field(
179 | default=None,
180 | metadata={
181 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
182 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
183 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
184 | "during ``evaluate`` and ``predict``."
185 | },
186 | )
187 | pad_to_max_length: bool = field(
188 | default=False,
189 | metadata={
190 | "help": "Whether to pad all samples to model maximum sentence length. "
191 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
192 | "efficient on GPU but very bad for TPU."
193 | },
194 | )
195 | max_train_samples: Optional[int] = field(
196 | default=None,
197 | metadata={
198 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
199 | "value if set."
200 | },
201 | )
202 | max_eval_samples: Optional[int] = field(
203 | default=None,
204 | metadata={
205 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
206 | "value if set."
207 | },
208 | )
209 | max_predict_samples: Optional[int] = field(
210 | default=None,
211 | metadata={
212 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
213 | "value if set."
214 | },
215 | )
216 | num_beams: Optional[int] = field(
217 | default=None,
218 | metadata={
219 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
220 | "which is used during ``evaluate`` and ``predict``."
221 | },
222 | )
223 | ignore_pad_token_for_loss: bool = field(
224 | default=True,
225 | metadata={
226 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
227 | },
228 | )
229 | source_prefix: Optional[str] = field(
230 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
231 | )
232 |
233 | forced_bos_token: Optional[str] = field(
234 | default=None,
235 | metadata={
236 | "help": "The token to force as the first generated token after the decoder_start_token_id."
237 | "Useful for multilingual models like mBART where the first generated token"
238 | "needs to be the target language token (Usually it is the target language token)"
239 | },
240 | )
241 |
242 | def __post_init__(self):
243 | if self.dataset_name is None and self.train_file is None and self.validation_file is None:
244 | raise ValueError("Need either a dataset name or a training/validation file.")
245 | else:
246 | if self.train_file is not None:
247 | extension = self.train_file.split(".")[-1]
248 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
249 | if self.validation_file is not None:
250 | extension = self.validation_file.split(".")[-1]
251 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
252 | if self.val_max_target_length is None:
253 | self.val_max_target_length = self.max_target_length
254 |
255 |
256 | summarization_name_mapping = {
257 | "amazon_reviews_multi": ("review_body", "review_title"),
258 | "big_patent": ("description", "abstract"),
259 | "cnn_dailymail": ("article", "highlights"),
260 | "orange_sum": ("text", "summary"),
261 | "pn_summary": ("article", "summary"),
262 | "psc": ("extract_text", "summary_text"),
263 | "samsum": ("dialogue", "summary"),
264 | "thaisum": ("body", "summary"),
265 | "xglue": ("news_body", "news_title"),
266 | "xsum": ("document", "summary"),
267 | "wiki_summary": ("article", "highlights"),
268 | }
269 |
270 |
271 | def main():
272 | # See all possible arguments in src/transformers/training_args.py
273 | # or by passing the --help flag to this script.
274 | # We now keep distinct sets of args, for a cleaner separation of concerns.
275 |
276 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
277 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
278 | # If we pass only one argument to the script and it's the path to a json file,
279 | # let's parse it to get our arguments.
280 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
281 | else:
282 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
283 |
284 | # Setup logging
285 | logging.basicConfig(
286 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
287 | datefmt="%m/%d/%Y %H:%M:%S",
288 | handlers=[logging.StreamHandler(sys.stdout)],
289 | )
290 | log_level = training_args.get_process_log_level()
291 | logger.setLevel(log_level)
292 | datasets.utils.logging.set_verbosity(log_level)
293 | transformers.utils.logging.set_verbosity(log_level)
294 | transformers.utils.logging.enable_default_handler()
295 | transformers.utils.logging.enable_explicit_format()
296 |
297 | # Log on each process the small summary:
298 | logger.warning(
299 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
300 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
301 | )
302 | logger.info(f"Training/evaluation parameters {training_args}")
303 |
304 | if data_args.source_prefix is None and model_args.model_name_or_path in [
305 | "t5-small",
306 | "t5-base",
307 | "t5-large",
308 | "t5-3b",
309 | "t5-11b",
310 | ]:
311 | logger.warning(
312 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
313 | "`--source_prefix 'summarize: ' `"
314 | )
315 |
316 | # Detecting last checkpoint.
317 | last_checkpoint = None
318 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
319 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
320 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
321 | raise ValueError(
322 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
323 | "Use --overwrite_output_dir to overcome."
324 | )
325 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
326 | logger.info(
327 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
328 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
329 | )
330 |
331 | # Set seed before initializing model.
332 | set_seed(training_args.seed)
333 |
334 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
335 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
336 | # (the dataset will be downloaded automatically from the datasets Hub).
337 | #
338 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the
339 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
340 | #
341 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
342 | # download the dataset.
343 | if data_args.dataset_name is not None:
344 | # Downloading and loading a dataset from the hub.
345 | raw_datasets = load_dataset(
346 | data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
347 | )
348 | else:
349 | data_files = {}
350 | if data_args.train_file is not None:
351 | data_files["train"] = data_args.train_file
352 | extension = data_args.train_file.split(".")[-1]
353 | if data_args.validation_file is not None:
354 | data_files["validation"] = data_args.validation_file
355 | extension = data_args.validation_file.split(".")[-1]
356 | if data_args.test_file is not None:
357 | data_files["test"] = data_args.test_file
358 | extension = data_args.test_file.split(".")[-1]
359 | raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
360 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
361 | # https://huggingface.co/docs/datasets/loading_datasets.html.
362 |
363 | # Load pretrained model and tokenizer
364 | #
365 | # Distributed training:
366 | # The .from_pretrained methods guarantee that only one local process can concurrently
367 | # download model & vocab.
368 | config = AutoConfig.from_pretrained(
369 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
370 | cache_dir=model_args.cache_dir,
371 | revision=model_args.model_revision,
372 | use_auth_token=True if model_args.use_auth_token else None,
373 | )
374 | tokenizer = AutoTokenizer.from_pretrained(
375 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
376 | cache_dir=model_args.cache_dir,
377 | use_fast=model_args.use_fast_tokenizer,
378 | revision=model_args.model_revision,
379 | use_auth_token=True if model_args.use_auth_token else None,
380 | )
381 | model = BartForConditionalGeneration.from_pretrained(
382 | model_args.model_name_or_path,
383 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
384 | config=config,
385 | cache_dir=model_args.cache_dir,
386 | revision=model_args.model_revision,
387 | use_auth_token=True if model_args.use_auth_token else None,
388 | )
389 |
390 | #special_tokens_dict = {'additional_special_tokens': ['[BOT]', '[BOCT]', '[BOA]', '[BOC]', '[BOB]']}# + ['BIB{:0>3}'.format(i+1) for i in range(100)]}
391 | #tokenizer.add_special_tokens(special_tokens_dict)
392 |
393 | model.resize_token_embeddings(len(tokenizer))
394 |
395 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
396 | if isinstance(tokenizer, MBartTokenizer):
397 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
398 | else:
399 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)
400 |
401 | if model.config.decoder_start_token_id is None:
402 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
403 |
404 | #if (
405 | # hasattr(model.config, "max_position_embeddings")
406 | # and model.config.max_position_embeddings < data_args.max_source_length
407 | #):
408 | # if model_args.resize_position_embeddings is None:
409 | # logger.warning(
410 | # f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} "
411 | # f"to {data_args.max_source_length}."
412 | # )
413 | # model.resize_position_embeddings(data_args.max_source_length)
414 | # elif model_args.resize_position_embeddings:
415 | # model.resize_position_embeddings(data_args.max_source_length)
416 | # else:
417 | # raise ValueError(
418 | # f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}"
419 | # f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically "
420 | # "resize the model's position encodings by passing `--resize_position_embeddings`."
421 | # )
422 |
423 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
424 |
425 | # Preprocessing the datasets.
426 | # We need to tokenize inputs and targets.
427 | if training_args.do_train:
428 | column_names = raw_datasets["train"].column_names
429 | elif training_args.do_eval:
430 | column_names = raw_datasets["validation"].column_names
431 | elif training_args.do_predict:
432 | column_names = raw_datasets["test"].column_names
433 | else:
434 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
435 | return
436 |
437 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
438 | assert (
439 | data_args.lang is not None
440 | ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
441 |
442 | tokenizer.src_lang = data_args.lang
443 | tokenizer.tgt_lang = data_args.lang
444 |
445 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
446 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
447 | forced_bos_token_id = (
448 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
449 | )
450 | model.config.forced_bos_token_id = forced_bos_token_id
451 |
452 | # Get the column names for input/target.
453 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
454 | if data_args.text_column is None:
455 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
456 | else:
457 | text_column = data_args.text_column
458 | if text_column not in column_names:
459 | raise ValueError(
460 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
461 | )
462 | if data_args.summary_column is None:
463 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
464 | else:
465 | summary_column = data_args.summary_column
466 | if summary_column not in column_names:
467 | raise ValueError(
468 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
469 | )
470 |
471 | # Temporarily set max_target_length for training.
472 | max_target_length = data_args.max_target_length
473 | padding = "max_length" if data_args.pad_to_max_length else False
474 |
475 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
476 | logger.warning(
477 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
478 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
479 | )
480 |
481 | def preprocess_function(examples):
482 | # remove pairs where at least one record is None
483 |
484 | inputs, targets = [], []
485 | for i in range(len(examples[text_column])):
486 | if examples[text_column][i] is not None and examples[summary_column][i] is not None:
487 | inputs.append(examples[text_column][i])
488 | targets.append(examples[summary_column][i])
489 |
490 | inputs = examples[text_column]
491 | targets = examples[summary_column]
492 | inputs = [prefix + inp for inp in inputs]
493 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
494 |
495 | # Setup the tokenizer for targets
496 | with tokenizer.as_target_tokenizer():
497 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)
498 |
499 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
500 | # padding in the loss.
501 | if padding == "max_length" and data_args.ignore_pad_token_for_loss:
502 | labels["input_ids"] = [
503 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
504 | ]
505 |
506 | model_inputs["labels"] = labels["input_ids"]
507 | return model_inputs
508 |
509 | if training_args.do_train:
510 | if "train" not in raw_datasets:
511 | raise ValueError("--do_train requires a train dataset")
512 | train_dataset = raw_datasets["train"]
513 | if data_args.max_train_samples is not None:
514 | train_dataset = train_dataset.select(range(data_args.max_train_samples))
515 | with training_args.main_process_first(desc="train dataset map pre-processing"):
516 | train_dataset = train_dataset.map(
517 | preprocess_function,
518 | batched=True,
519 | num_proc=data_args.preprocessing_num_workers,
520 | remove_columns=column_names,
521 | load_from_cache_file=not data_args.overwrite_cache,
522 | desc="Running tokenizer on train dataset",
523 | )
524 |
525 | if training_args.do_eval:
526 | max_target_length = data_args.val_max_target_length
527 | if "validation" not in raw_datasets:
528 | raise ValueError("--do_eval requires a validation dataset")
529 | eval_dataset = raw_datasets["validation"]
530 | if data_args.max_eval_samples is not None:
531 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
532 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
533 | eval_dataset = eval_dataset.map(
534 | preprocess_function,
535 | batched=True,
536 | num_proc=data_args.preprocessing_num_workers,
537 | remove_columns=column_names,
538 | load_from_cache_file=not data_args.overwrite_cache,
539 | desc="Running tokenizer on validation dataset",
540 | )
541 |
542 | if training_args.do_predict:
543 | max_target_length = data_args.val_max_target_length
544 | if "test" not in raw_datasets:
545 | raise ValueError("--do_predict requires a test dataset")
546 | predict_dataset = raw_datasets["test"]
547 | if data_args.max_predict_samples is not None:
548 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
549 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
550 | predict_dataset = predict_dataset.map(
551 | preprocess_function,
552 | batched=True,
553 | num_proc=data_args.preprocessing_num_workers,
554 | remove_columns=column_names,
555 | load_from_cache_file=not data_args.overwrite_cache,
556 | desc="Running tokenizer on prediction dataset",
557 | )
558 |
559 | # Data collator
560 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
561 | data_collator = DataCollatorForSeq2Seq(
562 | tokenizer,
563 | model=model,
564 | label_pad_token_id=label_pad_token_id,
565 | pad_to_multiple_of=8 if training_args.fp16 else None,
566 | )
567 |
568 | # Metric
569 | metric = load_metric("rouge")
570 |
571 | def postprocess_text(preds, labels):
572 | preds = [pred.strip() for pred in preds]
573 | labels = [label.strip() for label in labels]
574 |
575 | # rougeLSum expects newline after each sentence
576 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
577 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
578 |
579 | return preds, labels
580 |
581 | def compute_metrics(eval_preds):
582 | preds, labels = eval_preds
583 | if isinstance(preds, tuple):
584 | preds = preds[0]
585 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
586 | if data_args.ignore_pad_token_for_loss:
587 | # Replace -100 in the labels as we can't decode them.
588 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
589 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
590 |
591 | # Some simple post-processing
592 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
593 |
594 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
595 | # Extract a few results from ROUGE
596 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
597 |
598 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
599 | result["gen_len"] = np.mean(prediction_lens)
600 | result = {k: round(v, 4) for k, v in result.items()}
601 | return result
602 |
603 | # Initialize our Trainer
604 | trainer = Seq2SeqTrainer(
605 | model=model,
606 | args=training_args,
607 | train_dataset=train_dataset if training_args.do_train else None,
608 | eval_dataset=eval_dataset if training_args.do_eval else None,
609 | tokenizer=tokenizer,
610 | data_collator=data_collator,
611 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
612 | )
613 |
614 | # Training
615 | if training_args.do_train:
616 | checkpoint = None
617 | if training_args.resume_from_checkpoint is not None:
618 | checkpoint = training_args.resume_from_checkpoint
619 | elif last_checkpoint is not None:
620 | checkpoint = last_checkpoint
621 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
622 | trainer.save_model() # Saves the tokenizer too for easy upload
623 |
624 | metrics = train_result.metrics
625 | max_train_samples = (
626 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
627 | )
628 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
629 |
630 | trainer.log_metrics("train", metrics)
631 | trainer.save_metrics("train", metrics)
632 | trainer.save_state()
633 |
634 | # Evaluation
635 | results = {}
636 | max_length = (
637 | training_args.generation_max_length
638 | if training_args.generation_max_length is not None
639 | else data_args.val_max_target_length
640 | )
641 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
642 | if training_args.do_eval:
643 | logger.info("*** Evaluate ***")
644 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
645 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
646 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
647 |
648 | trainer.log_metrics("eval", metrics)
649 | trainer.save_metrics("eval", metrics)
650 |
651 | if training_args.do_predict:
652 | logger.info("*** Predict ***")
653 |
654 | predict_results = trainer.predict(
655 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
656 | )
657 | metrics = predict_results.metrics
658 | max_predict_samples = (
659 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
660 | )
661 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
662 |
663 | trainer.log_metrics("predict", metrics)
664 | trainer.save_metrics("predict", metrics)
665 |
666 | if trainer.is_world_process_zero():
667 | if training_args.predict_with_generate:
668 | predictions = tokenizer.batch_decode(
669 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
670 | )
671 | predictions = [pred.strip() for pred in predictions]
672 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
673 | with open(output_prediction_file, "w") as writer:
674 | writer.write("\n".join(predictions))
675 |
676 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"}
677 | if data_args.dataset_name is not None:
678 | kwargs["dataset_tags"] = data_args.dataset_name
679 | if data_args.dataset_config_name is not None:
680 | kwargs["dataset_args"] = data_args.dataset_config_name
681 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
682 | else:
683 | kwargs["dataset"] = data_args.dataset_name
684 |
685 | if data_args.lang is not None:
686 | kwargs["language"] = data_args.lang
687 |
688 | if training_args.push_to_hub:
689 | trainer.push_to_hub(**kwargs)
690 | else:
691 | trainer.create_model_card(**kwargs)
692 |
693 | return results
694 |
695 |
696 | def _mp_fn(index):
697 | # For xla_spawn (TPUs)
698 | main()
699 |
700 |
701 | if __name__ == "__main__":
702 | main()
--------------------------------------------------------------------------------
/qfid/qfid.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import math
3 | import random
4 | import warnings
5 | from typing import List, Optional, Tuple, Union
6 |
7 | import torch
8 | import torch.utils.checkpoint
9 | from torch import nn
10 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11 |
12 | from transformers.activations import ACT2FN
13 | from transformers.modeling_outputs import (
14 | BaseModelOutput,
15 | BaseModelOutputWithPastAndCrossAttentions,
16 | CausalLMOutputWithCrossAttentions,
17 | Seq2SeqLMOutput,
18 | Seq2SeqModelOutput,
19 | Seq2SeqQuestionAnsweringModelOutput,
20 | Seq2SeqSequenceClassifierOutput,
21 | )
22 | from transformers.modeling_utils import PreTrainedModel
23 | from transformers.models.bart.configuration_bart import BartConfig
24 |
25 |
26 | _CHECKPOINT_FOR_DOC = "facebook/bart-base"
27 | _CONFIG_FOR_DOC = "BartConfig"
28 | _TOKENIZER_FOR_DOC = "BartTokenizer"
29 |
30 | # Base model docstring
31 | _EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
32 |
33 | # SequenceClassification docstring
34 | _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "valhalla/bart-large-sst2"
35 | _SEQ_CLASS_EXPECTED_LOSS = 0.0
36 | _SEQ_CLASS_EXPECTED_OUTPUT = "'POSITIVE'"
37 |
38 | # QuestionAsnwering docstring
39 | _CHECKPOINT_FOR_QA = "valhalla/bart-large-finetuned-squadv1"
40 | _QA_EXPECTED_LOSS = 0.59
41 | _QA_EXPECTED_OUTPUT = "' nice puppet'"
42 |
43 |
44 | BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
45 | "facebook/bart-large",
46 | # see all BART models at https://huggingface.co/models?filter=bart
47 | ]
48 |
49 |
50 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
51 | """
52 | Shift input ids one token to the right.
53 | """
54 | shifted_input_ids = input_ids.new_zeros(input_ids.shape)
55 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
56 | shifted_input_ids[:, 0] = decoder_start_token_id
57 |
58 | if pad_token_id is None:
59 | raise ValueError("self.model.config.pad_token_id has to be defined.")
60 | # replace possible -100 values in labels by `pad_token_id`
61 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
62 |
63 | return shifted_input_ids
64 |
65 |
66 | def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
67 | """
68 | Make causal mask used for bi-directional self-attention.
69 | """
70 | bsz, tgt_len = input_ids_shape
71 | mask = torch.full((tgt_len, tgt_len), float("-inf"))
72 | mask_cond = torch.arange(mask.size(-1))
73 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
74 | mask = mask.to(dtype)
75 |
76 | if past_key_values_length > 0:
77 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
78 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
79 |
80 |
81 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
82 | """
83 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
84 | """
85 | bsz, src_len = mask.size()
86 | tgt_len = tgt_len if tgt_len is not None else src_len
87 |
88 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
89 |
90 | inverted_mask = 1.0 - expanded_mask
91 |
92 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)
93 |
94 |
95 | class BartLearnedPositionalEmbedding(nn.Embedding):
96 | """
97 | This module learns positional embeddings up to a fixed maximum size.
98 | """
99 |
100 | def __init__(self, num_embeddings: int, embedding_dim: int):
101 | # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
102 | # and adjust num_embeddings appropriately. Other models don't have this hack
103 | self.offset = 2
104 | super().__init__(num_embeddings + self.offset, embedding_dim)
105 |
106 | def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0):
107 | """`input_ids_shape` is expected to be [bsz x seqlen]."""
108 | bsz, seq_len = input_ids_shape[:2]
109 | positions = torch.arange(
110 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
111 | )
112 | return super().forward(positions + self.offset)
113 |
114 |
115 | class BartAttention(nn.Module):
116 | """Multi-headed attention from 'Attention Is All You Need' paper"""
117 |
118 | def __init__(
119 | self,
120 | embed_dim: int,
121 | num_heads: int,
122 | dropout: float = 0.0,
123 | is_decoder: bool = False,
124 | bias: bool = True,
125 | ):
126 | super().__init__()
127 | self.embed_dim = embed_dim
128 | self.num_heads = num_heads
129 | self.dropout = dropout
130 | self.head_dim = embed_dim // num_heads
131 |
132 | if (self.head_dim * num_heads) != self.embed_dim:
133 | raise ValueError(
134 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
135 | f" and `num_heads`: {num_heads})."
136 | )
137 | self.scaling = self.head_dim**-0.5
138 | self.is_decoder = is_decoder
139 |
140 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
141 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
142 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
143 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
144 |
145 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
146 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
147 |
148 | def forward(
149 | self,
150 | hidden_states: torch.Tensor,
151 | key_value_states: Optional[torch.Tensor] = None,
152 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
153 | attention_mask: Optional[torch.Tensor] = None,
154 | layer_head_mask: Optional[torch.Tensor] = None,
155 | output_attentions: bool = False,
156 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
157 | """Input shape: Batch x Time x Channel"""
158 |
159 | # if key_value_states are provided this layer is used as a cross-attention layer
160 | # for the decoder
161 | is_cross_attention = key_value_states is not None
162 |
163 | bsz, tgt_len, _ = hidden_states.size()
164 |
165 | # get query proj
166 | query_states = self.q_proj(hidden_states) * self.scaling
167 | # get key, value proj
168 | if is_cross_attention and past_key_value is not None:
169 | # reuse k,v, cross_attentions
170 | key_states = past_key_value[0]
171 | value_states = past_key_value[1]
172 | elif is_cross_attention:
173 | # cross_attentions
174 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
175 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
176 | elif past_key_value is not None:
177 | # reuse k, v, self_attention
178 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
179 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
180 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
181 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
182 | else:
183 | # self_attention
184 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
185 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
186 |
187 | if self.is_decoder:
188 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
189 | # Further calls to cross_attention layer can then reuse all cross-attention
190 | # key/value_states (first "if" case)
191 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
192 | # all previous decoder key/value_states. Further calls to uni-directional self-attention
193 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
194 | # if encoder bi-directional self-attention `past_key_value` is always `None`
195 | past_key_value = (key_states, value_states)
196 |
197 | proj_shape = (bsz * self.num_heads, -1, self.head_dim)
198 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
199 | key_states = key_states.view(*proj_shape)
200 | value_states = value_states.view(*proj_shape)
201 |
202 | src_len = key_states.size(1)
203 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
204 |
205 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
206 | raise ValueError(
207 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}"
208 | )
209 |
210 | if attention_mask is not None:
211 | if attention_mask.size() != (bsz, 1, tgt_len, src_len):
212 | raise ValueError(
213 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
214 | )
215 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
216 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
217 |
218 | attn_weights = nn.functional.softmax(attn_weights, dim=-1)
219 |
220 | if layer_head_mask is not None:
221 | if layer_head_mask.size() != (self.num_heads,):
222 | raise ValueError(
223 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is {layer_head_mask.size()}"
224 | )
225 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
226 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
227 |
228 | if output_attentions:
229 | # this operation is a bit awkward, but it's required to
230 | # make sure that attn_weights keeps its gradient.
231 | # In order to do so, attn_weights have to be reshaped
232 | # twice and have to be reused in the following
233 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
234 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
235 | else:
236 | attn_weights_reshaped = None
237 |
238 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
239 |
240 | attn_output = torch.bmm(attn_probs, value_states)
241 |
242 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
243 | raise ValueError(
244 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
245 | )
246 |
247 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
248 | attn_output = attn_output.transpose(1, 2)
249 |
250 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
251 | # partitioned aross GPUs when using tensor-parallelism.
252 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
253 |
254 | attn_output = self.out_proj(attn_output)
255 |
256 | return attn_output, attn_weights_reshaped, past_key_value
257 |
258 |
259 | class BartEncoderLayer(nn.Module):
260 | def __init__(self, config: BartConfig):
261 | super().__init__()
262 | self.embed_dim = config.d_model
263 | self.self_attn = BartAttention(
264 | embed_dim=self.embed_dim,
265 | num_heads=config.encoder_attention_heads,
266 | dropout=config.attention_dropout,
267 | )
268 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
269 | self.dropout = config.dropout
270 | self.activation_fn = ACT2FN[config.activation_function]
271 | self.activation_dropout = config.activation_dropout
272 | self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
273 | self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
274 | self.final_layer_norm = nn.LayerNorm(self.embed_dim)
275 |
276 | def forward(
277 | self,
278 | hidden_states: torch.FloatTensor,
279 | attention_mask: torch.FloatTensor,
280 | layer_head_mask: torch.FloatTensor,
281 | output_attentions: Optional[bool] = False,
282 | ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
283 | """
284 | Args:
285 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
286 | attention_mask (`torch.FloatTensor`): attention mask of size
287 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
288 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
289 | `(encoder_attention_heads,)`.
290 | output_attentions (`bool`, *optional*):
291 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
292 | returned tensors for more detail.
293 | """
294 | residual = hidden_states
295 | hidden_states, attn_weights, _ = self.self_attn(
296 | hidden_states=hidden_states,
297 | attention_mask=attention_mask,
298 | layer_head_mask=layer_head_mask,
299 | output_attentions=output_attentions,
300 | )
301 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
302 | hidden_states = residual + hidden_states
303 | hidden_states = self.self_attn_layer_norm(hidden_states)
304 |
305 | residual = hidden_states
306 | hidden_states = self.activation_fn(self.fc1(hidden_states))
307 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
308 | hidden_states = self.fc2(hidden_states)
309 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
310 | hidden_states = residual + hidden_states
311 | hidden_states = self.final_layer_norm(hidden_states)
312 |
313 | if hidden_states.dtype == torch.float16 and (
314 | torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
315 | ):
316 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000
317 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
318 |
319 | outputs = (hidden_states,)
320 |
321 | if output_attentions:
322 | outputs += (attn_weights,)
323 |
324 | return outputs
325 |
326 |
327 | class BartDecoderLayer(nn.Module):
328 | def __init__(self, config: BartConfig):
329 | super().__init__()
330 | self.embed_dim = config.d_model
331 |
332 | self.self_attn = BartAttention(
333 | embed_dim=self.embed_dim,
334 | num_heads=config.decoder_attention_heads,
335 | dropout=config.attention_dropout,
336 | is_decoder=True,
337 | )
338 | self.dropout = config.dropout
339 | self.activation_fn = ACT2FN[config.activation_function]
340 | self.activation_dropout = config.activation_dropout
341 |
342 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
343 | self.encoder_attn = BartAttention(
344 | self.embed_dim,
345 | config.decoder_attention_heads,
346 | dropout=config.attention_dropout,
347 | is_decoder=True,
348 | )
349 | self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
350 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
351 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
352 | self.final_layer_norm = nn.LayerNorm(self.embed_dim)
353 |
354 | def forward(
355 | self,
356 | hidden_states: torch.Tensor,
357 | attention_mask: Optional[torch.Tensor] = None,
358 | encoder_hidden_states: Optional[torch.Tensor] = None,
359 | encoder_attention_mask: Optional[torch.Tensor] = None,
360 | layer_head_mask: Optional[torch.Tensor] = None,
361 | cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
362 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
363 | output_attentions: Optional[bool] = False,
364 | use_cache: Optional[bool] = True,
365 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
366 | """
367 | Args:
368 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
369 | attention_mask (`torch.FloatTensor`): attention mask of size
370 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
371 | encoder_hidden_states (`torch.FloatTensor`):
372 | cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
373 | encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
374 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
375 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
376 | `(encoder_attention_heads,)`.
377 | cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
378 | size `(decoder_attention_heads,)`.
379 | past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
380 | output_attentions (`bool`, *optional*):
381 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
382 | returned tensors for more detail.
383 | """
384 | residual = hidden_states
385 | # Self Attention
386 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
387 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
388 | # add present self-attn cache to positions 1,2 of present_key_value tuple
389 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
390 | hidden_states=hidden_states,
391 | past_key_value=self_attn_past_key_value,
392 | attention_mask=attention_mask,
393 | layer_head_mask=layer_head_mask,
394 | output_attentions=output_attentions,
395 | )
396 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
397 | hidden_states = residual + hidden_states
398 | hidden_states = self.self_attn_layer_norm(hidden_states)
399 |
400 | # Cross-Attention Block
401 | cross_attn_present_key_value = None
402 | cross_attn_weights = None
403 | if encoder_hidden_states is not None:
404 | residual = hidden_states
405 |
406 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
407 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
408 | #print("BartDecoderLayer: {}".format(encoder_attention_mask.shape)) # こいつをtensor([], size=(1, 1, 512, 0))→size=(1, 1, 512, 1024)にする
409 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
410 | hidden_states=hidden_states,
411 | key_value_states=encoder_hidden_states,
412 | attention_mask=encoder_attention_mask,
413 | layer_head_mask=cross_attn_layer_head_mask,
414 | past_key_value=cross_attn_past_key_value,
415 | output_attentions=output_attentions,
416 | )
417 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
418 | hidden_states = residual + hidden_states
419 | hidden_states = self.encoder_attn_layer_norm(hidden_states)
420 |
421 | # add cross-attn to positions 3,4 of present_key_value tuple
422 | present_key_value = present_key_value + cross_attn_present_key_value
423 |
424 | # Fully Connected
425 | residual = hidden_states
426 | hidden_states = self.activation_fn(self.fc1(hidden_states))
427 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
428 | hidden_states = self.fc2(hidden_states)
429 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
430 | hidden_states = residual + hidden_states
431 | hidden_states = self.final_layer_norm(hidden_states)
432 |
433 | outputs = (hidden_states,)
434 |
435 | if output_attentions:
436 | outputs += (self_attn_weights, cross_attn_weights)
437 |
438 | if use_cache:
439 | outputs += (present_key_value,)
440 |
441 | return outputs
442 |
443 |
444 | class BartClassificationHead(nn.Module):
445 | """Head for sentence-level classification tasks."""
446 |
447 | def __init__(
448 | self,
449 | input_dim: int,
450 | inner_dim: int,
451 | num_classes: int,
452 | pooler_dropout: float,
453 | ):
454 | super().__init__()
455 | self.dense = nn.Linear(input_dim, inner_dim)
456 | self.dropout = nn.Dropout(p=pooler_dropout)
457 | self.out_proj = nn.Linear(inner_dim, num_classes)
458 |
459 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
460 | hidden_states = self.dropout(hidden_states)
461 | hidden_states = self.dense(hidden_states)
462 | hidden_states = torch.tanh(hidden_states)
463 | hidden_states = self.dropout(hidden_states)
464 | hidden_states = self.out_proj(hidden_states)
465 | return hidden_states
466 |
467 |
468 | class BartPretrainedModel(PreTrainedModel):
469 | config_class = BartConfig
470 | base_model_prefix = "model"
471 | supports_gradient_checkpointing = True
472 | _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
473 |
474 | def _init_weights(self, module):
475 | std = self.config.init_std
476 | if isinstance(module, nn.Linear):
477 | module.weight.data.normal_(mean=0.0, std=std)
478 | if module.bias is not None:
479 | module.bias.data.zero_()
480 | elif isinstance(module, nn.Embedding):
481 | module.weight.data.normal_(mean=0.0, std=std)
482 | if module.padding_idx is not None:
483 | module.weight.data[module.padding_idx].zero_()
484 |
485 | def _set_gradient_checkpointing(self, module, value=False):
486 | if isinstance(module, (BartDecoder, BartEncoder)):
487 | module.gradient_checkpointing = value
488 |
489 | @property
490 | def dummy_inputs(self):
491 | pad_token = self.config.pad_token_id
492 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
493 | dummy_inputs = {
494 | "attention_mask": input_ids.ne(pad_token),
495 | "input_ids": input_ids,
496 | }
497 | return dummy_inputs
498 |
499 |
500 | class PretrainedBartModel(BartPretrainedModel):
501 | def __init_subclass__(self):
502 | warnings.warn(
503 | "The class `PretrainedBartModel` has been depreciated, please use `BartPretrainedModel` instead.",
504 | FutureWarning,
505 | )
506 |
507 |
508 | BART_START_DOCSTRING = r"""
509 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
510 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
511 | etc.)
512 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
513 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
514 | and behavior.
515 | Parameters:
516 | config ([`BartConfig`]):
517 | Model configuration class with all the parameters of the model. Initializing with a config file does not
518 | load the weights associated with the model, only the configuration. Check out the
519 | [`~PreTrainedModel.from_pretrained`] method to load the model weights.
520 | """
521 |
522 | BART_GENERATION_EXAMPLE = r"""
523 | Summarization example:
524 | ```python
525 | >>> from transformers import BartTokenizer, BartForConditionalGeneration
526 | >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
527 | >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
528 | >>> ARTICLE_TO_SUMMARIZE = (
529 | ... "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
530 | ... "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
531 | ... "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
532 | ... )
533 | >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")
534 | >>> # Generate Summary
535 | >>> summary_ids = model.generate(inputs["input_ids"], num_beams=2, max_length=20)
536 | >>> tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
537 | 'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'
538 | ```
539 | Mask filling example:
540 | ```python
541 | >>> from transformers import BartTokenizer, BartForConditionalGeneration
542 | >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
543 | >>> model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
544 | >>> TXT = "My friends are but they eat too many carbs."
545 | >>> input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
546 | >>> logits = model(input_ids).logits
547 | >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
548 | >>> probs = logits[0, masked_index].softmax(dim=0)
549 | >>> values, predictions = probs.topk(5)
550 | >>> tokenizer.decode(predictions).split()
551 | ['not', 'good', 'healthy', 'great', 'very']
552 | ```
553 | """
554 |
555 | BART_INPUTS_DOCSTRING = r"""
556 | Args:
557 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
558 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
559 | it.
560 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
561 | [`PreTrainedTokenizer.__call__`] for details.
562 | [What are input IDs?](../glossary#input-ids)
563 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
564 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
565 | - 1 for tokens that are **not masked**,
566 | - 0 for tokens that are **masked**.
567 | [What are attention masks?](../glossary#attention-mask)
568 | decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
569 | Indices of decoder input sequence tokens in the vocabulary.
570 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
571 | [`PreTrainedTokenizer.__call__`] for details.
572 | [What are decoder input IDs?](../glossary#decoder-input-ids)
573 | Bart uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
574 | is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
575 | For translation and summarization training, `decoder_input_ids` should be provided. If no
576 | `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
577 | for denoising pre-training following the paper.
578 | decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
579 | Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
580 | be used by default.
581 | If you want to change padding behavior, you should read [`modeling_bart._prepare_decoder_inputs`] and
582 | modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information
583 | on the default strategy.
584 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
585 | Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
586 | - 1 indicates the head is **not masked**,
587 | - 0 indicates the head is **masked**.
588 | decoder_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
589 | Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
590 | - 1 indicates the head is **not masked**,
591 | - 0 indicates the head is **masked**.
592 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
593 | Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in `[0,
594 | 1]`:
595 | - 1 indicates the head is **not masked**,
596 | - 0 indicates the head is **masked**.
597 | encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
598 | Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
599 | `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
600 | hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
601 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
602 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
603 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
604 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
605 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
606 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
607 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
608 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
609 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape
610 | `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you
611 | can choose to directly pass an embedded representation. This is useful if you want more control over how to
612 | convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
613 | decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
614 | Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
615 | representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
616 | input (see `past_key_values`). This is useful if you want more control over how to convert
617 | `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
618 | If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
619 | of `inputs_embeds`.
620 | use_cache (`bool`, *optional*):
621 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
622 | `past_key_values`).
623 | output_attentions (`bool`, *optional*):
624 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
625 | tensors for more detail.
626 | output_hidden_states (`bool`, *optional*):
627 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
628 | more detail.
629 | return_dict (`bool`, *optional*):
630 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
631 | """
632 |
633 |
634 | class BartEncoder(BartPretrainedModel):
635 | """
636 | Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
637 | [`BartEncoderLayer`].
638 | Args:
639 | config: BartConfig
640 | embed_tokens (nn.Embedding): output embedding
641 | """
642 |
643 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
644 | super().__init__(config)
645 |
646 | self.dropout = config.dropout
647 | self.layerdrop = config.encoder_layerdrop
648 |
649 | embed_dim = config.d_model
650 | self.padding_idx = config.pad_token_id
651 | self.max_source_positions = config.max_position_embeddings
652 | self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
653 |
654 | if embed_tokens is not None:
655 | self.embed_tokens = embed_tokens
656 | else:
657 | self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
658 |
659 | self.embed_positions = BartLearnedPositionalEmbedding(
660 | config.max_position_embeddings,
661 | embed_dim,
662 | )
663 | self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
664 | self.layernorm_embedding = nn.LayerNorm(embed_dim)
665 |
666 | self.gradient_checkpointing = False
667 | # Initialize weights and apply final processing
668 | self.post_init()
669 |
670 | def get_input_embeddings(self):
671 | return self.embed_tokens
672 |
673 | def set_input_embeddings(self, value):
674 | self.embed_tokens = value
675 |
676 | def forward(
677 | self,
678 | input_ids: torch.LongTensor = None,
679 | attention_mask: Optional[torch.Tensor] = None,
680 | head_mask: Optional[torch.Tensor] = None,
681 | inputs_embeds: Optional[torch.FloatTensor] = None,
682 | output_attentions: Optional[bool] = None,
683 | output_hidden_states: Optional[bool] = None,
684 | return_dict: Optional[bool] = None,
685 | ) -> Union[Tuple, BaseModelOutput]:
686 | r"""
687 | Args:
688 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
689 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
690 | provide it.
691 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
692 | [`PreTrainedTokenizer.__call__`] for details.
693 | [What are input IDs?](../glossary#input-ids)
694 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
695 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
696 | - 1 for tokens that are **not masked**,
697 | - 0 for tokens that are **masked**.
698 | [What are attention masks?](../glossary#attention-mask)
699 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
700 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
701 | - 1 indicates the head is **not masked**,
702 | - 0 indicates the head is **masked**.
703 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
704 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
705 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors
706 | than the model's internal embedding lookup matrix.
707 | output_attentions (`bool`, *optional*):
708 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
709 | returned tensors for more detail.
710 | output_hidden_states (`bool`, *optional*):
711 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
712 | for more detail.
713 | return_dict (`bool`, *optional*):
714 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
715 | """
716 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
717 | output_hidden_states = (
718 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
719 | )
720 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
721 |
722 | # retrieve input_ids and inputs_embeds
723 | if input_ids is not None and inputs_embeds is not None:
724 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
725 | elif input_ids is not None:
726 | input_shape = input_ids.size()
727 | input_ids = input_ids.view(-1, input_shape[-1])
728 | elif inputs_embeds is not None:
729 | raise ValueError("You cannot specify inputs_embeds for now")
730 | else:
731 | raise ValueError("You have to specify either input_ids or inputs_embeds")
732 |
733 | # divide input_ids and attention_mask by ''
734 | input_ids_list = []
735 | attention_mask_list = []
736 | prev_token = 0
737 | prev_i = 1
738 | n_sep = 0
739 | for i, token in enumerate(input_ids[0]):
740 | if token == 2 and prev_token == 1437:
741 | input_ids_list.append(torch.cat((torch.cat((torch.tensor([0], device="cuda"), torch.tensor(input_ids[0][prev_i:i-1])))[:min(len(input_ids[0][prev_i-1:i-1]), self.config.max_position_embeddings-1)], torch.tensor([2], device="cuda"))).unsqueeze(dim=0))
742 | attention_mask_list.append(torch.tensor(attention_mask[0][prev_i-1:i])[:self.config.max_position_embeddings].unsqueeze(dim=0))
743 | prev_i = i + 1
744 | n_sep += 1
745 | prev_token = token
746 |
747 | input_ids_list.append(torch.cat((torch.cat((torch.tensor([0], device="cuda"), torch.tensor(input_ids[0][prev_i:])))[:min(len(input_ids[0][prev_i:]), self.config.max_position_embeddings-1)], torch.tensor([2], device="cuda"))).unsqueeze(dim=0))
748 | attention_mask_list.append((torch.tensor(attention_mask[0][prev_i-1:]))[:self.config.max_position_embeddings].unsqueeze(dim=0))
749 |
750 | l = max(list(map(lambda x: len(x[0]), input_ids_list)))
751 | input_ids = torch.ones(len(input_ids_list), l, dtype=torch.int32, device="cuda")
752 | attention_mask = torch.zeros(len(attention_mask_list), l, dtype=torch.int32, device="cuda")
753 | for i, (_input_ids, _attention_mask) in enumerate(zip(input_ids_list, attention_mask_list)):
754 | input_ids[i, :len(_input_ids[0])] = _input_ids[0]
755 | attention_mask[i, :len(_attention_mask[0])] = _attention_mask[0]
756 |
757 | if inputs_embeds is None:
758 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
759 |
760 | input_shape = input_ids.size()
761 | embed_pos = self.embed_positions(input_shape)
762 |
763 | hidden_states = inputs_embeds + embed_pos
764 | hidden_states = self.layernorm_embedding(hidden_states)
765 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
766 |
767 | # expand attention_mask
768 | if attention_mask is not None:
769 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
770 | original_attention_mask = attention_mask
771 | attention_mask = _expand_mask(attention_mask, inputs_embeds.dtype)
772 |
773 | encoder_states = () if output_hidden_states else None
774 | all_attentions = () if output_attentions else None
775 |
776 | # check if head_mask has a correct number of layers specified if desired
777 | if head_mask is not None:
778 | if head_mask.size()[0] != (len(self.layers)):
779 | raise ValueError(
780 | f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
781 | )
782 |
783 | for idx, encoder_layer in enumerate(self.layers):
784 | if output_hidden_states:
785 | encoder_states = encoder_states + (hidden_states,)
786 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
787 | dropout_probability = random.uniform(0, 1)
788 | if self.training and (dropout_probability < self.layerdrop): # skip the layer
789 | layer_outputs = (None, None)
790 | else:
791 | if self.gradient_checkpointing and self.training:
792 |
793 | def create_custom_forward(module):
794 | def custom_forward(*inputs):
795 | return module(*inputs, output_attentions)
796 |
797 | return custom_forward
798 |
799 | layer_outputs = torch.utils.checkpoint.checkpoint(
800 | create_custom_forward(encoder_layer),
801 | hidden_states,
802 | attention_mask,
803 | (head_mask[idx] if head_mask is not None else None),
804 | )
805 | else:
806 | layer_outputs = encoder_layer(
807 | hidden_states,
808 | attention_mask,
809 | layer_head_mask=(head_mask[idx] if head_mask is not None else None),
810 | output_attentions=output_attentions,
811 | )
812 |
813 | hidden_states = layer_outputs[0]
814 |
815 | if output_attentions:
816 | all_attentions = all_attentions + (layer_outputs[1],)
817 |
818 | if output_hidden_states:
819 | encoder_states = encoder_states + (hidden_states,)
820 |
821 | matching_list = []
822 | for i, (_hidden_states, _attention_mask) in enumerate(zip(hidden_states, original_attention_mask)):
823 | if i == 0:
824 | title_hidden_states = torch.mean(_hidden_states.index_select(0, torch.tensor([i for i, mask in enumerate(_attention_mask) if mask == 1], device="cuda")), dim=0)
825 | else:
826 | abst_hidden_states = torch.mean(_hidden_states.index_select(0, torch.tensor([i for i, mask in enumerate(_attention_mask) if mask == 1], device="cuda")), dim=0)
827 | matching_list.append(torch.dot(title_hidden_states, abst_hidden_states).item())
828 |
829 | f = torch.nn.Softmax(dim=0)
830 | matching_list = f(torch.tensor(matching_list, dtype=torch.float32))
831 | try:
832 | matching_list = (1 + matching_list).tolist()
833 | except:
834 | matching_list = [1 for i in range(len(matching_list))]
835 |
836 | for i, (_hidden_states, _attention_mask) in enumerate(zip(hidden_states, original_attention_mask)):
837 | if i == 0:
838 | return_hidden_states = _hidden_states.index_select(0, torch.tensor([i for i, mask in enumerate(_attention_mask) if mask == 1], device="cuda"))
839 | else:
840 | return_hidden_states = torch.cat((return_hidden_states, matching_list[i-1] * _hidden_states.index_select(0, torch.tensor([i for i, mask in enumerate(_attention_mask) if mask == 1], device="cuda"))), 0)
841 |
842 | if not return_dict:
843 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
844 | return BaseModelOutput(
845 | last_hidden_state=return_hidden_states.unsqueeze(0), hidden_states=encoder_states, attentions=all_attentions
846 | )
847 |
848 |
849 | class BartDecoder(BartPretrainedModel):
850 | """
851 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`BartDecoderLayer`]
852 | Args:
853 | config: BartConfig
854 | embed_tokens (nn.Embedding): output embedding
855 | """
856 |
857 | def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
858 | super().__init__(config)
859 | self.dropout = config.dropout
860 | self.layerdrop = config.decoder_layerdrop
861 | self.padding_idx = config.pad_token_id
862 | self.max_target_positions = config.max_position_embeddings
863 | self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
864 |
865 | if embed_tokens is not None:
866 | self.embed_tokens = embed_tokens
867 | else:
868 | self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
869 |
870 | self.embed_positions = BartLearnedPositionalEmbedding(
871 | config.max_position_embeddings,
872 | config.d_model,
873 | )
874 | self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
875 | self.layernorm_embedding = nn.LayerNorm(config.d_model)
876 |
877 | self.gradient_checkpointing = False
878 | # Initialize weights and apply final processing
879 | self.post_init()
880 |
881 | def get_input_embeddings(self):
882 | return self.embed_tokens
883 |
884 | def set_input_embeddings(self, value):
885 | self.embed_tokens = value
886 |
887 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
888 | # create causal mask
889 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
890 | combined_attention_mask = None
891 | if input_shape[-1] > 1:
892 | combined_attention_mask = _make_causal_mask(
893 | input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
894 | ).to(self.device)
895 |
896 | if attention_mask is not None:
897 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
898 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
899 | combined_attention_mask = (
900 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
901 | )
902 |
903 | return combined_attention_mask
904 |
905 | def forward(
906 | self,
907 | input_ids: torch.LongTensor = None,
908 | attention_mask: Optional[torch.Tensor] = None,
909 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
910 | encoder_attention_mask: Optional[torch.LongTensor] = None,
911 | head_mask: Optional[torch.Tensor] = None,
912 | cross_attn_head_mask: Optional[torch.Tensor] = None,
913 | past_key_values: Optional[List[torch.FloatTensor]] = None,
914 | inputs_embeds: Optional[torch.FloatTensor] = None,
915 | use_cache: Optional[bool] = None,
916 | output_attentions: Optional[bool] = None,
917 | output_hidden_states: Optional[bool] = None,
918 | return_dict: Optional[bool] = None,
919 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
920 | r"""
921 | Args:
922 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
923 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
924 | provide it.
925 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
926 | [`PreTrainedTokenizer.__call__`] for details.
927 | [What are input IDs?](../glossary#input-ids)
928 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
929 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
930 | - 1 for tokens that are **not masked**,
931 | - 0 for tokens that are **masked**.
932 | [What are attention masks?](../glossary#attention-mask)
933 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
934 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
935 | of the decoder.
936 | encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
937 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
938 | selected in `[0, 1]`:
939 | - 1 for tokens that are **not masked**,
940 | - 0 for tokens that are **masked**.
941 | [What are attention masks?](../glossary#attention-mask)
942 | head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
943 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
944 | - 1 indicates the head is **not masked**,
945 | - 0 indicates the head is **masked**.
946 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
947 | Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
948 | cross-attention on hidden heads. Mask values selected in `[0, 1]`:
949 | - 1 indicates the head is **not masked**,
950 | - 0 indicates the head is **masked**.
951 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
952 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
953 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
954 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
955 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
956 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
957 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
958 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
959 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
960 | shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
961 | `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
962 | control over how to convert `input_ids` indices into associated vectors than the model's internal
963 | embedding lookup matrix.
964 | output_attentions (`bool`, *optional*):
965 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
966 | returned tensors for more detail.
967 | output_hidden_states (`bool`, *optional*):
968 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
969 | for more detail.
970 | return_dict (`bool`, *optional*):
971 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
972 | """
973 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
974 | output_hidden_states = (
975 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
976 | )
977 | use_cache = use_cache if use_cache is not None else self.config.use_cache
978 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
979 |
980 | # retrieve input_ids and inputs_embeds
981 | if input_ids is not None and inputs_embeds is not None:
982 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
983 | elif input_ids is not None:
984 | input_shape = input_ids.size()
985 | input_ids = input_ids.view(-1, input_shape[-1])
986 | elif inputs_embeds is not None:
987 | input_shape = inputs_embeds.size()[:-1]
988 | else:
989 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
990 |
991 | # past_key_values_length
992 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
993 |
994 | if inputs_embeds is None:
995 | inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
996 |
997 | attention_mask = self._prepare_decoder_attention_mask(
998 | attention_mask, input_shape, inputs_embeds, past_key_values_length
999 | )
1000 |
1001 | # encoder_attention_maskを1埋めで作る
1002 | #encoder_attention_mask = torch.ones(encoder_hidden_states.shape[0], encoder_hidden_states.shape[1], device="cuda")
1003 | encoder_attention_mask = encoder_attention_mask.narrow(dim=1, start=0, length=encoder_hidden_states.shape[1])
1004 | # expand encoder attention mask
1005 | if encoder_hidden_states is not None and encoder_attention_mask is not None:
1006 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1007 | encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
1008 |
1009 | # embed positions
1010 | positions = self.embed_positions(input_shape, past_key_values_length)
1011 |
1012 | hidden_states = inputs_embeds + positions
1013 | hidden_states = self.layernorm_embedding(hidden_states)
1014 |
1015 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1016 |
1017 | # decoder layers
1018 | all_hidden_states = () if output_hidden_states else None
1019 | all_self_attns = () if output_attentions else None
1020 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
1021 | next_decoder_cache = () if use_cache else None
1022 |
1023 | # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
1024 | for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
1025 | if attn_mask is not None:
1026 | if attn_mask.size()[0] != (len(self.layers)):
1027 | raise ValueError(
1028 | "The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
1029 | )
1030 |
1031 | for idx, decoder_layer in enumerate(self.layers):
1032 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1033 | if output_hidden_states:
1034 | all_hidden_states += (hidden_states,)
1035 | dropout_probability = random.uniform(0, 1)
1036 | if self.training and (dropout_probability < self.layerdrop):
1037 | continue
1038 |
1039 | past_key_value = past_key_values[idx] if past_key_values is not None else None
1040 |
1041 | if self.gradient_checkpointing and self.training:
1042 |
1043 | if use_cache:
1044 | print(
1045 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1046 | )
1047 | use_cache = False
1048 |
1049 | def create_custom_forward(module):
1050 | def custom_forward(*inputs):
1051 | # None for past_key_value
1052 | return module(*inputs, output_attentions, use_cache)
1053 |
1054 | return custom_forward
1055 |
1056 | layer_outputs = torch.utils.checkpoint.checkpoint(
1057 | create_custom_forward(decoder_layer),
1058 | hidden_states,
1059 | attention_mask,
1060 | encoder_hidden_states,
1061 | encoder_attention_mask,
1062 | head_mask[idx] if head_mask is not None else None,
1063 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
1064 | None,
1065 | )
1066 | else:
1067 | layer_outputs = decoder_layer(
1068 | hidden_states,
1069 | attention_mask=attention_mask,
1070 | encoder_hidden_states=encoder_hidden_states,
1071 | encoder_attention_mask=encoder_attention_mask,
1072 | layer_head_mask=(head_mask[idx] if head_mask is not None else None),
1073 | cross_attn_layer_head_mask=(
1074 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
1075 | ),
1076 | past_key_value=past_key_value,
1077 | output_attentions=output_attentions,
1078 | use_cache=use_cache,
1079 | )
1080 | hidden_states = layer_outputs[0]
1081 |
1082 | if use_cache:
1083 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
1084 |
1085 | if output_attentions:
1086 | all_self_attns += (layer_outputs[1],)
1087 |
1088 | if encoder_hidden_states is not None:
1089 | all_cross_attentions += (layer_outputs[2],)
1090 |
1091 | # add hidden states from the last decoder layer
1092 | if output_hidden_states:
1093 | all_hidden_states += (hidden_states,)
1094 |
1095 | next_cache = next_decoder_cache if use_cache else None
1096 | if not return_dict:
1097 | return tuple(
1098 | v
1099 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
1100 | if v is not None
1101 | )
1102 | return BaseModelOutputWithPastAndCrossAttentions(
1103 | last_hidden_state=hidden_states,
1104 | past_key_values=next_cache,
1105 | hidden_states=all_hidden_states,
1106 | attentions=all_self_attns,
1107 | cross_attentions=all_cross_attentions,
1108 | )
1109 |
1110 |
1111 | class BartModel(BartPretrainedModel):
1112 | def __init__(self, config: BartConfig):
1113 | super().__init__(config)
1114 |
1115 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size
1116 | self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
1117 |
1118 | self.encoder = BartEncoder(config, self.shared)
1119 | self.decoder = BartDecoder(config, self.shared)
1120 |
1121 | # Initialize weights and apply final processing
1122 | self.post_init()
1123 |
1124 | def get_input_embeddings(self):
1125 | return self.shared
1126 |
1127 | def set_input_embeddings(self, value):
1128 | self.shared = value
1129 | self.encoder.embed_tokens = self.shared
1130 | self.decoder.embed_tokens = self.shared
1131 |
1132 | def get_encoder(self):
1133 | return self.encoder
1134 |
1135 | def get_decoder(self):
1136 | return self.decoder
1137 |
1138 | def forward(
1139 | self,
1140 | input_ids: torch.LongTensor = None,
1141 | attention_mask: Optional[torch.Tensor] = None,
1142 | decoder_input_ids: Optional[torch.LongTensor] = None,
1143 | decoder_attention_mask: Optional[torch.LongTensor] = None,
1144 | head_mask: Optional[torch.Tensor] = None,
1145 | decoder_head_mask: Optional[torch.Tensor] = None,
1146 | cross_attn_head_mask: Optional[torch.Tensor] = None,
1147 | encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1148 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1149 | inputs_embeds: Optional[torch.FloatTensor] = None,
1150 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1151 | use_cache: Optional[bool] = None,
1152 | output_attentions: Optional[bool] = None,
1153 | output_hidden_states: Optional[bool] = None,
1154 | return_dict: Optional[bool] = None,
1155 | ) -> Union[Tuple, Seq2SeqModelOutput]:
1156 |
1157 | # different to other models, Bart automatically creates decoder_input_ids from
1158 | # input_ids if no decoder_input_ids are provided
1159 | if decoder_input_ids is None and decoder_inputs_embeds is None:
1160 | if input_ids is None:
1161 | raise ValueError(
1162 | "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
1163 | "passed, `input_ids` cannot be `None`. Please pass either "
1164 | "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
1165 | )
1166 |
1167 | decoder_input_ids = shift_tokens_right(
1168 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id
1169 | )
1170 |
1171 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1172 | output_hidden_states = (
1173 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1174 | )
1175 | use_cache = use_cache if use_cache is not None else self.config.use_cache
1176 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1177 |
1178 | if encoder_outputs is None:
1179 | encoder_outputs = self.encoder(
1180 | input_ids=input_ids,
1181 | attention_mask=attention_mask,
1182 | head_mask=head_mask,
1183 | inputs_embeds=inputs_embeds,
1184 | output_attentions=output_attentions,
1185 | output_hidden_states=output_hidden_states,
1186 | return_dict=return_dict,
1187 | )
1188 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
1189 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1190 | encoder_outputs = BaseModelOutput(
1191 | last_hidden_state=encoder_outputs[0],
1192 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1193 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1194 | )
1195 |
1196 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
1197 | decoder_outputs = self.decoder(
1198 | input_ids=decoder_input_ids,
1199 | attention_mask=decoder_attention_mask,
1200 | encoder_hidden_states=encoder_outputs[0],
1201 | encoder_attention_mask=attention_mask,
1202 | head_mask=decoder_head_mask,
1203 | cross_attn_head_mask=cross_attn_head_mask,
1204 | past_key_values=past_key_values,
1205 | inputs_embeds=decoder_inputs_embeds,
1206 | use_cache=use_cache,
1207 | output_attentions=output_attentions,
1208 | output_hidden_states=output_hidden_states,
1209 | return_dict=return_dict,
1210 | )
1211 |
1212 | if not return_dict:
1213 | return decoder_outputs + encoder_outputs
1214 |
1215 | return Seq2SeqModelOutput(
1216 | last_hidden_state=decoder_outputs.last_hidden_state,
1217 | past_key_values=decoder_outputs.past_key_values,
1218 | decoder_hidden_states=decoder_outputs.hidden_states,
1219 | decoder_attentions=decoder_outputs.attentions,
1220 | cross_attentions=decoder_outputs.cross_attentions,
1221 | encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1222 | encoder_hidden_states=encoder_outputs.hidden_states,
1223 | encoder_attentions=encoder_outputs.attentions,
1224 | )
1225 |
1226 |
1227 | class BartForConditionalGeneration(BartPretrainedModel):
1228 | base_model_prefix = "model"
1229 | _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head\.weight"]
1230 |
1231 | def __init__(self, config: BartConfig):
1232 | super().__init__(config)
1233 | self.model = BartModel(config)
1234 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
1235 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
1236 |
1237 | # Initialize weights and apply final processing
1238 | self.post_init()
1239 |
1240 | def get_encoder(self):
1241 | return self.model.get_encoder()
1242 |
1243 | def get_decoder(self):
1244 | return self.model.get_decoder()
1245 |
1246 | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
1247 | new_embeddings = super().resize_token_embeddings(new_num_tokens)
1248 | self._resize_final_logits_bias(new_num_tokens)
1249 | return new_embeddings
1250 |
1251 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
1252 | old_num_tokens = self.final_logits_bias.shape[-1]
1253 | if new_num_tokens <= old_num_tokens:
1254 | new_bias = self.final_logits_bias[:, :new_num_tokens]
1255 | else:
1256 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
1257 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
1258 | self.register_buffer("final_logits_bias", new_bias)
1259 |
1260 | def get_output_embeddings(self):
1261 | return self.lm_head
1262 |
1263 | def set_output_embeddings(self, new_embeddings):
1264 | self.lm_head = new_embeddings
1265 |
1266 | def forward(
1267 | self,
1268 | input_ids: torch.LongTensor = None,
1269 | attention_mask: Optional[torch.Tensor] = None,
1270 | decoder_input_ids: Optional[torch.LongTensor] = None,
1271 | decoder_attention_mask: Optional[torch.LongTensor] = None,
1272 | head_mask: Optional[torch.Tensor] = None,
1273 | decoder_head_mask: Optional[torch.Tensor] = None,
1274 | cross_attn_head_mask: Optional[torch.Tensor] = None,
1275 | encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1276 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1277 | inputs_embeds: Optional[torch.FloatTensor] = None,
1278 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1279 | labels: Optional[torch.LongTensor] = None,
1280 | use_cache: Optional[bool] = None,
1281 | output_attentions: Optional[bool] = None,
1282 | output_hidden_states: Optional[bool] = None,
1283 | return_dict: Optional[bool] = None,
1284 | ) -> Union[Tuple, Seq2SeqLMOutput]:
1285 | r"""
1286 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1287 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1288 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1289 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1290 | Returns:
1291 | """
1292 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1293 |
1294 | if labels is not None:
1295 | if use_cache:
1296 | print("The `use_cache` argument is changed to `False` since `labels` is provided.")
1297 | use_cache = False
1298 | if decoder_input_ids is None and decoder_inputs_embeds is None:
1299 | decoder_input_ids = shift_tokens_right(
1300 | labels, self.config.pad_token_id, self.config.decoder_start_token_id
1301 | )
1302 |
1303 | outputs = self.model(
1304 | input_ids,
1305 | attention_mask=attention_mask,
1306 | decoder_input_ids=decoder_input_ids,
1307 | encoder_outputs=encoder_outputs,
1308 | decoder_attention_mask=decoder_attention_mask,
1309 | head_mask=head_mask,
1310 | decoder_head_mask=decoder_head_mask,
1311 | cross_attn_head_mask=cross_attn_head_mask,
1312 | past_key_values=past_key_values,
1313 | inputs_embeds=inputs_embeds,
1314 | decoder_inputs_embeds=decoder_inputs_embeds,
1315 | use_cache=use_cache,
1316 | output_attentions=output_attentions,
1317 | output_hidden_states=output_hidden_states,
1318 | return_dict=return_dict,
1319 | )
1320 | lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias
1321 |
1322 | masked_lm_loss = None
1323 | if labels is not None:
1324 | loss_fct = CrossEntropyLoss()
1325 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
1326 |
1327 | if not return_dict:
1328 | output = (lm_logits,) + outputs[1:]
1329 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1330 |
1331 | return Seq2SeqLMOutput(
1332 | loss=masked_lm_loss,
1333 | logits=lm_logits,
1334 | past_key_values=outputs.past_key_values,
1335 | decoder_hidden_states=outputs.decoder_hidden_states,
1336 | decoder_attentions=outputs.decoder_attentions,
1337 | cross_attentions=outputs.cross_attentions,
1338 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1339 | encoder_hidden_states=outputs.encoder_hidden_states,
1340 | encoder_attentions=outputs.encoder_attentions,
1341 | )
1342 |
1343 | def prepare_inputs_for_generation(
1344 | self,
1345 | decoder_input_ids,
1346 | past=None,
1347 | attention_mask=None,
1348 | head_mask=None,
1349 | decoder_head_mask=None,
1350 | cross_attn_head_mask=None,
1351 | use_cache=None,
1352 | encoder_outputs=None,
1353 | **kwargs
1354 | ):
1355 | # cut decoder_input_ids if past is used
1356 | if past is not None:
1357 | decoder_input_ids = decoder_input_ids[:, -1:]
1358 |
1359 | return {
1360 | "input_ids": None, # encoder_outputs is defined. input_ids not needed
1361 | "encoder_outputs": encoder_outputs,
1362 | "past_key_values": past,
1363 | "decoder_input_ids": decoder_input_ids,
1364 | "attention_mask": attention_mask,
1365 | "head_mask": head_mask,
1366 | "decoder_head_mask": decoder_head_mask,
1367 | "cross_attn_head_mask": cross_attn_head_mask,
1368 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
1369 | }
1370 |
1371 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1372 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
1373 |
1374 | @staticmethod
1375 | def _reorder_cache(past, beam_idx):
1376 | reordered_past = ()
1377 | for layer_past in past:
1378 | # cached cross_attention states don't have to be reordered -> they are always the same
1379 | reordered_past += (
1380 | tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
1381 | )
1382 | return reordered_past
1383 |
1384 |
1385 | class BartForSequenceClassification(BartPretrainedModel):
1386 | def __init__(self, config: BartConfig, **kwargs):
1387 | super().__init__(config, **kwargs)
1388 | self.model = BartModel(config)
1389 | self.classification_head = BartClassificationHead(
1390 | config.d_model,
1391 | config.d_model,
1392 | config.num_labels,
1393 | config.classifier_dropout,
1394 | )
1395 | self.model._init_weights(self.classification_head.dense)
1396 | self.model._init_weights(self.classification_head.out_proj)
1397 |
1398 | def forward(
1399 | self,
1400 | input_ids: torch.LongTensor = None,
1401 | attention_mask: Optional[torch.Tensor] = None,
1402 | decoder_input_ids: Optional[torch.LongTensor] = None,
1403 | decoder_attention_mask: Optional[torch.LongTensor] = None,
1404 | head_mask: Optional[torch.Tensor] = None,
1405 | decoder_head_mask: Optional[torch.Tensor] = None,
1406 | cross_attn_head_mask: Optional[torch.Tensor] = None,
1407 | encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1408 | inputs_embeds: Optional[torch.FloatTensor] = None,
1409 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1410 | labels: Optional[torch.LongTensor] = None,
1411 | use_cache: Optional[bool] = None,
1412 | output_attentions: Optional[bool] = None,
1413 | output_hidden_states: Optional[bool] = None,
1414 | return_dict: Optional[bool] = None,
1415 | ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
1416 | r"""
1417 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1418 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1419 | config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1420 | """
1421 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1422 | if labels is not None:
1423 | use_cache = False
1424 |
1425 | if input_ids is None and inputs_embeds is not None:
1426 | raise NotImplementedError(
1427 | f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
1428 | )
1429 |
1430 | outputs = self.model(
1431 | input_ids,
1432 | attention_mask=attention_mask,
1433 | decoder_input_ids=decoder_input_ids,
1434 | decoder_attention_mask=decoder_attention_mask,
1435 | head_mask=head_mask,
1436 | decoder_head_mask=decoder_head_mask,
1437 | cross_attn_head_mask=cross_attn_head_mask,
1438 | encoder_outputs=encoder_outputs,
1439 | inputs_embeds=inputs_embeds,
1440 | decoder_inputs_embeds=decoder_inputs_embeds,
1441 | use_cache=use_cache,
1442 | output_attentions=output_attentions,
1443 | output_hidden_states=output_hidden_states,
1444 | return_dict=return_dict,
1445 | )
1446 | hidden_states = outputs[0] # last hidden state
1447 |
1448 | eos_mask = input_ids.eq(self.config.eos_token_id)
1449 |
1450 | if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
1451 | raise ValueError("All examples must have the same number of tokens.")
1452 | sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
1453 | :, -1, :
1454 | ]
1455 | logits = self.classification_head(sentence_representation)
1456 |
1457 | loss = None
1458 | if labels is not None:
1459 | if self.config.problem_type is None:
1460 | if self.config.num_labels == 1:
1461 | self.config.problem_type = "regression"
1462 | elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1463 | self.config.problem_type = "single_label_classification"
1464 | else:
1465 | self.config.problem_type = "multi_label_classification"
1466 |
1467 | if self.config.problem_type == "regression":
1468 | loss_fct = MSELoss()
1469 | if self.config.num_labels == 1:
1470 | loss = loss_fct(logits.squeeze(), labels.squeeze())
1471 | else:
1472 | loss = loss_fct(logits, labels)
1473 | elif self.config.problem_type == "single_label_classification":
1474 | loss_fct = CrossEntropyLoss()
1475 | loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1476 | elif self.config.problem_type == "multi_label_classification":
1477 | loss_fct = BCEWithLogitsLoss()
1478 | loss = loss_fct(logits, labels)
1479 | if not return_dict:
1480 | output = (logits,) + outputs[1:]
1481 | return ((loss,) + output) if loss is not None else output
1482 |
1483 | return Seq2SeqSequenceClassifierOutput(
1484 | loss=loss,
1485 | logits=logits,
1486 | past_key_values=outputs.past_key_values,
1487 | decoder_hidden_states=outputs.decoder_hidden_states,
1488 | decoder_attentions=outputs.decoder_attentions,
1489 | cross_attentions=outputs.cross_attentions,
1490 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1491 | encoder_hidden_states=outputs.encoder_hidden_states,
1492 | encoder_attentions=outputs.encoder_attentions,
1493 | )
1494 |
1495 |
1496 | class BartForQuestionAnswering(BartPretrainedModel):
1497 | def __init__(self, config):
1498 | super().__init__(config)
1499 |
1500 | config.num_labels = 2
1501 | self.num_labels = config.num_labels
1502 |
1503 | self.model = BartModel(config)
1504 | self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1505 |
1506 | self.model._init_weights(self.qa_outputs)
1507 |
1508 | def forward(
1509 | self,
1510 | input_ids: torch.Tensor = None,
1511 | attention_mask: Optional[torch.Tensor] = None,
1512 | decoder_input_ids: Optional[torch.LongTensor] = None,
1513 | decoder_attention_mask: Optional[torch.LongTensor] = None,
1514 | head_mask: Optional[torch.Tensor] = None,
1515 | decoder_head_mask: Optional[torch.Tensor] = None,
1516 | cross_attn_head_mask: Optional[torch.Tensor] = None,
1517 | encoder_outputs: Optional[List[torch.FloatTensor]] = None,
1518 | start_positions: Optional[torch.LongTensor] = None,
1519 | end_positions: Optional[torch.LongTensor] = None,
1520 | inputs_embeds: Optional[torch.FloatTensor] = None,
1521 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1522 | use_cache: Optional[bool] = None,
1523 | output_attentions: Optional[bool] = None,
1524 | output_hidden_states: Optional[bool] = None,
1525 | return_dict: Optional[bool] = None,
1526 | ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]:
1527 | r"""
1528 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1529 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
1530 | Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
1531 | are not taken into account for computing the loss.
1532 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1533 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
1534 | Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
1535 | are not taken into account for computing the loss.
1536 | """
1537 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1538 | if start_positions is not None and end_positions is not None:
1539 | use_cache = False
1540 |
1541 | outputs = self.model(
1542 | input_ids,
1543 | attention_mask=attention_mask,
1544 | decoder_input_ids=decoder_input_ids,
1545 | decoder_attention_mask=decoder_attention_mask,
1546 | head_mask=head_mask,
1547 | decoder_head_mask=decoder_head_mask,
1548 | cross_attn_head_mask=cross_attn_head_mask,
1549 | encoder_outputs=encoder_outputs,
1550 | inputs_embeds=inputs_embeds,
1551 | decoder_inputs_embeds=decoder_inputs_embeds,
1552 | use_cache=use_cache,
1553 | output_attentions=output_attentions,
1554 | output_hidden_states=output_hidden_states,
1555 | return_dict=return_dict,
1556 | )
1557 |
1558 | sequence_output = outputs[0]
1559 |
1560 | logits = self.qa_outputs(sequence_output)
1561 | start_logits, end_logits = logits.split(1, dim=-1)
1562 | start_logits = start_logits.squeeze(-1).contiguous()
1563 | end_logits = end_logits.squeeze(-1).contiguous()
1564 |
1565 | total_loss = None
1566 | if start_positions is not None and end_positions is not None:
1567 | # If we are on multi-GPU, split add a dimension
1568 | if len(start_positions.size()) > 1:
1569 | start_positions = start_positions.squeeze(-1)
1570 | if len(end_positions.size()) > 1:
1571 | end_positions = end_positions.squeeze(-1)
1572 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
1573 | ignored_index = start_logits.size(1)
1574 | start_positions = start_positions.clamp(0, ignored_index)
1575 | end_positions = end_positions.clamp(0, ignored_index)
1576 |
1577 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1578 | start_loss = loss_fct(start_logits, start_positions)
1579 | end_loss = loss_fct(end_logits, end_positions)
1580 | total_loss = (start_loss + end_loss) / 2
1581 |
1582 | if not return_dict:
1583 | output = (
1584 | start_logits,
1585 | end_logits,
1586 | ) + outputs[1:]
1587 | return ((total_loss,) + output) if total_loss is not None else output
1588 |
1589 | return Seq2SeqQuestionAnsweringModelOutput(
1590 | loss=total_loss,
1591 | start_logits=start_logits,
1592 | end_logits=end_logits,
1593 | past_key_values=outputs.past_key_values,
1594 | decoder_hidden_states=outputs.decoder_hidden_states,
1595 | decoder_attentions=outputs.decoder_attentions,
1596 | cross_attentions=outputs.cross_attentions,
1597 | encoder_last_hidden_state=outputs.encoder_last_hidden_state,
1598 | encoder_hidden_states=outputs.encoder_hidden_states,
1599 | encoder_attentions=outputs.encoder_attentions,
1600 | )
1601 |
1602 |
1603 | class BartDecoderWrapper(BartPretrainedModel):
1604 | """
1605 | This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
1606 | used in combination with the [`EncoderDecoderModel`] framework.
1607 | """
1608 |
1609 | def __init__(self, config):
1610 | super().__init__(config)
1611 | self.decoder = BartDecoder(config)
1612 |
1613 | def forward(self, *args, **kwargs):
1614 | return self.decoder(*args, **kwargs)
1615 |
1616 |
1617 | class BartForCausalLM(BartPretrainedModel):
1618 | def __init__(self, config):
1619 | config = copy.deepcopy(config)
1620 | config.is_decoder = True
1621 | config.is_encoder_decoder = False
1622 | super().__init__(config)
1623 | self.model = BartDecoderWrapper(config)
1624 |
1625 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1626 |
1627 | # Initialize weights and apply final processing
1628 | self.post_init()
1629 |
1630 | def get_input_embeddings(self):
1631 | return self.model.decoder.embed_tokens
1632 |
1633 | def set_input_embeddings(self, value):
1634 | self.model.decoder.embed_tokens = value
1635 |
1636 | def get_output_embeddings(self):
1637 | return self.lm_head
1638 |
1639 | def set_output_embeddings(self, new_embeddings):
1640 | self.lm_head = new_embeddings
1641 |
1642 | def set_decoder(self, decoder):
1643 | self.model.decoder = decoder
1644 |
1645 | def get_decoder(self):
1646 | return self.model.decoder
1647 |
1648 | def forward(
1649 | self,
1650 | input_ids: torch.LongTensor = None,
1651 | attention_mask: Optional[torch.Tensor] = None,
1652 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
1653 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
1654 | head_mask: Optional[torch.Tensor] = None,
1655 | cross_attn_head_mask: Optional[torch.Tensor] = None,
1656 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1657 | inputs_embeds: Optional[torch.FloatTensor] = None,
1658 | labels: Optional[torch.LongTensor] = None,
1659 | use_cache: Optional[bool] = None,
1660 | output_attentions: Optional[bool] = None,
1661 | output_hidden_states: Optional[bool] = None,
1662 | return_dict: Optional[bool] = None,
1663 | ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1664 | r"""
1665 | Args:
1666 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1667 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
1668 | provide it.
1669 | Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1670 | [`PreTrainedTokenizer.__call__`] for details.
1671 | [What are input IDs?](../glossary#input-ids)
1672 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1673 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1674 | - 1 for tokens that are **not masked**,
1675 | - 0 for tokens that are **masked**.
1676 | [What are attention masks?](../glossary#attention-mask)
1677 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1678 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1679 | if the model is configured as a decoder.
1680 | encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1681 | Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
1682 | in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1683 | head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1684 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
1685 | - 1 indicates the head is **not masked**,
1686 | - 0 indicates the head is **masked**.
1687 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
1688 | Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
1689 | - 1 indicates the head is **not masked**,
1690 | - 0 indicates the head is **masked**.
1691 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1692 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
1693 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
1694 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
1695 | tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
1696 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
1697 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
1698 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
1699 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
1700 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1701 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1702 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1703 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1704 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1705 | use_cache (`bool`, *optional*):
1706 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
1707 | (see `past_key_values`).
1708 | - 1 for tokens that are **not masked**,
1709 | - 0 for tokens that are **masked**.
1710 | output_attentions (`bool`, *optional*):
1711 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1712 | returned tensors for more detail.
1713 | output_hidden_states (`bool`, *optional*):
1714 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1715 | for more detail.
1716 | return_dict (`bool`, *optional*):
1717 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1718 | Returns:
1719 | Example:
1720 | ```python
1721 | >>> from transformers import BartTokenizer, BartForCausalLM
1722 | >>> tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
1723 | >>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)
1724 | >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
1725 | >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1726 | >>> outputs = model(**inputs)
1727 | >>> logits = outputs.logits
1728 | >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
1729 | >>> list(logits.shape) == expected_shape
1730 | True
1731 | ```"""
1732 |
1733 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1734 | output_hidden_states = (
1735 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1736 | )
1737 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1738 |
1739 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1740 | outputs = self.model.decoder(
1741 | input_ids=input_ids,
1742 | attention_mask=attention_mask,
1743 | encoder_hidden_states=encoder_hidden_states,
1744 | encoder_attention_mask=encoder_attention_mask,
1745 | head_mask=head_mask,
1746 | cross_attn_head_mask=cross_attn_head_mask,
1747 | past_key_values=past_key_values,
1748 | inputs_embeds=inputs_embeds,
1749 | use_cache=use_cache,
1750 | output_attentions=output_attentions,
1751 | output_hidden_states=output_hidden_states,
1752 | return_dict=return_dict,
1753 | )
1754 |
1755 | logits = self.lm_head(outputs[0])
1756 |
1757 | loss = None
1758 | if labels is not None:
1759 | loss_fct = CrossEntropyLoss()
1760 | loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
1761 |
1762 | if not return_dict:
1763 | output = (logits,) + outputs[1:]
1764 | return (loss,) + output if loss is not None else output
1765 |
1766 | return CausalLMOutputWithCrossAttentions(
1767 | loss=loss,
1768 | logits=logits,
1769 | past_key_values=outputs.past_key_values,
1770 | hidden_states=outputs.hidden_states,
1771 | attentions=outputs.attentions,
1772 | cross_attentions=outputs.cross_attentions,
1773 | )
1774 |
1775 | def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
1776 | # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1777 | if attention_mask is None:
1778 | attention_mask = input_ids.new_ones(input_ids.shape)
1779 |
1780 | if past:
1781 | input_ids = input_ids[:, -1:]
1782 | # first step, decoder_cached_states are empty
1783 | return {
1784 | "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
1785 | "attention_mask": attention_mask,
1786 | "past_key_values": past,
1787 | "use_cache": use_cache,
1788 | }
1789 |
1790 | @staticmethod
1791 | def _reorder_cache(past, beam_idx):
1792 | reordered_past = ()
1793 | for layer_past in past:
1794 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1795 | return reordered_past
--------------------------------------------------------------------------------