├── .gitignore ├── README.md ├── data ├── test.json └── train.json ├── examples ├── evaluation.ipynb ├── evaluation.py ├── evaluation.sh ├── knn_classifier.pkl ├── train_knn.py └── train_knn_classifier.ipynb ├── models └── README.md ├── requirements.txt ├── rl_training ├── rl_utils.py ├── train.py ├── train_rl.sh └── training_config.yaml ├── setup.py └── src └── EntFA ├── model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | *.py[cod] 6 | *$py.class 7 | *.out 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | 150 | # PyCharm 151 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 152 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 153 | # and can be added to the global gitignore or merged into this file. For a more nuclear 154 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 155 | #.idea/ 156 | 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Inspecting the Factuality of Hallucinations in Abstractive Summarization 2 | 3 | This directory contains code necessary to replicate the training and evaluation for the ACL 2022 paper ["Hallucinated but Factual! Inspecting the Factuality of Hallucinations in Abstractive Summarization"](https://arxiv.org/pdf/2109.09784.pdf) by [Meng Cao](https://mcao516.github.io/), [Yue Dong](https://www.cs.mcgill.ca/~ydong26/) and [Jackie Chi Kit Cheung](https://www.cs.mcgill.ca/~jcheung/). 4 | 5 | ## Dependencies and Setup 6 | The code is based on Huggingface's [Transformers](https://github.com/huggingface/transformers) library. 7 | ``` 8 | git clone https://github.com/mcao516/EntFA.git 9 | cd ./EntFA 10 | pip install -r requirements.txt 11 | python setup.py install 12 | ``` 13 | 14 | ## How to Run 15 | Conditional masked language model (CMLM) checkpoint can be found [here](https://drive.google.com/drive/folders/10ibVc5R7q4Gc0TH1AIRo7IaLCV83SkpF?usp=sharing). For masked language model (MLM), download `bart.large` at Fairseq's [BART](https://github.com/pytorch/fairseq/tree/main/examples/bart) repository. Download CMLM and MLM, put them in the models directory. 16 | 17 | ### Train KNN Classifier 18 | ```bash 19 | OUTPUT_DIR=knn_checkpoint 20 | mkdir $OUTPUT_DIR 21 | 22 | python examples/train_knn.py \ 23 | --train-path data/train.json \ 24 | --test-path data/test.json \ 25 | --cmlm-model-path models \ 26 | --data-name-or-path models/xsum-bin \ 27 | --mlm-path models/bart.large \ 28 | --output-dir $OUTPUT_DIR; 29 | ``` 30 | You can also find an example at `examples/train_knn_classifier.ipynb`. 31 | 32 | ### Evaluation 33 | Evalute the entity-level factuality of generated summaries. Input file format: one document/summary per line. 34 | 35 | ```bash 36 | SOURCE_PATH=test.source 37 | TARGET_PATH=test.hypothesis 38 | 39 | python examples/evaluation.py \ 40 | --source-path $SOURCE_PATH \ 41 | --target-path $TARGET_PATH \ 42 | --cmlm-model-path models \ 43 | --data-name-or-path models/xsum-bin \ 44 | --mlm-path models/bart.large \ 45 | --knn-model-path models/knn_classifier.pkl; 46 | ``` 47 | Also check `examples/evaluation.ipynb`. 48 | -------------------------------------------------------------------------------- /examples/evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "appointed-reconstruction", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import json\n", 11 | "import torch\n", 12 | "\n", 13 | "from fairseq.models.bart import BARTModel" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 4, 19 | "id": "ac01838a-e509-4b17-91ed-d50e7c4ead07", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "CMLM_MODEL_PATH = 'BART_models/xsum_cedar_cmlm'\n", 24 | "MLM_MODEL_PATH = 'BART_models/bart.large'\n", 25 | "DATA_NAME_OR_PATH = 'summarization/XSum/fairseq_files/xsum-bin'" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 5, 31 | "id": "assigned-exhibition", 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stderr", 36 | "output_type": "stream", 37 | "text": [ 38 | "2022-03-15 14:57:32 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/xsum_cedar_cmlm\n", 39 | "2022-03-15 14:57:32 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/summarization/XSum/fairseq_files/xsum-bin\n", 40 | "2022-03-15 14:57:41 | INFO | fairseq.tasks.translation | [source] dictionary: 50264 types\n", 41 | "2022-03-15 14:57:41 | INFO | fairseq.tasks.translation | [target] dictionary: 50264 types\n" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "bart = BARTModel.from_pretrained(CMLM_MODEL_PATH,\n", 47 | " checkpoint_file='checkpoint_best.pt',\n", 48 | " data_name_or_path=DATA_NAME_OR_PATH)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 6, 54 | "id": "elementary-dutch", 55 | "metadata": {}, 56 | "outputs": [ 57 | { 58 | "name": "stderr", 59 | "output_type": "stream", 60 | "text": [ 61 | "2022-03-15 14:57:52 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/bart.large\n", 62 | "2022-03-15 14:57:52 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/bart.large\n", 63 | "2022-03-15 14:57:59 | INFO | fairseq.tasks.denoising | dictionary: 50264 types\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "prior_bart = BARTModel.from_pretrained(MLM_MODEL_PATH,\n", 69 | " checkpoint_file='model.pt',\n", 70 | " data_name_or_path=MLM_MODEL_PATH)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "id": "adjusted-raise", 76 | "metadata": {}, 77 | "source": [ 78 | "#### Build Prior & Posterior Model" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 7, 84 | "id": "trained-township", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "from EntFA.model import ConditionalSequenceGenerator\n", 89 | "from EntFA.utils import prepare_cmlm_inputs, prepare_mlm_inputs, get_probability_parallel" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 8, 95 | "id": "varied-renewal", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "model = ConditionalSequenceGenerator(bart)\n", 100 | "prior_model = ConditionalSequenceGenerator(prior_bart)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "markdown", 105 | "id": "authorized-platform", 106 | "metadata": {}, 107 | "source": [ 108 | "#### Test on One Sample" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 9, 114 | "id": "c884b778-43a2-4236-87d0-873732bd7dd6", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "import spacy\n", 119 | "\n", 120 | "nlp = spacy.load('en_core_web_sm')" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 10, 126 | "id": "7f9a8087-2c69-4dca-a3d1-2f524d6901ee", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "source = 'The city was brought to a standstill on 15 December last year when a gunman held 18 hostages for 17 hours. Family members of victims Tori Johnson and Katrina Dawson were in attendance. Images of the floral tributes that filled the city centre in the wake of the siege were projected on to the cafe and surrounding buildings in an emotional twilight ceremony. Prime Minister Malcolm Turnbull gave an address saying a \"whole nation resolved to answer hatred with love\". \"Testament to the spirit of Australians is that with such unnecessary, thoughtless tragedy, an amazing birth of mateship, unity and love occurs. Proud to be Australian,\" he said. How the Sydney siege unfolded. New South Wales Premier Mike Baird has also announced plans for a permanent memorial to be built into the pavement in Martin Place. Clear cubes containing flowers will be embedded into the concrete and will shine with specialised lighting. It is a project inspired by the massive floral tributes that were left in the days after the siege. \"Something remarkable happened here. As a city we were drawn to Martin Place. We came in shock and in sorrow but every step we took was with purpose,\" he said on Tuesday.'\n", 131 | "prediction = 'Sydney has marked the first anniversary of the siege at the Waverley cafe in which two women were killed by a gunman in the Australian city.'" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 11, 137 | "id": "indoor-shield", 138 | "metadata": {}, 139 | "outputs": [ 140 | { 141 | "name": "stdout", 142 | "output_type": "stream", 143 | "text": [ 144 | "['Sydney', 'first', 'Waverley', 'two', 'Australian']\n" 145 | ] 146 | } 147 | ], 148 | "source": [ 149 | "entities = nlp(prediction).to_json()['ents']\n", 150 | "ent_text = [prediction[e['start']: e['end']] for e in entities]\n", 151 | "print(ent_text)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 12, 157 | "id": "occupational-encoding", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "inputs = prepare_cmlm_inputs(source, prediction, ent_parts=entities)\n", 162 | "posteriors = get_probability_parallel(model, inputs[0], inputs[1], inputs[2], inputs[3])" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 13, 168 | "id": "lesser-ontario", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "inputs = prepare_mlm_inputs(source, prediction, ent_parts=entities)\n", 173 | "priors = get_probability_parallel(prior_model, inputs[0], inputs[1], inputs[2], inputs[3], mask_filling=True)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 14, 179 | "id": "0d649230-2fee-466d-b451-32c8e664dae5", 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | " \tPrior\t\tPosterior\n", 187 | "Sydney \t0.00366783\t0.946777\n", 188 | "first \t0.116516\t0.325928\n", 189 | "Waverley\t0.0179596\t0.00888062\n", 190 | "two \t0.0629272\t0.858887\n", 191 | "Australian\t0.00283623\t0.911133\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "print('{:<8}\\t{:}\\t\\t{:}'.format('', 'Prior', 'Posterior'))\n", 197 | "for e, pri, pos in zip(ent_text, priors, posteriors):\n", 198 | " print('{:<8}\\t{:.6}\\t{:.6}'.format(e, pri, pos))" 199 | ] 200 | } 201 | ], 202 | "metadata": { 203 | "kernelspec": { 204 | "display_name": "Python 3 (ipykernel)", 205 | "language": "python", 206 | "name": "python3" 207 | }, 208 | "language_info": { 209 | "codemirror_mode": { 210 | "name": "ipython", 211 | "version": 3 212 | }, 213 | "file_extension": ".py", 214 | "mimetype": "text/x-python", 215 | "name": "python", 216 | "nbconvert_exporter": "python", 217 | "pygments_lexer": "ipython3", 218 | "version": "3.8.2" 219 | } 220 | }, 221 | "nbformat": 4, 222 | "nbformat_minor": 5 223 | } 224 | -------------------------------------------------------------------------------- /examples/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | Factuality evaluation using the trained KNN classifier. 5 | """ 6 | 7 | import argparse 8 | import os 9 | import json 10 | import torch 11 | import spacy 12 | import pickle 13 | import numpy as np 14 | 15 | from os.path import join 16 | from tqdm import tqdm 17 | from fairseq.models.bart import BARTModel 18 | 19 | from EntFA.model import ConditionalSequenceGenerator 20 | from EntFA.utils import prepare_cmlm_inputs, prepare_mlm_inputs, get_probability_parallel 21 | from EntFA.utils import read_lines 22 | 23 | nlp = spacy.load('en_core_web_sm') 24 | 25 | 26 | def build_models(args): 27 | prior_bart = BARTModel.from_pretrained(args.mlm_path, 28 | checkpoint_file='model.pt', 29 | data_name_or_path=args.mlm_path) 30 | prior_model = ConditionalSequenceGenerator(prior_bart) 31 | 32 | bart = BARTModel.from_pretrained(args.cmlm_model_path, 33 | checkpoint_file='checkpoint_best.pt', 34 | data_name_or_path=args.data_name_or_path) 35 | model = ConditionalSequenceGenerator(bart) 36 | 37 | return prior_model, model 38 | 39 | 40 | def extract_features(source, hypothesis, prior_model, model): 41 | features = [] 42 | empty, error_count = 0, 0 43 | 44 | for index in tqdm(range(len(hypothesis))): 45 | source_doc, target_doc = source[index], hypothesis[index] 46 | target_doc = target_doc.replace("“", '"').replace("”", '"').replace("’", "'") 47 | target_doc = target_doc.replace("%.", "% .") 48 | target_doc = target_doc.replace("%,", "% ,") 49 | target_doc = target_doc.replace("%)", "% )") 50 | 51 | # extract entities 52 | ent_parts = nlp(target_doc).to_json()['ents'] 53 | entities = [target_doc[e['start']: e['end']] for e in ent_parts] 54 | 55 | if len(ent_parts) > 0: 56 | pri_inputs = prepare_mlm_inputs(source, target_doc, ent_parts=ent_parts) 57 | pos_inputs = prepare_cmlm_inputs(source_doc, target_doc, ent_parts=ent_parts) 58 | 59 | # calculate probability features 60 | try: 61 | pri_probs = get_probability_parallel(prior_model, pri_inputs[0], pri_inputs[1], pri_inputs[2], pri_inputs[3], mask_filling=True) 62 | pos_probs = get_probability_parallel(model, pos_inputs[0], pos_inputs[1], pos_inputs[2], pos_inputs[3]) 63 | 64 | # overlapping feature 65 | source_doc = source_doc.lower() 66 | overlap = [] 67 | for e in entities: 68 | if e[:4] == 'the ': e = e[4:] 69 | if e.lower() in source_doc: 70 | overlap.append(1) 71 | else: 72 | overlap.append(0) 73 | 74 | assert len(pri_probs) == len(pos_probs) == len(pri_inputs[2]) == len(pos_inputs[3]) 75 | features.append((pos_inputs[3], pos_inputs[2], pri_probs, pos_probs, overlap)) 76 | except AssertionError as err: 77 | print("{}: {}".format(index, err)) 78 | error_count += 1 79 | 80 | else: 81 | empty += 1 82 | features.append(([], [], [], [], [])) 83 | 84 | return features 85 | 86 | 87 | def infernece(test_features, classifier): 88 | """ 89 | Args: 90 | test_features (List[List]): [[prior, posterior, overlap_feature], ...] 91 | classifier: KNN classifier 92 | """ 93 | x_mat = np.array(test_features) 94 | stds = [np.std(x_mat[:, 0]), np.std(x_mat[:, 1]), np.std(x_mat[:, 2])] 95 | x_mat = np.vstack([x_mat[:, 0]/stds[0], x_mat[:, 1]/stds[1], x_mat[:, 2]/stds[2]]).transpose() 96 | Z = classifier.predict(x_mat) 97 | return Z 98 | 99 | 100 | def main(args): 101 | print('- Build prior/posterior models...') 102 | prior_model, posterior_model = build_models(args) 103 | print('- Done.') 104 | 105 | print('- Read source documents and summaries...') 106 | source = read_lines(args.source_path) 107 | hypothesis = read_lines(args.target_path) 108 | print('- Done. {} summaries to be evaluated.'.format(len(hypothesis))) 109 | 110 | print('- Extract features...') 111 | features = extract_features(source, hypothesis, prior_model, posterior_model) 112 | print('- Done.') 113 | 114 | test_features = [] 115 | for sample in features: 116 | for pri, pos, ovrp in zip(sample[2], sample[3], sample[4]): 117 | test_features.append([pri, pos, ovrp]) 118 | 119 | print('- Start inference...') 120 | classifier = pickle.load(open(args.knn_model_path, 'rb')) 121 | Z = infernece(test_features, classifier) 122 | print('- Done.') 123 | 124 | print('- Total extracted entities: ', Z.shape[0]) 125 | print('- Non-factual entities: {:.2f}%'.format((Z.sum() / Z.shape[0]) * 100)) 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument( 131 | "--source_path", 132 | type=str, 133 | default=None, 134 | required=True, 135 | help="The path of the source articles.", 136 | ) 137 | parser.add_argument( 138 | "--target_path", 139 | type=str, 140 | default=None, 141 | required=True, 142 | help="The path of the summaries to be evaluated.", 143 | ) 144 | parser.add_argument( 145 | "--cmlm_model_path", 146 | type=str, 147 | required=True, 148 | ) 149 | parser.add_argument( 150 | "--data_name_or_path", 151 | type=str, 152 | required=True, 153 | ) 154 | parser.add_argument( 155 | "--mlm_path", 156 | type=str, 157 | required=True, 158 | ) 159 | parser.add_argument( 160 | "--knn_model_path", 161 | type=str, 162 | required=True, 163 | ) 164 | 165 | args = parser.parse_args() 166 | main(args) -------------------------------------------------------------------------------- /examples/evaluation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | module load cuda/11.0 3 | module load python/3.8 4 | 5 | source $HOME/envFS/bin/activate 6 | 7 | 8 | SOURCE_PATH=$SCRATCH/summarization/XSum/fairseq_files/test.source 9 | TARGET_PATH=$SCRATCH/summarization/XSum/fairseq_files/test.target 10 | 11 | python evaluation.py \ 12 | --source_path $SOURCE_PATH \ 13 | --target_path $TARGET_PATH \ 14 | --cmlm_model_path $SCRATCH/BART_models/xsum_cedar_cmlm \ 15 | --data_name_or_path $SCRATCH/summarization/XSum/fairseq_files/xsum-bin \ 16 | --mlm_path $SCRATCH/BART_models/bart.large \ 17 | --knn_model_path $HOME/EntFA/examples/knn_classifier.pkl; -------------------------------------------------------------------------------- /examples/knn_classifier.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mcao516/EntFA/e6f857ddf0ba43a7bc5e6e0b40cfaada99f79583/examples/knn_classifier.pkl -------------------------------------------------------------------------------- /examples/train_knn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | """ 4 | Train a KNN classifier for factuality evaluation. 5 | """ 6 | 7 | import argparse 8 | import os 9 | import json 10 | import torch 11 | import pickle 12 | import numpy as np 13 | 14 | from tqdm import tqdm 15 | from sklearn import neighbors 16 | from sklearn.metrics import classification_report, f1_score, accuracy_score 17 | 18 | from fairseq.models.bart import BARTModel 19 | 20 | from EntFA.model import ConditionalSequenceGenerator 21 | from EntFA.utils import prepare_cmlm_inputs, prepare_mlm_inputs, get_probability_parallel 22 | 23 | 24 | def build_classifier(train_features, train_labels, n=30): 25 | classifier = neighbors.KNeighborsClassifier(n_neighbors=30, algorithm='auto') 26 | 27 | x_mat = np.array(train_features) 28 | stds = [np.std(x_mat[:, 0]), np.std(x_mat[:, 1]), np.std(x_mat[:, 2])] 29 | x_mat = np.vstack([x_mat[:, 0]/stds[0], x_mat[:, 1]/stds[1], x_mat[:, 2]/stds[2]]).transpose() 30 | y_vec = np.array(train_labels) 31 | classifier.fit(x_mat, y_vec) 32 | 33 | return classifier 34 | 35 | 36 | def infernece(test_features, classifier): 37 | """ 38 | Args: 39 | test_features (List[List]): [[prior, posterior, overlap_feature], ...] 40 | classifier: KNN classifier 41 | """ 42 | x_mat = np.array(test_features) 43 | stds = [np.std(x_mat[:, 0]), np.std(x_mat[:, 1]), np.std(x_mat[:, 2])] 44 | x_mat = np.vstack([x_mat[:, 0]/stds[0], x_mat[:, 1]/stds[1], x_mat[:, 2]/stds[2]]).transpose() 45 | Z = classifier.predict(x_mat) 46 | return Z 47 | 48 | 49 | def get_features(data_set, prior_model, model): 50 | label_mapping = { 51 | 'Non-hallucinated': 0, 52 | 'Factual Hallucination': 0, 53 | 'Non-factual Hallucination': 1 54 | } 55 | 56 | features, labels = [], [] 57 | for t in tqdm(data_set): 58 | source, prediction, entities = t['source'], t['prediction'], t['entities'] 59 | 60 | inputs = prepare_mlm_inputs(source, prediction, ent_parts=entities) 61 | priors = get_probability_parallel(prior_model, inputs[0], inputs[1], inputs[2], inputs[3], mask_filling=True) 62 | 63 | inputs = prepare_cmlm_inputs(source, prediction, ent_parts=entities) 64 | posteriors = get_probability_parallel(model, inputs[0], inputs[1], inputs[2], inputs[3]) 65 | 66 | overlaps = [1. if e['ent'].lower() in source.lower() else 0. for e in entities] 67 | assert len(priors) == len(posteriors) == len(overlaps) 68 | 69 | for i, e in enumerate(entities): 70 | if label_mapping.get(e['label'], -1) != -1: 71 | features.append((priors[i], posteriors[i], overlaps[i])) 72 | labels.append(label_mapping[e['label']]) 73 | 74 | return features, labels 75 | 76 | 77 | def main(args): 78 | # 1. load training & test dataset 79 | train_set = json.load(open(args.train_path, 'r')) 80 | 81 | # 2. load weights 82 | bart = BARTModel.from_pretrained(args.cmlm_model_path, 83 | checkpoint_file='checkpoint_best.pt', 84 | data_name_or_path=args.data_name_or_path) 85 | prior_bart = BARTModel.from_pretrained(args.mlm_path, 86 | checkpoint_file='model.pt', 87 | data_name_or_path=args.mlm_path) 88 | 89 | # 3. build model 90 | model = ConditionalSequenceGenerator(bart) 91 | prior_model = ConditionalSequenceGenerator(prior_bart) 92 | 93 | # 4. training 94 | train_features, train_labels = get_features(train_set, prior_model, model) 95 | classifier = build_classifier(train_features, train_labels, n=30) 96 | 97 | # 5. evaluation 98 | if args.test_path: 99 | test_set = json.load(open(args.test_path, 'r')) 100 | 101 | test_features, test_labels = get_features(test_set, prior_model, model) 102 | Z = infernece(test_features, classifier) 103 | 104 | print('accuracy: {:.4}\n\n'.format(accuracy_score(test_labels, Z))) 105 | print(classification_report(test_labels, Z, target_names=['Factual', 'Non-Factual'], digits=4)) 106 | 107 | # 6. save 108 | save_path = os.path.join(args.output_dir, 'knn_classifier.pkl') 109 | pickle.dump(classifier, open(save_path, 'wb')) 110 | print('- model is saved at: ', save_path) 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument( 116 | "--train_path", 117 | type=str, 118 | default=None, 119 | required=True, 120 | help="The path of the training set.", 121 | ) 122 | parser.add_argument( 123 | "--test_path", 124 | type=str, 125 | default=None, 126 | help="The path of the test set.", 127 | ) 128 | parser.add_argument( 129 | "--cmlm_model_path", 130 | type=str, 131 | required=True, 132 | ) 133 | parser.add_argument( 134 | "--data_name_or_path", 135 | type=str, 136 | required=True, 137 | ) 138 | parser.add_argument( 139 | "--mlm_path", 140 | type=str, 141 | required=True, 142 | ) 143 | parser.add_argument( 144 | "--output_dir", 145 | type=str, 146 | default='.', 147 | ) 148 | 149 | args = parser.parse_args() 150 | main(args) -------------------------------------------------------------------------------- /examples/train_knn_classifier.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "appointed-reconstruction", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import json\n", 11 | "import torch\n", 12 | "\n", 13 | "from tqdm import tqdm\n", 14 | "from fairseq.models.bart import BARTModel" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "id": "e459c6de-c22a-4077-a3f9-2570c12757c9", 20 | "metadata": {}, 21 | "source": [ 22 | "#### Load Annotated Dataset" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "id": "59897d15-4735-4cd4-9dab-5becdc587ffb", 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "train_set = json.load(open('../data/train.json', 'r'))\n", 33 | "test_set = json.load(open('../data/test.json', 'r'))" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "20d2f8aa-755a-42f7-b5f9-4e262b6cf120", 39 | "metadata": {}, 40 | "source": [ 41 | "#### Load Weights" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "id": "ac01838a-e509-4b17-91ed-d50e7c4ead07", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "CMLM_MODEL_PATH = 'BART_models/xsum_cedar_cmlm'\n", 52 | "MLM_MODEL_PATH = 'BART_models/bart.large'\n", 53 | "\n", 54 | "DATA_NAME_OR_PATH = 'summarization/XSum/fairseq_files/xsum-bin'" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "id": "assigned-exhibition", 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "2022-04-07 01:25:13 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/xsum_cedar_cmlm\n", 68 | "2022-04-07 01:25:13 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/summarization/XSum/fairseq_files/xsum-bin\n", 69 | "2022-04-07 01:25:22 | INFO | fairseq.tasks.translation | [source] dictionary: 50264 types\n", 70 | "2022-04-07 01:25:22 | INFO | fairseq.tasks.translation | [target] dictionary: 50264 types\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "bart = BARTModel.from_pretrained(CMLM_MODEL_PATH,\n", 76 | " checkpoint_file='checkpoint_best.pt',\n", 77 | " data_name_or_path=DATA_NAME_OR_PATH)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "id": "elementary-dutch", 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stderr", 88 | "output_type": "stream", 89 | "text": [ 90 | "2022-04-07 01:25:32 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/bart.large\n", 91 | "2022-04-07 01:25:32 | INFO | fairseq.file_utils | loading archive file /home/mila/c/caomeng/scratch/BART_models/bart.large\n", 92 | "2022-04-07 01:25:39 | INFO | fairseq.tasks.denoising | dictionary: 50264 types\n" 93 | ] 94 | } 95 | ], 96 | "source": [ 97 | "prior_bart = BARTModel.from_pretrained(MLM_MODEL_PATH,\n", 98 | " checkpoint_file='model.pt',\n", 99 | " data_name_or_path=MLM_MODEL_PATH)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "id": "adjusted-raise", 105 | "metadata": {}, 106 | "source": [ 107 | "#### Build Prior & Posterior Model" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 6, 113 | "id": "trained-township", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "from EntFA.model import ConditionalSequenceGenerator\n", 118 | "from EntFA.utils import prepare_cmlm_inputs, prepare_mlm_inputs, get_probability_parallel" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 7, 124 | "id": "varied-renewal", 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "model = ConditionalSequenceGenerator(bart)\n", 129 | "prior_model = ConditionalSequenceGenerator(prior_bart)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "id": "94283abf-8d88-4d55-b938-a8eafd86cc5f", 135 | "metadata": {}, 136 | "source": [ 137 | "#### Training" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 8, 143 | "id": "3a100d3c-65ed-4c04-98d4-4bca7509887a", 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "import numpy as np\n", 148 | "\n", 149 | "from sklearn import neighbors\n", 150 | "from sklearn.metrics import classification_report, f1_score, accuracy_score" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 9, 156 | "id": "f7614572-dff0-428b-9550-e11727986f16", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "def build_classifier(train_features, train_labels, n=30):\n", 161 | " classifier = neighbors.KNeighborsClassifier(n_neighbors=30, algorithm='auto')\n", 162 | " \n", 163 | " x_mat = np.array(train_features)\n", 164 | " stds = [np.std(x_mat[:, 0]), np.std(x_mat[:, 1]), np.std(x_mat[:, 2])]\n", 165 | " x_mat = np.vstack([x_mat[:, 0]/stds[0], x_mat[:, 1]/stds[1], x_mat[:, 2]/stds[2]]).transpose()\n", 166 | " y_vec = np.array(train_labels)\n", 167 | " classifier.fit(x_mat, y_vec)\n", 168 | " \n", 169 | " return classifier\n", 170 | "\n", 171 | "def infernece(test_features, classifier):\n", 172 | " \"\"\"\n", 173 | " Args:\n", 174 | " test_features (List[List]): [[prior, posterior, overlap_feature], ...]\n", 175 | " classifier: KNN classifier\n", 176 | " \"\"\"\n", 177 | " x_mat = np.array(test_features)\n", 178 | " stds = [np.std(x_mat[:, 0]), np.std(x_mat[:, 1]), np.std(x_mat[:, 2])]\n", 179 | " x_mat = np.vstack([x_mat[:, 0]/stds[0], x_mat[:, 1]/stds[1], x_mat[:, 2]/stds[2]]).transpose()\n", 180 | " Z = classifier.predict(x_mat)\n", 181 | " return Z" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": 10, 187 | "id": "618e8509-2e27-443a-bef0-b0f011873041", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "def get_features(data_set, prior_model, model):\n", 192 | " label_mapping = {\n", 193 | " 'Non-hallucinated': 0,\n", 194 | " 'Factual Hallucination': 0,\n", 195 | " 'Non-factual Hallucination': 1\n", 196 | " }\n", 197 | "\n", 198 | " features, labels = [], []\n", 199 | " for t in tqdm(data_set):\n", 200 | " source, prediction, entities = t['source'], t['prediction'], t['entities']\n", 201 | "\n", 202 | " inputs = prepare_mlm_inputs(source, prediction, ent_parts=entities)\n", 203 | " priors = get_probability_parallel(prior_model, inputs[0], inputs[1], inputs[2], inputs[3], mask_filling=True)\n", 204 | "\n", 205 | " inputs = prepare_cmlm_inputs(source, prediction, ent_parts=entities)\n", 206 | " posteriors = get_probability_parallel(model, inputs[0], inputs[1], inputs[2], inputs[3])\n", 207 | "\n", 208 | " overlaps = [1. if e['ent'].lower() in source.lower() else 0. for e in entities]\n", 209 | " assert len(priors) == len(posteriors) == len(overlaps)\n", 210 | "\n", 211 | " for i, e in enumerate(entities):\n", 212 | " if label_mapping.get(e['label'], -1) != -1:\n", 213 | " features.append((priors[i], posteriors[i], overlaps[i]))\n", 214 | " labels.append(label_mapping[e['label']])\n", 215 | "\n", 216 | " return features, labels" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 11, 222 | "id": "904efb35-da2b-4c56-add4-e4e618722a10", 223 | "metadata": {}, 224 | "outputs": [ 225 | { 226 | "name": "stderr", 227 | "output_type": "stream", 228 | "text": [ 229 | "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 460/460 [00:49<00:00, 9.22it/s]\n" 230 | ] 231 | } 232 | ], 233 | "source": [ 234 | "train_features, train_labels = get_features(train_set, prior_model, model)\n", 235 | "classifier = build_classifier(train_features, train_labels, n=30)" 236 | ] 237 | }, 238 | { 239 | "cell_type": "markdown", 240 | "id": "e77027fa-c1ca-4a42-a441-e60d70a3fbef", 241 | "metadata": {}, 242 | "source": [ 243 | "#### Evaluation" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 12, 249 | "id": "a938bec0-738e-4321-a821-a01696256a59", 250 | "metadata": {}, 251 | "outputs": [ 252 | { 253 | "name": "stderr", 254 | "output_type": "stream", 255 | "text": [ 256 | "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [00:25<00:00, 9.52it/s]\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "test_features, test_labels = get_features(test_set, prior_model, model)\n", 262 | "Z = infernece(test_features, classifier)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 13, 268 | "id": "a4b1d32c-fcab-4b0d-9128-f6d4cd1a1964", 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | "name": "stdout", 273 | "output_type": "stream", 274 | "text": [ 275 | "accuracy: 0.9102\n", 276 | "\n", 277 | "\n", 278 | " precision recall f1-score support\n", 279 | "\n", 280 | " Factual 0.9323 0.9629 0.9474 701\n", 281 | " Non-Factual 0.7658 0.6343 0.6939 134\n", 282 | "\n", 283 | " accuracy 0.9102 835\n", 284 | " macro avg 0.8490 0.7986 0.8206 835\n", 285 | "weighted avg 0.9056 0.9102 0.9067 835\n", 286 | "\n" 287 | ] 288 | } 289 | ], 290 | "source": [ 291 | "print('accuracy: {:.4}\\n\\n'.format(accuracy_score(test_labels, Z)))\n", 292 | "print(classification_report(test_labels, Z, target_names=['Factual', 'Non-Factual'], digits=4))" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "id": "b8fe2f43-19bb-48b3-9bbe-b6b778cb802a", 298 | "metadata": {}, 299 | "source": [ 300 | "#### Save Classifier" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": 14, 306 | "id": "e3747c5e-4175-493b-a395-6b7ab3d77d10", 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "import pickle\n", 311 | "\n", 312 | "pickle.dump(classifier, open('knn_classifier.pkl', 'wb'))" 313 | ] 314 | } 315 | ], 316 | "metadata": { 317 | "kernelspec": { 318 | "display_name": "Python 3 (ipykernel)", 319 | "language": "python", 320 | "name": "python3" 321 | }, 322 | "language_info": { 323 | "codemirror_mode": { 324 | "name": "ipython", 325 | "version": 3 326 | }, 327 | "file_extension": ".py", 328 | "mimetype": "text/x-python", 329 | "name": "python", 330 | "nbconvert_exporter": "python", 331 | "pygments_lexer": "ipython3", 332 | "version": "3.8.2" 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 5 337 | } 338 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | Download the model checkpoints and put them in this directory. CMLM and $K$-NN classifier checkpoint can be found [here](https://drive.google.com/drive/folders/10ibVc5R7q4Gc0TH1AIRo7IaLCV83SkpF?usp=sharing). For MLM, download `bart.large` at Fairseq's [BART](https://github.com/pytorch/fairseq/tree/main/examples/bart) repository. 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.1 2 | fairseq==0.10.2 3 | transformers>=4.17.0 4 | spacy 5 | -------------------------------------------------------------------------------- /rl_training/rl_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 6 | 7 | from transformers.file_utils import PaddingStrategy 8 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 9 | 10 | 11 | def sequence_mask(lengths, max_len=None, dtype=None, device=None) : 12 | r"""Return a mask tensor representing the first N positions of each cell. 13 | If ``lengths`` has shape ``[d_1, d_2, ..., d_n]`` the resulting tensor 14 | ``mask`` has dtype ``dtype`` and shape ``[d_1, d_2, ..., d_n, maxlen]``, 15 | with 16 | ``` 17 | mask[i_1, i_2, ..., i_n, j] = (j < lengths[i_1, i_2, ..., i_n]) 18 | ``` 19 | Examples: 20 | ```python 21 | sequence_mask([1, 3, 2], 5) # [[True, False, False, False, False], 22 | # [True, True, True, False, False], 23 | # [True, True, False, False, False]] 24 | sequence_mask([[1, 3],[2,0]]) # [[[ True, False, False], 25 | # [ True, True, True]], 26 | # [[ True, True, False], 27 | # [False, False, False]]] 28 | ``` 29 | Args: 30 | lengths: integer tensor or list of int, all its values <= max_len. 31 | max_len: scalar integer tensor, size of last dimension of returned 32 | tensor. Default is the maximum value in ``lengths``. 33 | dtype: the desired data type of returned tensor. Default: if None, 34 | returns :torch:`ByteTensor`. 35 | device: the desired device of returned tensor. Default: if None, uses 36 | the current device for the default tensor type. 37 | Returns: 38 | A mask tensor of shape :python:`lengths.shape + (max_len,)`, cast to 39 | specified dtype. 40 | Raises: 41 | ValueError: if ``max_len`` is not a scalar. 42 | """ 43 | if not isinstance(lengths, torch.Tensor): 44 | lengths = torch.tensor(lengths, device=device) 45 | elif device is None: 46 | device = lengths.device 47 | lengths: torch.LongTensor 48 | if max_len is None: 49 | max_len = torch.max(lengths).item() 50 | 51 | size = lengths.size() 52 | row_vector = torch.arange(max_len, device=device, dtype=lengths.dtype).view( 53 | *([1] * len(size)), -1).expand(*size, max_len) 54 | mask = (row_vector < lengths.unsqueeze(-1)).to(device=device) 55 | if dtype is not None: 56 | mask = mask.to(dtype=dtype) 57 | 58 | return mask 59 | 60 | 61 | def masked_reverse_cumsum(X, lengths, dim): 62 | """ 63 | Args: 64 | X (Tensor): [batch_size, max_tgt_len] 65 | lengths (Tensor): [batch_size] 66 | dim (int): -1 67 | gamma (float): the discount factor 68 | 69 | """ 70 | masked_X = X * sequence_mask(lengths, max_len=X.shape[1]) 71 | return (masked_X 72 | .flip(dims=[dim]) 73 | .cumsum(dim=dim) 74 | .flip(dims=[dim])) 75 | 76 | 77 | def discounted_future_sum(values, lengths, num_steps=None, gamma=1.0): 78 | """ 79 | Args: 80 | values (Tensor): reward values with size [batch_size, max_tgt_len] 81 | lengths (Tensor): target sequence length with size [batch_size] 82 | num_steps (int): number of future steps to sum over. 83 | gamma (float): discount value. 84 | 85 | Return: 86 | output (Tensor): [batch_size, max_tgt_len] 87 | """ 88 | assert values.dim() == 2 89 | 90 | batch_size, total_steps = values.shape 91 | values = values * sequence_mask(lengths, max_len=values.shape[1]) 92 | 93 | num_steps = total_steps if num_steps is None else num_steps 94 | num_steps = min(num_steps, total_steps) 95 | 96 | padding = torch.zeros([batch_size, num_steps - 1]).to(values) 97 | padded_values = torch.cat([values, padding], 1) 98 | discount_filter = gamma ** torch.arange(num_steps).to(values).reshape(1, 1, -1) 99 | 100 | output = F.conv1d(padded_values.unsqueeze(-2), discount_filter).squeeze(1) 101 | return output 102 | 103 | 104 | def polyak_update(model, tgt_model, target_lr): 105 | for param_, param in zip(tgt_model.parameters(), model.parameters()): 106 | param_.data.copy_((1 - target_lr) * param_ + target_lr * param) 107 | 108 | 109 | @dataclass 110 | class DataCollatorForSeq2Seq: 111 | """ 112 | Data collator that will dynamically pad the inputs received, as well as the labels. 113 | Args: 114 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): 115 | The tokenizer used for encoding the data. 116 | model ([`PreTrainedModel`]): 117 | The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to 118 | prepare the *decoder_input_ids* 119 | This is useful when using *label_smoothing* to avoid calculating loss twice. 120 | padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `True`): 121 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 122 | among: 123 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence 124 | is provided). 125 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 126 | acceptable input length for the model if that argument is not provided. 127 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 128 | lengths). 129 | max_length (`int`, *optional*): 130 | Maximum length of the returned list and optionally padding length (see above). 131 | pad_to_multiple_of (`int`, *optional*): 132 | If set will pad the sequence to a multiple of the provided value. 133 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 134 | 7.5 (Volta). 135 | label_pad_token_id (`int`, *optional*, defaults to -100): 136 | The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). 137 | return_tensors (`str`): 138 | The type of Tensor to return. Allowable values are "np", "pt" and "tf". 139 | """ 140 | 141 | tokenizer: PreTrainedTokenizerBase 142 | model: Optional[Any] = None 143 | padding: Union[bool, str, PaddingStrategy] = True 144 | max_length: Optional[int] = None 145 | pad_to_multiple_of: Optional[int] = None 146 | label_pad_token_id: int = -100 147 | reward_pad_value: float = 0.0 148 | return_tensors: str = "pt" 149 | 150 | def __call__(self, features, return_tensors=None): 151 | if return_tensors is None: 152 | return_tensors = self.return_tensors 153 | labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None 154 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the 155 | # same length to return tensors. 156 | if labels is not None: 157 | max_label_length = max(len(l) for l in labels) 158 | if self.pad_to_multiple_of is not None: 159 | max_label_length = ( 160 | (max_label_length + self.pad_to_multiple_of - 1) 161 | // self.pad_to_multiple_of 162 | * self.pad_to_multiple_of 163 | ) 164 | 165 | padding_side = self.tokenizer.padding_side 166 | for feature in features: 167 | remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) 168 | if isinstance(feature["labels"], list): 169 | feature["labels"] = ( 170 | feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] 171 | ) 172 | elif padding_side == "right": 173 | feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64) 174 | else: 175 | feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) 176 | 177 | rewards = [feature["rewards"] for feature in features] if "rewards" in features[0].keys() else None 178 | if rewards is not None: 179 | max_reward_length = max(len(l) for l in rewards) 180 | if self.pad_to_multiple_of is not None: 181 | max_reward_length = ( 182 | (max_reward_length + self.pad_to_multiple_of - 1) 183 | // self.pad_to_multiple_of 184 | * self.pad_to_multiple_of 185 | ) 186 | 187 | padding_side = self.tokenizer.padding_side 188 | for feature in features: 189 | remainder = [self.reward_pad_value] * (max_reward_length - len(feature["rewards"])) 190 | if isinstance(feature["rewards"], list): 191 | feature["rewards"] = ( 192 | feature["rewards"] + remainder if padding_side == "right" else remainder + feature["rewards"] 193 | ) 194 | elif padding_side == "right": 195 | feature["rewards"] = np.concatenate([feature["rewards"], remainder]) 196 | else: 197 | feature["rewards"] = np.concatenate([remainder, feature["rewards"]]) 198 | 199 | features = self.tokenizer.pad( 200 | features, 201 | padding=self.padding, 202 | max_length=self.max_length, 203 | pad_to_multiple_of=self.pad_to_multiple_of, 204 | return_tensors=return_tensors, 205 | ) 206 | 207 | # prepare decoder_input_ids 208 | if ( 209 | labels is not None 210 | and self.model is not None 211 | and hasattr(self.model, "prepare_decoder_input_ids_from_labels") 212 | ): 213 | decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"]) 214 | features["decoder_input_ids"] = decoder_input_ids 215 | 216 | return features -------------------------------------------------------------------------------- /rl_training/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning a 🤗 Transformers model on summarization. 18 | """ 19 | # You can also adapt this script on your own summarization task. Pointers for this are left as comments. 20 | 21 | import argparse 22 | import logging 23 | import math 24 | import os 25 | import random 26 | from pathlib import Path 27 | from copy import deepcopy 28 | 29 | import datasets 30 | import nltk 31 | import numpy as np 32 | import torch 33 | import torch.nn.functional as F 34 | from datasets import load_dataset, load_metric 35 | from torch.utils.data import DataLoader 36 | from torch.nn import CrossEntropyLoss 37 | from tqdm.auto import tqdm 38 | 39 | import transformers 40 | from accelerate import Accelerator, DeepSpeedPlugin 41 | from filelock import FileLock 42 | from huggingface_hub import Repository 43 | from transformers import ( 44 | CONFIG_MAPPING, 45 | MODEL_MAPPING, 46 | AdamW, 47 | AutoConfig, 48 | AutoModelForSeq2SeqLM, 49 | AutoTokenizer, 50 | SchedulerType, 51 | get_scheduler, 52 | set_seed, 53 | ) 54 | from transformers.file_utils import is_offline_mode 55 | from transformers.utils.versions import require_version 56 | 57 | from rl_utils import DataCollatorForSeq2Seq, discounted_future_sum, polyak_update 58 | 59 | logger = logging.getLogger(__name__) 60 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 61 | 62 | # You should update this to your particular problem to have better documentation of `model_type` 63 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 64 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 65 | 66 | try: 67 | nltk.data.find("tokenizers/punkt") 68 | except (LookupError, OSError): 69 | if is_offline_mode(): 70 | raise LookupError( 71 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 72 | ) 73 | with FileLock(".lock") as lock: 74 | nltk.download("punkt", quiet=True) 75 | 76 | summarization_name_mapping = { 77 | "amazon_reviews_multi": ("review_body", "review_title"), 78 | "big_patent": ("description", "abstract"), 79 | "cnn_dailymail": ("article", "highlights"), 80 | "orange_sum": ("text", "summary"), 81 | "pn_summary": ("article", "summary"), 82 | "psc": ("extract_text", "summary_text"), 83 | "samsum": ("dialogue", "summary"), 84 | "thaisum": ("body", "summary"), 85 | "xglue": ("news_body", "news_title"), 86 | "xsum": ("document", "summary"), 87 | "wiki_summary": ("article", "highlights"), 88 | } 89 | 90 | 91 | def parse_args(): 92 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a summarization task") 93 | parser.add_argument( 94 | "--dataset_name", 95 | type=str, 96 | default=None, 97 | help="The name of the dataset to use (via the datasets library).", 98 | ) 99 | parser.add_argument( 100 | "--dataset_config_name", 101 | type=str, 102 | default=None, 103 | help="The configuration name of the dataset to use (via the datasets library).", 104 | ) 105 | parser.add_argument( 106 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data." 107 | ) 108 | parser.add_argument( 109 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." 110 | ) 111 | parser.add_argument( 112 | "--ignore_pad_token_for_loss", 113 | type=bool, 114 | default=True, 115 | help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.", 116 | ) 117 | parser.add_argument( 118 | "--max_source_length", 119 | type=int, 120 | default=1024, 121 | help="The maximum total input sequence length after " 122 | "tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.", 123 | ) 124 | parser.add_argument( 125 | "--source_prefix", 126 | type=str, 127 | default=None, 128 | help="A prefix to add before every source text " "(useful for T5 models).", 129 | ) 130 | parser.add_argument( 131 | "--target_prefix", 132 | type=str, 133 | default=None, 134 | help="A prefix to add before every target text", 135 | ) 136 | parser.add_argument( 137 | "--preprocessing_num_workers", 138 | type=int, 139 | default=None, 140 | help="The number of processes to use for the preprocessing.", 141 | ) 142 | parser.add_argument( 143 | "--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets" 144 | ) 145 | parser.add_argument( 146 | "--max_target_length", 147 | type=int, 148 | default=128, 149 | help="The maximum total sequence length for target text after " 150 | "tokenization. Sequences longer than this will be truncated, sequences shorter will be padded." 151 | "during ``evaluate`` and ``predict``.", 152 | ) 153 | parser.add_argument( 154 | "--val_max_target_length", 155 | type=int, 156 | default=None, 157 | help="The maximum total sequence length for validation " 158 | "target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be " 159 | "padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` " 160 | "param of ``model.generate``, which is used during ``evaluate`` and ``predict``.", 161 | ) 162 | parser.add_argument( 163 | "--max_length", 164 | type=int, 165 | default=128, 166 | help=( 167 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 168 | " sequences shorter will be padded if `--pad_to_max_lengh` is passed." 169 | ), 170 | ) 171 | parser.add_argument( 172 | "--num_beams", 173 | type=int, 174 | default=None, 175 | help="Number of beams to use for evaluation. This argument will be " 176 | "passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.", 177 | ) 178 | parser.add_argument( 179 | "--pad_to_max_length", 180 | action="store_true", 181 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 182 | ) 183 | parser.add_argument( 184 | "--model_name_or_path", 185 | type=str, 186 | help="Path to pretrained model or model identifier from huggingface.co/models.", 187 | required=True, 188 | ) 189 | parser.add_argument( 190 | "--config_name", 191 | type=str, 192 | default=None, 193 | help="Pretrained config name or path if not the same as model_name", 194 | ) 195 | parser.add_argument( 196 | "--tokenizer_name", 197 | type=str, 198 | default=None, 199 | help="Pretrained tokenizer name or path if not the same as model_name", 200 | ) 201 | parser.add_argument( 202 | "--text_column", 203 | type=str, 204 | default=None, 205 | help="The name of the column in the datasets containing the full texts (for summarization).", 206 | ) 207 | parser.add_argument( 208 | "--summary_column", 209 | type=str, 210 | default=None, 211 | help="The name of the column in the datasets containing the summaries (for summarization).", 212 | ) 213 | parser.add_argument( 214 | "--use_slow_tokenizer", 215 | action="store_true", 216 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 217 | ) 218 | parser.add_argument( 219 | "--per_device_train_batch_size", 220 | type=int, 221 | default=8, 222 | help="Batch size (per device) for the training dataloader.", 223 | ) 224 | parser.add_argument( 225 | "--per_device_eval_batch_size", 226 | type=int, 227 | default=8, 228 | help="Batch size (per device) for the evaluation dataloader.", 229 | ) 230 | parser.add_argument( 231 | "--learning_rate", 232 | type=float, 233 | default=5e-5, 234 | help="Initial learning rate (after the potential warmup period) to use.", 235 | ) 236 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 237 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") 238 | parser.add_argument( 239 | "--max_train_steps", 240 | type=int, 241 | default=None, 242 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 243 | ) 244 | parser.add_argument( 245 | "--gradient_accumulation_steps", 246 | type=int, 247 | default=1, 248 | help="Number of updates steps to accumulate before performing a backward/update pass.", 249 | ) 250 | parser.add_argument( 251 | "--lr_scheduler_type", 252 | type=SchedulerType, 253 | default="linear", 254 | help="The scheduler type to use.", 255 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 256 | ) 257 | parser.add_argument( 258 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 259 | ) 260 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 261 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 262 | parser.add_argument( 263 | "--model_type", 264 | type=str, 265 | default=None, 266 | help="Model type to use if training from scratch.", 267 | choices=MODEL_TYPES, 268 | ) 269 | parser.add_argument( 270 | "--cache_dir", 271 | type=str, 272 | default="~/.cache/huggingface/datasets", 273 | help="Cache directory for datasets." 274 | ) 275 | parser.add_argument( 276 | "--eval", 277 | action="store_true", 278 | help="Only run evaluation." 279 | ) 280 | parser.add_argument( 281 | "--num_steps", 282 | type=int, 283 | default=5, 284 | help="Number of steps for estimating Q function." 285 | ) 286 | parser.add_argument( 287 | "--gamma", 288 | type=float, 289 | default=0.99, 290 | help="The discount factor for Q.", 291 | ) 292 | parser.add_argument( 293 | "--polyak_update_lr", 294 | type=float, 295 | default=0.01, 296 | help="The polyak updating rate for target model.", 297 | ) 298 | args = parser.parse_args() 299 | 300 | # Sanity checks 301 | if args.dataset_name is None and args.train_file is None and args.validation_file is None: 302 | raise ValueError("Need either a dataset name or a training/validation file.") 303 | else: 304 | if args.train_file is not None: 305 | extension = args.train_file.split(".")[-1] 306 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 307 | if args.validation_file is not None: 308 | extension = args.validation_file.split(".")[-1] 309 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 310 | 311 | return args 312 | 313 | 314 | def get_raw_dataset(args): 315 | """ 316 | Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 317 | or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 318 | (the dataset will be downloaded automatically from the datasets Hub). 319 | 320 | For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 321 | 'text' is found. You can easily tweak this behavior (see below). 322 | 323 | In distributed training, the load_dataset function guarantee that only one local process can concurrently 324 | download the dataset. 325 | 326 | Returns: 327 | :class:`Dataset` 328 | """ 329 | if args.dataset_name is not None: 330 | # Downloading and loading a dataset from the hub. 331 | raw_datasets = load_dataset( 332 | args.dataset_name, 333 | args.dataset_config_name, 334 | cache_dir=args.cache_dir 335 | ) 336 | else: 337 | data_files = {} 338 | if args.train_file is not None: 339 | data_files["train"] = args.train_file 340 | if args.validation_file is not None: 341 | data_files["validation"] = args.validation_file 342 | extension = args.train_file.split(".")[-1] 343 | raw_datasets = load_dataset(extension, data_files=data_files) 344 | 345 | return raw_datasets 346 | 347 | 348 | def load_pretrained_model_and_tokenizer( 349 | model_name_or_path, 350 | config_name, 351 | tokenizer_name, 352 | model_type=None, 353 | use_slow_tokenizer=False 354 | ): 355 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 356 | # download model & vocab. 357 | if config_name: 358 | config = AutoConfig.from_pretrained(config_name) 359 | elif model_name_or_path: 360 | config = AutoConfig.from_pretrained(model_name_or_path) 361 | else: 362 | config = CONFIG_MAPPING[model_type]() 363 | logger.warning("You are instantiating a new config instance from scratch.") 364 | 365 | if tokenizer_name: 366 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=not use_slow_tokenizer) 367 | elif model_name_or_path: 368 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=not use_slow_tokenizer) 369 | else: 370 | raise ValueError( 371 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 372 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 373 | ) 374 | 375 | if model_name_or_path: 376 | model = AutoModelForSeq2SeqLM.from_pretrained( 377 | model_name_or_path, 378 | from_tf=bool(".ckpt" in model_name_or_path), 379 | config=config, 380 | ) 381 | else: 382 | logger.info("Training new model from scratch") 383 | model = AutoModelForSeq2SeqLM.from_config(config) 384 | 385 | return config, tokenizer, model 386 | 387 | 388 | def get_column_names(args, column_names): 389 | """Get the column names for input/target.""" 390 | dataset_columns = summarization_name_mapping.get(args.dataset_name, None) 391 | if args.text_column is None: 392 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 393 | else: 394 | text_column = args.text_column 395 | if text_column not in column_names: 396 | raise ValueError( 397 | f"--text_column' value '{args.text_column}' needs to be one of: {', '.join(column_names)}" 398 | ) 399 | if args.summary_column is None: 400 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 401 | else: 402 | summary_column = args.summary_column 403 | if summary_column not in column_names: 404 | raise ValueError( 405 | f"--summary_column' value '{args.summary_column}' needs to be one of: {', '.join(column_names)}" 406 | ) 407 | return text_column, summary_column 408 | 409 | 410 | def process_raw_dataset(args, accelerator, raw_datasets, tokenizer): 411 | # First we tokenize all the texts. 412 | column_names = raw_datasets["train"].column_names # xsum: ['document', 'summary', 'reward'] 413 | text_column, summary_column = get_column_names(args, column_names) 414 | reward_column = 'reward' 415 | 416 | # Temporarily set max_target_length for training. 417 | max_target_length = args.max_target_length 418 | padding = "max_length" if args.pad_to_max_length else False 419 | 420 | prefix = args.source_prefix if args.source_prefix is not None else "" 421 | target_prefix = args.target_prefix if args.target_prefix is not None else "" 422 | 423 | def preprocess_function(examples): 424 | inputs = examples[text_column] 425 | targets = examples[summary_column] 426 | rewards = examples[reward_column] if reward_column in examples.keys() else None 427 | 428 | inputs = [prefix + inp for inp in inputs] 429 | model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True) 430 | 431 | # Setup the tokenizer for targets 432 | targets = [target_prefix + inp for inp in targets] 433 | with tokenizer.as_target_tokenizer(): 434 | labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True) 435 | 436 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 437 | # padding in the loss. 438 | if padding == "max_length" and args.ignore_pad_token_for_loss: 439 | labels["input_ids"] = [ 440 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 441 | ] 442 | 443 | model_inputs["labels"] = labels["input_ids"] 444 | model_inputs["rewards"] = rewards 445 | 446 | return model_inputs 447 | 448 | with accelerator.main_process_first(): 449 | processed_datasets = raw_datasets.map( 450 | preprocess_function, 451 | batched=True, 452 | num_proc=args.preprocessing_num_workers, 453 | remove_columns=column_names, 454 | load_from_cache_file=not args.overwrite_cache, 455 | desc="Running tokenizer on dataset", 456 | ) 457 | 458 | return processed_datasets 459 | 460 | 461 | def setup_optimizer(args, model): 462 | # Split weights in two groups, one with weight decay and the other not. 463 | no_decay = ["bias", "LayerNorm.weight"] 464 | optimizer_grouped_parameters = [ 465 | { 466 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 467 | "weight_decay": args.weight_decay, 468 | }, 469 | { 470 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 471 | "weight_decay": 0.0, 472 | }, 473 | ] 474 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 475 | return optimizer 476 | 477 | 478 | def postprocess_text(preds, labels): 479 | preds = [pred.strip() for pred in preds] 480 | labels = [label.strip() for label in labels] 481 | 482 | # rougeLSum expects newline after each sentence 483 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 484 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 485 | 486 | return preds, labels 487 | 488 | 489 | def eval(args, accelerator, model, tokenizer, eval_dataloader, metric): 490 | model.eval() 491 | if args.val_max_target_length is None: 492 | args.val_max_target_length = args.max_target_length 493 | 494 | gen_kwargs = { 495 | "max_length": args.val_max_target_length if args is not None else config.max_length, 496 | "num_beams": args.num_beams, 497 | } 498 | for step, batch in enumerate(eval_dataloader): 499 | with torch.no_grad(): 500 | generated_tokens = accelerator.unwrap_model(model).generate( 501 | batch["input_ids"], 502 | attention_mask=batch["attention_mask"], 503 | **gen_kwargs, 504 | ) 505 | 506 | generated_tokens = accelerator.pad_across_processes( 507 | generated_tokens, dim=1, pad_index=tokenizer.pad_token_id 508 | ) 509 | labels = batch["labels"] 510 | if not args.pad_to_max_length: 511 | # If we did not pad to max length, we need to pad the labels too 512 | labels = accelerator.pad_across_processes( 513 | batch["labels"], dim=1, pad_index=tokenizer.pad_token_id 514 | ) 515 | 516 | generated_tokens = accelerator.gather(generated_tokens).cpu().numpy() 517 | labels = accelerator.gather(labels).cpu().numpy() 518 | 519 | if args.ignore_pad_token_for_loss: 520 | # Replace -100 in the labels as we can't decode them. 521 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 522 | if isinstance(generated_tokens, tuple): 523 | generated_tokens = generated_tokens[0] 524 | 525 | decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) 526 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 527 | 528 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 529 | 530 | metric.add_batch(predictions=decoded_preds, references=decoded_labels) 531 | 532 | 533 | def main(): 534 | args = parse_args() 535 | 536 | if args.source_prefix is None and args.model_name_or_path in [ 537 | "t5-small", 538 | "t5-base", 539 | "t5-large", 540 | "t5-3b", 541 | "t5-11b", 542 | ]: 543 | logger.warning( 544 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 545 | "`--source_prefix 'summarize: ' `" 546 | ) 547 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 548 | accelerator = Accelerator() 549 | # deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=args.gradient_accumulation_steps) 550 | # accelerator = Accelerator(fp16=True, deepspeed_plugin=deepspeed_plugin) 551 | 552 | # Make one log on every process with the configuration for debugging. 553 | logging.basicConfig( 554 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 555 | datefmt="%m/%d/%Y %H:%M:%S", 556 | level=logging.INFO, 557 | ) 558 | logger.info(accelerator.state) 559 | 560 | # Setup logging, we only want one process per machine to log things on the screen. 561 | # accelerator.is_local_main_process is only True for one process per machine. 562 | logger.setLevel(logging.INFO if accelerator.is_local_main_process else logging.ERROR) 563 | if accelerator.is_local_main_process: 564 | datasets.utils.logging.set_verbosity_warning() 565 | transformers.utils.logging.set_verbosity_info() 566 | else: 567 | datasets.utils.logging.set_verbosity_error() 568 | transformers.utils.logging.set_verbosity_error() 569 | 570 | # If passed along, set the training seed now. 571 | if args.seed is not None: 572 | set_seed(args.seed) 573 | 574 | # Handle the repository creation 575 | if accelerator.is_main_process: 576 | if args.output_dir is not None: 577 | os.makedirs(args.output_dir, exist_ok=True) 578 | accelerator.wait_for_everyone() 579 | 580 | # Load pretrained model and tokenizer 581 | config, tokenizer, model = load_pretrained_model_and_tokenizer( 582 | args.model_name_or_path, 583 | args.config_name, 584 | args.tokenizer_name, 585 | model_type=args.model_type, 586 | use_slow_tokenizer=args.use_slow_tokenizer 587 | ) 588 | 589 | model.resize_token_embeddings(len(tokenizer)) 590 | if model.config.decoder_start_token_id is None: 591 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 592 | 593 | tgt_model = deepcopy(model) 594 | 595 | # Get the raw dataset 596 | raw_datasets = get_raw_dataset(args) 597 | 598 | # Preprocessing the datasets. 599 | processed_datasets = process_raw_dataset(args, accelerator, raw_datasets, tokenizer) 600 | train_dataset = processed_datasets["train"] 601 | eval_dataset = processed_datasets["validation"] 602 | 603 | # Log a few random samples from the training set: 604 | for index in random.sample(range(len(train_dataset)), 1): 605 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 606 | 607 | label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id 608 | data_collator = DataCollatorForSeq2Seq( 609 | tokenizer, 610 | model=model, 611 | label_pad_token_id=label_pad_token_id, 612 | pad_to_multiple_of=8 if accelerator.use_fp16 else None, 613 | ) 614 | 615 | train_dataloader = DataLoader( 616 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 617 | ) 618 | eval_dataloader = DataLoader( 619 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 620 | ) 621 | 622 | # Prepare optimizer 623 | optimizer = setup_optimizer(args, model) 624 | 625 | # Prepare everything with our `accelerator`. 626 | model, tgt_model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare( 627 | model, tgt_model, optimizer, train_dataloader, eval_dataloader 628 | ) 629 | 630 | # Note -> the training dataloader needs to be prepared before we grab his length below (cause its length will be 631 | # shorter in multiprocess) 632 | 633 | # Scheduler and math around the number of training steps. 634 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 635 | if args.max_train_steps is None: 636 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 637 | else: 638 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 639 | 640 | lr_scheduler = get_scheduler( 641 | name=args.lr_scheduler_type, 642 | optimizer=optimizer, 643 | num_warmup_steps=args.num_warmup_steps, 644 | num_training_steps=args.max_train_steps, 645 | ) 646 | 647 | # Metric 648 | metric = load_metric("rouge") 649 | 650 | if args.eval: 651 | # Evaluation only 652 | logger.info("***** Running evaluation *****") 653 | eval(args, accelerator, model, tokenizer, eval_dataloader, metric) 654 | 655 | # Extract a few results from ROUGE 656 | result = metric.compute(use_stemmer=True) 657 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 658 | result = {k: round(v, 4) for k, v in result.items()} 659 | 660 | logger.info(result) 661 | else: 662 | # Train! 663 | total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 664 | 665 | logger.info("***** Running training *****") 666 | logger.info(f" Num examples = {len(train_dataset)}") 667 | logger.info(f" Num Epochs = {args.num_train_epochs}") 668 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 669 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 670 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 671 | logger.info(f" Total optimization steps = {args.max_train_steps}") 672 | # Only show the progress bar once on each machine. 673 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 674 | completed_steps = 0 675 | 676 | for epoch in range(args.num_train_epochs): 677 | model.train() 678 | tgt_model.eval() 679 | for step, batch in enumerate(train_dataloader): 680 | # batch: ['input_ids', 'attention_mask', 'labels', 'rewards', 'decoder_input_ids'] 681 | batch_size, tgt_len = batch['labels'].shape 682 | outputs = model(**{ 683 | 'input_ids': batch['input_ids'], 684 | 'attention_mask': batch['attention_mask'], 685 | 'decoder_input_ids': batch['decoder_input_ids'] 686 | }) 687 | 688 | loss_fct = CrossEntropyLoss(reduction='none') 689 | loss = loss_fct(outputs.logits.view(-1, config.vocab_size), batch['labels'].view(-1)) 690 | loss = loss.view(batch_size, tgt_len) 691 | 692 | # calculate estimated Q value 693 | seq_lens = torch.sum(batch['labels'] != -100, dim=-1) 694 | Q = discounted_future_sum(batch['rewards'], seq_lens, num_steps=args.num_steps, gamma=args.gamma).detach() 695 | 696 | # calculate importance sampling ratio 697 | with torch.no_grad(): 698 | tgt_outputs = tgt_model(**{ 699 | 'input_ids': batch['input_ids'], 700 | 'attention_mask': batch['attention_mask'], 701 | 'decoder_input_ids': batch['decoder_input_ids'] 702 | }) 703 | tgt_logits = tgt_outputs.logits.view(-1, config.vocab_size) # [batch * tgt_len, vocab_size] 704 | 705 | tgt_nll = loss_fct(tgt_logits, batch['labels'].view(-1)) 706 | sampling_ratio = torch.exp(-tgt_nll.view(batch_size, tgt_len)) # [batch, tgt_len] 707 | 708 | assert sampling_ratio.shape == Q.shape == loss.shape 709 | loss = (sampling_ratio * Q) * loss 710 | loss = loss.sum() / seq_lens.sum() 711 | 712 | loss = loss / args.gradient_accumulation_steps 713 | accelerator.backward(loss) 714 | if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: 715 | optimizer.step() 716 | lr_scheduler.step() 717 | optimizer.zero_grad() 718 | progress_bar.update(1) 719 | completed_steps += 1 720 | 721 | # update target model 722 | polyak_update(model, tgt_model, args.polyak_update_lr) 723 | 724 | if completed_steps >= args.max_train_steps: 725 | break 726 | 727 | # Run evaluation 728 | logger.info("***** Running evaluation *****") 729 | eval(args, accelerator, model, tokenizer, eval_dataloader, metric) 730 | 731 | # Extract a few results from ROUGE 732 | result = metric.compute(use_stemmer=True) 733 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 734 | result = {k: round(v, 4) for k, v in result.items()} 735 | 736 | logger.info(result) 737 | 738 | # save the model after each epoch of training 739 | if args.output_dir is not None: 740 | epoch_output_dir = os.path.join(args.output_dir, '{}/'.format(epoch)) 741 | accelerator.wait_for_everyone() 742 | unwrapped_model = accelerator.unwrap_model(model) 743 | unwrapped_model.save_pretrained(epoch_output_dir, save_function=accelerator.save) 744 | if accelerator.is_main_process: 745 | tokenizer.save_pretrained(epoch_output_dir) 746 | 747 | 748 | if __name__ == "__main__": 749 | main() 750 | -------------------------------------------------------------------------------- /rl_training/train_rl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_NAME_OR_PATH=/huggingface/bart-large 4 | OUTPUT_DIR=/BART_HF_models/TEST/ 5 | 6 | accelerate launch --config_file training_config.yaml train.py \ 7 | --model_name_or_path $MODEL_NAME_OR_PATH \ 8 | --per_device_train_batch_size 5 \ 9 | --per_device_eval_batch_size 8 \ 10 | --preprocessing_num_workers 16 \ 11 | --num_warmup_steps 500 \ 12 | --train_file train.json \ 13 | --validation_file val.json \ 14 | --learning_rate 5e-5 \ 15 | --polyak_update_lr 0.001 \ 16 | --gradient_accumulation_steps 1 \ 17 | --num_beams 6 \ 18 | --num_train_epochs 12 \ 19 | --output_dir $OUTPUT_DIR; 20 | # --overwrite_cache true \ 21 | # --dataset_name xsum \ 22 | # --source_prefix "summarize: " \ 23 | # --dataset_config "3.0.0" \ 24 | -------------------------------------------------------------------------------- /rl_training/training_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: MULTI_GPU 3 | fp16: true 4 | machine_rank: 0 5 | main_process_ip: null 6 | main_process_port: 1234 7 | main_training_function: main 8 | num_machines: 1 9 | num_processes: 1 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name='EntFA', 5 | version='1.0', 6 | description="Check our ACL 2022 paper 'Hallucinated but Factual! Inspecting the Factuality of Hallucinations in Abstractive Summarization'", 7 | author='Meng Cao', 8 | author_email='meng.cao@mail.mcgill.ca', 9 | package_dir={"": "src"}, 10 | packages=find_packages("src"), 11 | install_requires=[], # external packages as dependencies 12 | ) 13 | -------------------------------------------------------------------------------- /src/EntFA/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from fairseq.data.data_utils import collate_tokens 6 | from transformers import BartTokenizer 7 | 8 | 9 | class ConditionalSequenceGenerator: 10 | """Conditional sequence generator for calculating prior and posterior probability.""" 11 | def __init__(self, bart): 12 | self.bart = bart 13 | self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large", local_files_only=False) 14 | 15 | self.encode_func = bart.encode 16 | self.decode_func = bart.decode 17 | self.max_positions = bart.max_positions 18 | if type(self.max_positions) == int: 19 | self.max_positions = [self.max_positions] 20 | self.encode_line = bart.task.source_dictionary.encode_line 21 | 22 | self._initialize() 23 | 24 | def _initialize(self): 25 | """Set BART model to evaluation mode.""" 26 | self.bart.cuda() 27 | self.bart.eval() 28 | self.bart.half() 29 | 30 | def tokenize_target(self, input_str, left_pad=False, append_eos=False): 31 | """BPE-encode a sentence (or multiple sentences). 32 | 33 | Args: 34 | input_str (str or List[str]): input sentence to be tokenized. 35 | left_pad (bool): self-explained. 36 | 37 | Return: 38 | prev_output_tokens (torch.Tensor): [batch_size, length] 39 | target (torch.Tensor): [batch_size, length] 40 | tgt_lengths (torch.Tensor): [batch_size] 41 | 42 | """ 43 | if type(input_str) == type(''): 44 | input_str = [input_str] 45 | 46 | prev_ids, tgt_ids = [], [] 47 | 48 | for ins in input_str: 49 | tokens = self.bart.bpe.encode(ins) # : 1279 27932 29 50 | calibration = 1 51 | if len(tokens.split(" ")) > min(self.max_positions) - calibration: 52 | tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - calibration]) 53 | 54 | if append_eos: 55 | tokens = " " + tokens 56 | prev_tokens = " " + tokens 57 | tgt_tokens = tokens + " " 58 | 59 | tgt_ids.append(self.encode_line(tgt_tokens, append_eos=False).long()) 60 | prev_ids.append(self.encode_line(prev_tokens, append_eos=False).long()) 61 | 62 | prev_output_tokens = collate_tokens(prev_ids, pad_idx=1, left_pad=left_pad).cuda() 63 | target = collate_tokens(tgt_ids, pad_idx=1, left_pad=left_pad).cuda() 64 | tgt_lengths = torch.sum(target != 1, dim=1).cuda() 65 | 66 | return prev_output_tokens, target, tgt_lengths 67 | 68 | def tokenize(self, input_str, append_bos=False, append_eos=True, left_pad=True): 69 | """BPE-encode a sentence (or multiple sentences). 70 | 71 | Args: 72 | input_str (str or List[str]): input sentence to be tokenized. 73 | append_bos (bool): self-explained. 74 | append_eos (bool): self-explained. 75 | 76 | Return: 77 | input_ids (torch.Tensor): [batch_size, length] 78 | src_lengths (torch.Tensor): [batch_size] 79 | """ 80 | if type(input_str) == type(''): 81 | input_str = [input_str] 82 | 83 | input_ids = [] 84 | for ins in input_str: 85 | tokens = self.bart.bpe.encode(ins) # : 1279 27932 29 86 | calibration = sum([append_bos, append_eos]) 87 | if len(tokens.split(" ")) > min(self.max_positions) - calibration: 88 | tokens = " ".join(tokens.split(" ")[: min(self.max_positions) - calibration]) 89 | 90 | tokens = " " + tokens if append_bos else tokens 91 | tokens = tokens + " " if append_eos else tokens 92 | ids = self.encode_line(tokens, append_eos=False).long() 93 | input_ids.append(ids) 94 | 95 | input_ids = collate_tokens(input_ids, pad_idx=1, left_pad=left_pad).cuda() 96 | input_lengths = torch.sum(input_ids != 1, dim=1).cuda() 97 | 98 | return input_ids, input_lengths 99 | 100 | def tokenize_with_mask(self, input_str): 101 | """Tokenize sentence with a special token in it. 102 | 103 | Args: 104 | input_str (str or List[str]): input sentence to be tokenized. 105 | 106 | Return: 107 | input_ids (torch.Tensor): [batch_size, length] 108 | src_lengths (torch.Tensor): [batch_size] 109 | """ 110 | input_ids = self.tokenizer(input_str, return_tensors='pt', padding=True)['input_ids'].cuda() 111 | input_lengths = torch.sum(input_ids != 1, dim=1).cuda() 112 | return input_ids, input_lengths 113 | 114 | def encode_decode(self, src_input, tgt_input, mask_filling=False): 115 | """ 116 | Args: 117 | src_input: (List[str]) 118 | tgt_input: (List[str]) 119 | 120 | """ 121 | if mask_filling: 122 | src_tokens, src_lengths = self.tokenize_with_mask(src_input) 123 | prev_output_tokens, target, tgt_lengths = self.tokenize_target(tgt_input, left_pad=False, append_eos=True) 124 | else: 125 | src_tokens, src_lengths = self.tokenize(src_input, append_bos=False) 126 | prev_output_tokens, target, tgt_lengths = self.tokenize_target(tgt_input, left_pad=False) 127 | 128 | with torch.no_grad(): 129 | encoder_out = self.bart.model.encoder(src_tokens, src_lengths=src_lengths) 130 | decoder_out = self.bart.model.decoder(prev_output_tokens, encoder_out=encoder_out, features_only=False) 131 | 132 | probs = nn.functional.softmax(decoder_out[0], dim=-1) 133 | tgt_token_probs = torch.gather(probs, 2, target.unsqueeze(-1)).squeeze(2) 134 | 135 | # mask with probability 1.0 136 | max_tgt_length = tgt_lengths.max().item() 137 | tgt_lengths = tgt_lengths - 1 138 | tgt_mask = torch.arange(max_tgt_length)[None, :].cuda() < tgt_lengths[:, None] 139 | tgt_token_probs.masked_fill_(tgt_mask == False, 1.0) 140 | 141 | return tgt_token_probs, target 142 | 143 | def generate(self, src_input, tgt_input=None): 144 | """Conditional generation. 145 | 146 | Args: 147 | src_input (str or List[str]): input source sentence to be tokenized. 148 | tgt_input (str or List[str]): input target sentence to be tokenized. 149 | """ 150 | input_ids, lengths = self.tokenize(src_input, append_bos=False) 151 | 152 | target_ids = None 153 | if tgt_input is not None: 154 | assert len(src_input) == len(tgt_input), "source & target length should match." 155 | target_ids, _ = self.tokenize(tgt_input, append_bos=False, left_pad=False) 156 | 157 | with torch.no_grad(): 158 | encoder_output = self.encode_sequence(input_ids, lengths) 159 | decoder_output = self.decode_sequence(encoder_output, 160 | target_ids=target_ids, 161 | prefix_tokens=[2]) 162 | return decoder_output 163 | 164 | def mask_filling(self, src_input, tgt_input=None): 165 | """ 166 | Filling the mask in sentence(s). 167 | """ 168 | input_ids, lengths = self.tokenize_with_mask(src_input) 169 | 170 | target_ids = None 171 | if tgt_input is not None: 172 | assert len(src_input) == len(tgt_input), "source & target length should match." 173 | target_ids, _ = self.tokenize(tgt_input, left_pad=False) 174 | 175 | with torch.no_grad(): 176 | encoder_output = self.encode_sequence(input_ids, lengths) 177 | decoder_output = self.decode_sequence(encoder_output, 178 | target_ids=target_ids, 179 | prefix_tokens=[2, 0]) 180 | return decoder_output 181 | 182 | def encode_sequence(self, input_ids, lengths): 183 | return self.bart.model.encoder(input_ids, src_lengths=lengths) 184 | 185 | def decode_sequence( 186 | self, 187 | encoder_out, 188 | target_ids=None, 189 | min_decode_step=3, 190 | max_decode_step=100, 191 | pad_id=1, 192 | eos_id=2, 193 | prefix_tokens=[2, 0], 194 | ): 195 | batch_size = encoder_out['encoder_padding_mask'][0].shape[0] 196 | init_input = torch.tensor([prefix_tokens] * batch_size, dtype=torch.long).cuda() 197 | token_probs, tokens = None, [[] for i in range(batch_size)] 198 | end_mask = torch.tensor([False] * batch_size).cuda() 199 | 200 | softmax = nn.Softmax(dim=1) 201 | for step in range(max_decode_step): 202 | decoder_outputs = self.bart.model.decoder(init_input, encoder_out, features_only=False) 203 | logits = decoder_outputs[0][:, -1, :] # logits: [batch_size, vocab] 204 | attn = decoder_outputs[1]['attn'][0] # [batch_size, prev_token_len, src_token_len] 205 | 206 | if step + 1 < min_decode_step: 207 | logits[:, eos_id] = -math.inf # mask token when within minimal step 208 | logits[:, pad_id], logits[:, 0] = -math.inf, -math.inf # never select & token 209 | probs = softmax(logits) # probs: [batch_size, vocab] 210 | 211 | # select tokens 212 | if target_ids is not None: 213 | selected_token = target_ids[:, step] 214 | else: 215 | value, indices = torch.topk(probs, 5, dim=1) 216 | selected_token = indices[:, 0] 217 | 218 | selected_token = selected_token.masked_fill(end_mask, pad_id) 219 | init_input = torch.cat([init_input, selected_token.unsqueeze(1)], dim=-1) 220 | 221 | probs = torch.gather(probs, 1, selected_token.unsqueeze(1)).detach() 222 | probs = probs.masked_fill(end_mask.unsqueeze(1), 1.0) 223 | 224 | # str & probability 225 | token_probs = probs if token_probs is None else torch.cat([token_probs, probs], dim=-1) 226 | for t, s in zip(tokens, selected_token): 227 | t.append(self.decode_func(s.unsqueeze(0)) if s.item() != pad_id else '') 228 | 229 | # stop generation when all finished 230 | end_mask = torch.logical_or(end_mask, selected_token == eos_id) 231 | if end_mask.sum().item() == batch_size: 232 | break 233 | 234 | return init_input, tokens, token_probs -------------------------------------------------------------------------------- /src/EntFA/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | 5 | 6 | def read_lines(file_path): 7 | files = [] 8 | with open(file_path, 'r', encoding='utf-8') as f: 9 | for line in f: 10 | files.append(line.strip()) 11 | return files 12 | 13 | 14 | def read_jsonl(file_path): 15 | data = [] 16 | with open(file_path, 'r') as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | 22 | def get_probability(position, tokens, probs, entity): 23 | """Calculate the probability of a span. 24 | 25 | Args: 26 | position: (start, end) 27 | tokens: ['The', ' Archbishop', ' of', ...] 28 | probs: [0.50, 0.49, 0.88, ...] 29 | entity: Rodgers 30 | """ 31 | assert len(tokens) == len(probs), "Tokens and token probabilities does not match." 32 | 33 | end_pointer, end_pos = 0, [] 34 | for t in tokens: 35 | end_pointer += len(t) 36 | end_pos.append(end_pointer) 37 | 38 | assert position[1] in end_pos, "- {}\n- {}\n- {}\n- {}\n- {}\n".format(position, tokens, probs, entity, end_pos) 39 | last_index = end_pos.index(position[1]) 40 | indexes = [last_index] 41 | total_length = len(tokens[last_index]) 42 | 43 | while total_length < (position[1] - position[0]): 44 | last_index -= 1 45 | assert last_index >= 0 46 | indexes.append(last_index) 47 | total_length += len(tokens[last_index]) 48 | 49 | indexes.reverse() 50 | 51 | generated = ''.join([tokens[i] for i in indexes]) 52 | assert entity in generated, 'entity: {}; span: {}'.format(entity, generated) 53 | 54 | prob = 1.0 55 | for i in indexes: 56 | prob *= probs[i] 57 | return prob 58 | 59 | 60 | def get_cmlm_probability(generator, src_input, tgt_input, position, entity): 61 | outputs = generator.generate(src_input, tgt_input=tgt_input) 62 | init_input, tokens, token_probs = outputs 63 | 64 | probs = [] 65 | for p, tok, tokp, e in zip(position, tokens, token_probs, entity): 66 | probs.append(get_probability(p, tok, tokp, e).item()) 67 | 68 | return probs 69 | 70 | 71 | def get_probability_parallel(generator, src_input, tgt_input, position, entity, mask_filling=False): 72 | """Get entities probability in parallel decoding. 73 | 74 | Args: 75 | generator: model 76 | args*: outputs from prepare_cmlm_inputs() 77 | 78 | """ 79 | token_probs, target = generator.encode_decode(src_input, tgt_input=tgt_input, mask_filling=mask_filling) 80 | 81 | probs = [] 82 | for p, tok, tokp, e in zip(position, target, token_probs, entity): 83 | if mask_filling: 84 | assert tok[0].item() == 0 85 | tok, tokp = tok[1:], tokp[1:] 86 | 87 | tok_ = [] 88 | for t in tok: 89 | if t.item() == 1: 90 | tok_.append("") 91 | else: 92 | tok_.append(generator.decode_func(t.unsqueeze(0))) 93 | probs.append(get_probability(p, tok_, tokp, e).item()) 94 | 95 | return probs 96 | 97 | 98 | def get_prior_probability(generator, src_input, tgt_input, position, entity): 99 | """Tokenize input with a special token.""" 100 | assert len(src_input) == len(tgt_input), "source & target length should match." 101 | decoder_output = generator.mask_filling(src_input, tgt_input) 102 | init_input, tokens, token_probs = decoder_output 103 | 104 | probs = [] 105 | for p, tok, tokp, e in zip(position, tokens, token_probs, entity): 106 | probs.append(get_probability(p, tok, tokp, e).item()) 107 | return probs 108 | 109 | 110 | def prepare_clm_inputs(source, target, ent_parts=None): 111 | """For Conditional Language Model. For XSum BART only.""" 112 | if ent_parts is None: 113 | ent_parts = nlp(target).to_json()['ents'] 114 | 115 | entities, positions = [], [] 116 | inputs, targets = [], [] 117 | 118 | for e in ent_parts: 119 | inputs.append(source) 120 | targets.append(target) 121 | positions.append((e['start'], e['end'])) 122 | entities.append(target[e['start']: e['end']]) 123 | 124 | return inputs, targets, positions, entities 125 | 126 | 127 | def prepare_mlm_inputs(source, target, ent_parts=None): 128 | """For Masked Language Model. For BART only.""" 129 | if ent_parts is None: 130 | ent_parts = nlp(target).to_json()['ents'] 131 | 132 | inputs, targets = [], [] 133 | positions, entities = [], [] 134 | 135 | for e in ent_parts: 136 | inputs.append(target[0: e['start']] + '' + target[e['end']:]) 137 | targets.append(target) 138 | entities.append(target[e['start']: e['end']]) 139 | positions.append((e['start'], e['end'])) 140 | 141 | return inputs, targets, positions, entities 142 | 143 | 144 | def prepare_cmlm_inputs(source, target, ent_parts=None): 145 | """For Conditional Masked Language Model.""" 146 | if ent_parts is None: 147 | ent_parts = nlp(target).to_json()['ents'] 148 | 149 | inputs, targets = [], [] 150 | positions, entities = [], [] 151 | 152 | for e in ent_parts: 153 | masked_hypothesis = target[0: e['start']] + '###' + target[e['end']:] 154 | masked_hypothesis = ' ' + masked_hypothesis + ' <\s> ' + source 155 | inputs.append(masked_hypothesis) 156 | targets.append(' ' + target) 157 | 158 | entities.append(target[e['start']: e['end']]) 159 | positions.append((e['start'] + 4, e['end'] + 4)) 160 | 161 | return inputs, targets, positions, entities 162 | 163 | 164 | def prepare_cmlm_ent_inputs(source, target, ent_parts=None): 165 | """For Entity Conditional Masked Language Model.""" 166 | if ent_parts is None: 167 | ent_parts = nlp(target).to_json()['ents'] 168 | 169 | inputs, targets, entities = [], [], [] 170 | 171 | for e in ent_parts: 172 | masked_hypothesis = target[0: e['start']] + '###' + target[e['end']:] 173 | masked_hypothesis = ' ' + masked_hypothesis + ' <\s> ' + source 174 | inputs.append(masked_hypothesis) 175 | targets.append(' ' + target[e['start']: e['end']]) 176 | 177 | entities.append(target[e['start']: e['end']]) 178 | 179 | return inputs, targets, entities 180 | 181 | 182 | def process_document(raw_doc): 183 | TRIVIAL_SENTS = [ 184 | 'Share this with', 185 | 'Copy this link', 186 | 'These are external links and will open in a new window', 187 | ] 188 | 189 | raw_doc = raw_doc.strip() 190 | raw_doc_sents = raw_doc.split('\n') 191 | 192 | start_signal = False 193 | filtered_sentences = [] 194 | for s in raw_doc_sents: 195 | if start_signal: 196 | filtered_sentences.append(s) 197 | elif len(s.split()) > 1 and s not in TRIVIAL_SENTS: 198 | start_signal = True 199 | filtered_sentences.append(s) 200 | 201 | return ' '.join(filtered_sentences) 202 | 203 | 204 | def read_document(bbcid, folder): 205 | file_path = folder + '{}.document'.format(bbcid) 206 | if os.path.exists(file_path): 207 | with open(file_path, 'r') as f: 208 | return process_document(f.read()) 209 | else: 210 | return None 211 | --------------------------------------------------------------------------------