├── .gitignore ├── CODEOWNERS ├── CONTRIBUTING-ARCHIVED.md ├── LICENSE.txt ├── README.md ├── data_generation ├── __init__.py ├── augmentation_ops.py └── create_data.py ├── data_pairing └── pair_data.py ├── model.jpg ├── modeling ├── __init__.py ├── model.py ├── run.py ├── scripts │ ├── factcc-eval.sh │ ├── factcc-finetune.sh │ ├── factcc-train.sh │ ├── factccx-eval.sh │ ├── factccx-finetune.sh │ └── factccx-train.sh └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | # *.ipynb_checkpoints 77 | # */*.ipynb/* 78 | # **/*.ipynb/* 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | ### MacOS 128 | .DS_Store 129 | # 130 | # 131 | **/gen_data/* 132 | # Other files 133 | .word_vectors_cache/ 134 | **/wandb/* 135 | */wandb/* 136 | 137 | **/apex/* 138 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, Salesforce.com, Inc. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Evaluating the Factual Consistency of Abstractive Text Summarization 2 | Authors: [Wojciech Kryściński](https://twitter.com/iam_wkr), [Bryan McCann](https://bmccann.github.io/), [Caiming Xiong](http://www.stat.ucla.edu/~caiming/), and [Richard Socher](https://www.socher.org/) 3 | 4 | ## Introduction 5 | Currently used metrics for assessing summarization algorithms do not account for whether summaries are factually consistent with source documents. 6 | We propose a weakly-supervised, model-based approach for verifying factual consistency and identifying conflicts between source documents and a generated summary. 7 | Training data is generated by applying a series of rule-based transformations to the sentences of source documents. 8 | The factual consistency model is then trained jointly for three tasks: 9 | 1) identify whether sentences remain factually consistent after transformation, 10 | 2) extract a span in the source documents to support the consistency prediction, 11 | 3) extract a span in the summary sentence that is inconsistent if one exists. 12 | Transferring this model to summaries generated by several state-of-the art models reveals that this highly scalable approach substantially outperforms previous models, 13 | including those trained with strong supervision using standard datasets for natural language inference and fact checking. 14 | Additionally, human evaluation shows that the auxiliary span extraction tasks provide useful assistance in the process of verifying factual consistency. 15 | 16 | Paper link: https://arxiv.org/abs/1910.12840 17 | 18 |

19 | 20 | 21 | ## Table of Contents 22 | 23 | 1. [Updates](#updates) 24 | 2. [Citation](#citation) 25 | 3. [License](#license) 26 | 4. [Usage](#usage) 27 | 5. [Get Involved](#get-involved) 28 | 29 | ## Updates 30 | #### 1/27/2020 31 | Updated manually annotated data files - fixed `filepaths` in misaligned examples. 32 | 33 | Updated model checkpoint files - recomputed evaluation metrics for fixed examples. 34 | 35 | ## Citation 36 | ``` 37 | @article{kryscinskiFactCC2019, 38 | author = {Wojciech Kry{\'s}ci{\'n}ski and Bryan McCann and Caiming Xiong and Richard Socher}, 39 | title = {Evaluating the Factual Consistency of Abstractive Text Summarization}, 40 | journal = {arXiv preprint arXiv:1910.12840}, 41 | year = {2019}, 42 | } 43 | ``` 44 | 45 | 46 | ## License 47 | The code is released under the BSD-3 License (see `LICENSE.txt` for details), but we also ask that users respect the following: 48 | 49 | This software should not be used to promote or profit from violence, hate, and division, environmental destruction, abuse of human rights, 50 | or the destruction of people's physical and mental health. 51 | 52 | 53 | ## Usage 54 | Code repository uses `Python 3`. 55 | Prior to running any scripts please make sure to install required Python packages listed in the `requirements.txt` file. 56 | 57 | Example call: 58 | `pip3 install -r requirements.txt` 59 | 60 | ### Training and Evaluation Datasets 61 | Generated training data can be found [here](https://storage.googleapis.com/sfr-factcc-data-research/unpaired_generated_data.tar.gz). 62 | 63 | Manually annotated validation and test data can be found [here](https://storage.googleapis.com/sfr-factcc-data-research/unpaired_annotated_data.tar.gz). 64 | 65 | Both generated and manually annotated datasets require pairing with the original CNN/DailyMail articles. 66 | 67 | To recreate the datasets follow the instructions: 68 | 1. Download CNN Stories and Daily Mail Stories from https://cs.nyu.edu/~kcho/DMQA/ 69 | 2. Create a `cnndm` directory and unpack downloaded files into the directory 70 | 3. Download and unpack FactCC data _(do not rename directory)_ 71 | 4. Run the `pair_data.py` script to pair the data with original articles 72 | 73 | Example call: 74 | 75 | `python3 data_pairing/pair_data.py ` 76 | 77 | ### Generating Data 78 | 79 | Synthetic training data can be generated using code available in the `data_generation` directory. 80 | 81 | The data generation script expects the source documents input as one `jsonl` file, where each source document is embedded in a separate json object. 82 | The json object is required to contain an `id` key which stores an example id (uniqness is not required), and a `text` field that stores the text of the source document. 83 | 84 | Certain transformations rely on NER tagging, thus for best results use source documents with original (proper) casing. 85 | 86 | 87 | The following claim augmentations (transformations) are available: 88 | - `backtranslation` - Paraphrasing claim via backtranslation (requires Google Translate API key; costs apply) 89 | - `pronoun_swap` - Swapping a random pronoun in the claim 90 | - `date_swap` - Swapping random date/time found in the claim with one present in the source article 91 | - `number_swap` - Swapping random number found in the claim with one present in the source article 92 | - `entity_swap` - Swapping random entity name found in the claim with one present in the source article 93 | - `negation` - Negating meaning of the claim 94 | - `noise` - Injecting noise into the claim sentence 95 | 96 | For a detailed description of available transformations please refer to `Section 3.1` in the paper. 97 | 98 | To authenticate with the Google Cloud API follow [these](https://cloud.google.com/docs/authentication/getting-started) instructions. 99 | 100 | Example call: 101 | 102 | `python3 data_generation/create_data.py [--augmentations list-of-augmentations]` 103 | 104 | ### Model Code 105 | 106 | `FactCC` and `FactCCX` models can be trained or initialized from a checkpoint using code available in the `modeling` directory. 107 | 108 | Quickstart training, fine-tuning, and evaluation scripts are shared in the `scripts` directory. 109 | Before use make sure to update `*_PATH` variables with appropriate, absolute paths. 110 | 111 | To customize training or evaluation settings please refer to the flags in the `run.py` file. 112 | 113 | To utilize Weights&Biases dashboards login to the service using the following command: `wandb login `. 114 | 115 | Trained `FactCC` model checkpoint can be found [here](https://storage.googleapis.com/sfr-factcc-data-research/factcc-checkpoint.tar.gz). 116 | 117 | Trained `FactCCX` model checkpoint can be found [here](https://storage.googleapis.com/sfr-factcc-data-research/factccx-checkpoint.tar.gz). 118 | 119 | *IMPORTANT:* Due to data pre-processing, the first run of training or evaluation code on a large dataset can take up to a few hours before the actual procedure starts. 120 | 121 | #### Running on other data 122 | To run pretrained `FactCC` or `FactCCX` models on your data follow the instruction: 123 | 1. Download pre-trained model checkpoint, linked above 124 | 2. Prepare your data in `jsonl` format. Each example should be a separate `json` object with `id`, `text`, `claim` keys 125 | representing example id, source document, and claim sentence accordingly. Name file as `data-dev.jsonl` 126 | 3. Update corresponding `*-eval.sh` script 127 | 128 | 129 | ## Get Involved 130 | 131 | Please create a GitHub issue if you have any questions, suggestions, requests or bug-reports. 132 | We welcome PRs! 133 | -------------------------------------------------------------------------------- /data_generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/factCC/5daeba51dc63f9d3197df6945501a1153a77affd/data_generation/__init__.py -------------------------------------------------------------------------------- /data_generation/augmentation_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data augmentation (transformations) operations used to generate 3 | synthetic training data for the `FactCC` and `FactCCX` models. 4 | """ 5 | 6 | import random 7 | 8 | import spacy 9 | 10 | from google.cloud import translate 11 | 12 | 13 | LABEL_MAP = {True: "CORRECT", False: "INCORRECT"} 14 | 15 | 16 | def align_ws(old_token, new_token): 17 | # Align trailing whitespaces between tokens 18 | if old_token[-1] == new_token[-1] == " ": 19 | return new_token 20 | elif old_token[-1] == " ": 21 | return new_token + " " 22 | elif new_token[-1] == " ": 23 | return new_token[:-1] 24 | else: 25 | return new_token 26 | 27 | 28 | def make_new_example(eid=None, text=None, claim=None, label=None, extraction_span=None, 29 | backtranslation=None, augmentation=None, augmentation_span=None, noise=None): 30 | # Embed example information in a json object. 31 | return { 32 | "id": eid, 33 | "text": text, 34 | "claim": claim, 35 | "label": label, 36 | "extraction_span": extraction_span, 37 | "backtranslation": backtranslation, 38 | "augmentation": augmentation, 39 | "augmentation_span": augmentation_span, 40 | "noise": noise 41 | } 42 | 43 | 44 | class Transformation(): 45 | # Base class for all data transformations 46 | 47 | def __init__(self): 48 | # Spacy toolkit used for all NLP-related substeps 49 | self.spacy = spacy.load("en") 50 | 51 | def transform(self, example): 52 | # Function applies transformation on passed example 53 | pass 54 | 55 | 56 | class SampleSentences(Transformation): 57 | # Embed document as Spacy object and sample one sentence as claim 58 | def __init__(self, min_sent_len=8): 59 | super().__init__() 60 | self.min_sent_len = min_sent_len 61 | 62 | def transform(self, example): 63 | assert example["text"] is not None, "Text must be available" 64 | 65 | # split into sentences 66 | page_id = example["id"] 67 | page_text = example["text"].replace("\n", " ") 68 | page_doc = self.spacy(page_text, disable=["tagger"]) 69 | sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len] 70 | 71 | # sample claim 72 | claim = random.choice(sents) 73 | new_example = make_new_example(eid=page_id, text=page_doc, 74 | claim=self.spacy(claim.text), 75 | label=LABEL_MAP[True], 76 | extraction_span=(claim.start, claim.end-1), 77 | backtranslation=False, noise=False) 78 | return new_example 79 | 80 | 81 | class NegateSentences(Transformation): 82 | # Apply or remove negation from negatable tokens 83 | def __init__(self): 84 | super().__init__() 85 | self.__negatable_tokens = ("are", "is", "was", "were", "have", "has", "had", 86 | "do", "does", "did", "can", "ca", "could", "may", 87 | "might", "must", "shall", "should", "will", "would") 88 | 89 | def transform(self, example): 90 | assert example["text"] is not None, "Text must be available" 91 | assert example["claim"] is not None, "Claim must be available" 92 | 93 | new_example = dict(example) 94 | new_claim, aug_span = self.__negate_sentences(new_example["claim"]) 95 | 96 | if new_claim: 97 | new_example["claim"] = new_claim 98 | new_example["label"] = LABEL_MAP[False] 99 | new_example["augmentation"] = self.__class__.__name__ 100 | new_example["augmentation_span"] = aug_span 101 | return new_example 102 | else: 103 | return None 104 | 105 | def __negate_sentences(self, claim): 106 | # find negatable token, return None if no candiates found 107 | candidate_tokens = [token for token in claim if token.text in self.__negatable_tokens] 108 | 109 | if not candidate_tokens: 110 | return None, None 111 | 112 | # choose random token to negate 113 | negated_token = random.choice(candidate_tokens) 114 | negated_ix = negated_token.i 115 | doc_len = len(claim) 116 | 117 | if negated_ix > 0: 118 | if claim[negated_ix - 1].text in self.__negatable_tokens: 119 | negated_token = claim[negated_ix - 1] 120 | negated_ix = negated_ix - 1 121 | 122 | # check whether token is negative 123 | is_negative = False 124 | if (doc_len - 1) > negated_ix: 125 | if claim[negated_ix + 1].text in ["not", "n't"]: 126 | is_negative = True 127 | elif claim[negated_ix + 1].text == "no": 128 | return None, None 129 | 130 | 131 | # negate token 132 | claim_tokens = [token.text_with_ws for token in claim] 133 | if is_negative: 134 | if claim[negated_ix + 1].text.lower() == "n't": 135 | if claim[negated_ix + 1].text.lower() == "ca": 136 | claim_tokens[negated_ix] = "can" if claim_tokens[negated_ix].islower() else "Can" 137 | claim_tokens[negated_ix] = claim_tokens[negated_ix] + " " 138 | claim_tokens.pop(negated_ix + 1) 139 | else: 140 | if claim[negated_ix].text.lower() in ["am", "may", "might", "must", "shall", "will"]: 141 | negation = "not " 142 | else: 143 | negation = random.choice(["not ", "n't "]) 144 | 145 | if negation == "n't ": 146 | if claim[negated_ix].text.lower() == "can": 147 | claim_tokens[negated_ix] = "ca" if claim_tokens[negated_ix].islower() else "Ca" 148 | else: 149 | claim_tokens[negated_ix] = claim_tokens[negated_ix][:-1] 150 | claim_tokens.insert(negated_ix + 1, negation) 151 | 152 | # create new claim object 153 | new_claim = self.spacy("".join(claim_tokens)) 154 | augmentation_span = (negated_ix, negated_ix if is_negative else negated_ix + 1) 155 | 156 | if new_claim.text == claim.text: 157 | return None, None 158 | else: 159 | return new_claim, augmentation_span 160 | 161 | 162 | class Backtranslation(Transformation): 163 | # Paraphrase sentence via backtranslation with Google Translate API 164 | # Requires API Key for Google Cloud SDK, additional charges DO apply 165 | def __init__(self, dst_lang=None): 166 | super().__init__() 167 | 168 | self.src_lang = "en" 169 | self.dst_lang = dst_lang 170 | self.accepted_langs = ["fr", "de", "zh-TW", "es", "ru"] 171 | self.translator = translate.Client() 172 | 173 | def transform(self, example): 174 | assert example["text"] is not None, "Text must be available" 175 | assert example["claim"] is not None, "Claim must be available" 176 | 177 | new_example = dict(example) 178 | new_claim, _ = self.__backtranslate(new_example["claim"]) 179 | 180 | if new_claim: 181 | new_example["claim"] = new_claim 182 | new_example["backtranslation"] = True 183 | return new_example 184 | else: 185 | return None 186 | 187 | def __backtranslate(self, claim): 188 | # chose destination language, passed or random from list 189 | dst_lang = self.dst_lang if self.dst_lang else random.choice(self.accepted_langs) 190 | 191 | # translate to intermediate language and back 192 | claim_trans = self.translator.translate(claim.text, target_language=dst_lang, format_="text") 193 | claim_btrans = self.translator.translate(claim_trans["translatedText"], target_language=self.src_lang, format_="text") 194 | 195 | # create new claim object 196 | new_claim = self.spacy(claim_btrans["translatedText"]) 197 | augmentation_span = (new_claim[0].i, new_claim[-1].i) 198 | 199 | if claim.text == new_claim.text: 200 | return None, None 201 | else: 202 | return new_claim, augmentation_span 203 | 204 | 205 | class PronounSwap(Transformation): 206 | # Swap randomly chosen pronoun 207 | def __init__(self, prob_swap=0.5): 208 | super().__init__() 209 | 210 | self.class2pronoun_map = { 211 | "SUBJECT": ["you", "he", "she", "we", "they"], 212 | "OBJECT": ["me", "you", "him", "her", "us", "them"], 213 | "POSSESSIVE": ["my", "your", "his", "her", "its", "out", "your", "their"], 214 | "REFLEXIVE": ["myself", "yourself", "himself", "itself", "outselves", "yourselves", "themselves"] 215 | } 216 | 217 | self.pronoun2class_map = {pronoun: key for (key, values) in self.class2pronoun_map.items() for pronoun in values} 218 | self.pronouns = {pronoun for (key, values) in self.class2pronoun_map.items() for pronoun in values} 219 | 220 | def transform(self, example): 221 | assert example["text"] is not None, "Text must be available" 222 | assert example["claim"] is not None, "Claim must be available" 223 | 224 | new_example = dict(example) 225 | new_claim, aug_span = self.__swap_pronouns(new_example["claim"]) 226 | 227 | if new_claim: 228 | new_example["claim"] = new_claim 229 | new_example["label"] = LABEL_MAP[False] 230 | new_example["augmentation"] = self.__class__.__name__ 231 | new_example["augmentation_span"] = aug_span 232 | return new_example 233 | else: 234 | return None 235 | 236 | def __swap_pronouns(self, claim): 237 | # find pronouns 238 | claim_pronouns = [token for token in claim if token.text.lower() in self.pronouns] 239 | 240 | if not claim_pronouns: 241 | return None, None 242 | 243 | # find pronoun replacement 244 | chosen_token = random.choice(claim_pronouns) 245 | chosen_ix = chosen_token.i 246 | chosen_class = self.pronoun2class_map[chosen_token.text.lower()] 247 | 248 | candidate_tokens = [token for token in self.class2pronoun_map[chosen_class] if token != chosen_token.text.lower()] 249 | 250 | if not candidate_tokens: 251 | return None, None 252 | 253 | # swap pronoun and update indices 254 | swapped_token = random.choice(candidate_tokens) 255 | swapped_token = align_ws(chosen_token.text_with_ws, swapped_token) 256 | swapped_token = swapped_token if chosen_token.text.islower() else swapped_token.capitalize() 257 | 258 | claim_tokens = [token.text_with_ws for token in claim] 259 | claim_tokens[chosen_ix] = swapped_token 260 | 261 | # create new claim object 262 | new_claim = self.spacy("".join(claim_tokens)) 263 | augmentation_span = (chosen_ix, chosen_ix) 264 | 265 | if claim.text == new_claim.text: 266 | return None, None 267 | else: 268 | return new_claim, augmentation_span 269 | 270 | 271 | class NERSwap(Transformation): 272 | # Swap NER objects - parent class 273 | def __init__(self): 274 | super().__init__() 275 | self.categories = () 276 | 277 | def transform(self, example): 278 | assert example["text"] is not None, "Text must be available" 279 | assert example["claim"] is not None, "Claim must be available" 280 | 281 | new_example = dict(example) 282 | new_claim, aug_span = self.__swap_entities(new_example["text"], new_example["claim"]) 283 | 284 | if new_claim: 285 | new_example["claim"] = new_claim 286 | new_example["label"] = LABEL_MAP[False] 287 | new_example["augmentation"] = self.__class__.__name__ 288 | new_example["augmentation_span"] = aug_span 289 | return new_example 290 | else: 291 | return None 292 | 293 | def __swap_entities(self, text, claim): 294 | # find entities in given category 295 | text_ents = [ent for ent in text.ents if ent.label_ in self.categories] 296 | claim_ents = [ent for ent in claim.ents if ent.label_ in self.categories] 297 | 298 | if not claim_ents or not text_ents: 299 | return None, None 300 | 301 | # choose entity to replace and find possible replacement in source 302 | replaced_ent = random.choice(claim_ents) 303 | candidate_ents = [ent for ent in text_ents if ent.text != replaced_ent.text and ent.text not in replaced_ent.text and replaced_ent.text not in ent.text] 304 | 305 | if not candidate_ents: 306 | return None, None 307 | 308 | # update claim and indices 309 | swapped_ent = random.choice(candidate_ents) 310 | claim_tokens = [token.text_with_ws for token in claim] 311 | swapped_token = align_ws(replaced_ent.text_with_ws, swapped_ent.text_with_ws) 312 | claim_swapped = claim_tokens[:replaced_ent.start] + [swapped_token] + claim_tokens[replaced_ent.end:] 313 | 314 | # create new claim object 315 | new_claim = self.spacy("".join(claim_swapped)) 316 | augmentation_span = (replaced_ent.start, replaced_ent.start + len(swapped_ent) - 1) 317 | 318 | if new_claim.text == claim.text: 319 | return None, None 320 | else: 321 | return new_claim, augmentation_span 322 | 323 | 324 | class EntitySwap(NERSwap): 325 | # NER swapping class specialized for entities (people, companies, locations, etc.) 326 | def __init__(self): 327 | super().__init__() 328 | self.categories = ("PERSON", "ORG", "NORP", "FAC", "GPE", "LOC", "PRODUCT", 329 | "WORK_OF_ART", "EVENT") 330 | 331 | 332 | class NumberSwap(NERSwap): 333 | # NER swapping class specialized for numbers (excluding dates) 334 | def __init__(self): 335 | super().__init__() 336 | 337 | self.categories = ("PERCENT", "MONEY", "QUANTITY", "CARDINAL") 338 | 339 | 340 | class DateSwap(NERSwap): 341 | # NER swapping class specialized for dates and time 342 | def __init__(self): 343 | super().__init__() 344 | 345 | self.categories = ("DATE", "TIME") 346 | 347 | 348 | class AddNoise(Transformation): 349 | # Inject noise into claims 350 | def __init__(self, noise_prob=0.05, delete_prob=0.8): 351 | super().__init__() 352 | 353 | self.noise_prob = noise_prob 354 | self.delete_prob = delete_prob 355 | self.spacy = spacy.load("en") 356 | 357 | def transform(self, example): 358 | assert example["text"] is not None, "Text must be available" 359 | assert example["claim"] is not None, "Claim must be available" 360 | 361 | new_example = dict(example) 362 | claim = new_example["claim"] 363 | aug_span = new_example["augmentation_span"] 364 | new_claim, aug_span = self.__add_noise(claim, aug_span) 365 | 366 | if new_claim: 367 | new_example["claim"] = new_claim 368 | new_example["augmentation_span"] = aug_span 369 | new_example["noise"] = True 370 | return new_example 371 | else: 372 | return None 373 | 374 | def __add_noise(self, claim, aug_span): 375 | claim_tokens = [token.text_with_ws for token in claim] 376 | 377 | new_claim = [] 378 | for ix, token in enumerate(claim_tokens): 379 | # don't modify text inside an augmented span 380 | apply_augmentation = True 381 | if aug_span: 382 | span_start, span_end = aug_span 383 | if span_start <= ix <= span_end: 384 | apply_augmentation = False 385 | 386 | # decide whether to add noise 387 | if apply_augmentation and random.random() < self.noise_prob: 388 | # decide whether to replicate or delete token 389 | if random.random() < self.delete_prob: 390 | # update spans and skip token 391 | if aug_span: 392 | span_start, span_end = aug_span 393 | if ix < span_start: 394 | span_start -= 1 395 | span_end -= 1 396 | aug_span = span_start, span_end 397 | if len(new_claim) > 0: 398 | if new_claim[-1][-1] != " ": 399 | new_claim[-1] = new_claim[-1] + " " 400 | continue 401 | else: 402 | if aug_span: 403 | span_start, span_end = aug_span 404 | if ix < span_start: 405 | span_start += 1 406 | span_end += 1 407 | aug_span = span_start, span_end 408 | new_claim.append(token) 409 | new_claim.append(token) 410 | new_claim = self.spacy("".join(new_claim)) 411 | 412 | if claim.text == new_claim.text: 413 | return None, None 414 | else: 415 | return new_claim, aug_span 416 | -------------------------------------------------------------------------------- /data_generation/create_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for generating synthetic data for FactCC training. 3 | 4 | Script expects source documents in `jsonl` format with each source document 5 | embedded in a separate json object. 6 | 7 | Json objects are required to contain `id` and `text` keys. 8 | """ 9 | 10 | import argparse 11 | import json 12 | import os 13 | 14 | from tqdm import tqdm 15 | 16 | import augmentation_ops as ops 17 | 18 | 19 | 20 | def load_source_docs(file_path, to_dict=False): 21 | with open(file_path, encoding="utf-8") as f: 22 | data = [json.loads(line) for line in f] 23 | 24 | if to_dict: 25 | data = {example["id"]: example for example in data} 26 | return data 27 | 28 | 29 | def save_data(args, data, name_suffix): 30 | output_file = os.path.splitext(args.data_file)[0] + "-" + name_suffix + ".jsonl" 31 | 32 | with open(output_file, "w", encoding="utf-8") as fd: 33 | for example in data: 34 | example = dict(example) 35 | example["text"] = example["text"].text 36 | example["claim"] = example["claim"].text 37 | fd.write(json.dumps(example, ensure_ascii=False) + "\n") 38 | 39 | 40 | def apply_transformation(data, operation): 41 | new_data = [] 42 | for example in tqdm(data): 43 | try: 44 | new_example = operation.transform(example) 45 | if new_example: 46 | new_data.append(new_example) 47 | except Exception as e: 48 | print("Caught exception:", e) 49 | return new_data 50 | 51 | 52 | def main(args): 53 | # load data 54 | source_docs = load_source_docs(args.data_file, to_dict=False) 55 | print("Loaded %d source documents." % len(source_docs)) 56 | 57 | # create or load positive examples 58 | print("Creating data examples") 59 | sclaims_op = ops.SampleSentences() 60 | data = apply_transformation(source_docs, sclaims_op) 61 | print("Created %s example pairs." % len(data)) 62 | 63 | if args.save_intermediate: 64 | save_data(args, data, "clean") 65 | 66 | # backtranslate 67 | data_btrans = [] 68 | if not args.augmentations or "backtranslation" in args.augmentations: 69 | print("Creating backtranslation examples") 70 | btrans_op = ops.Backtranslation() 71 | data_btrans = apply_transformation(data, btrans_op) 72 | print("Backtranslated %s example pairs." % len(data_btrans)) 73 | 74 | if args.save_intermediate: 75 | save_data(args, data_btrans, "btrans") 76 | 77 | data_positive = data + data_btrans 78 | save_data(args, data_positive, "positive") 79 | 80 | # create negative examples 81 | data_pronoun = [] 82 | if not args.augmentations or "pronoun_swap" in args.augmentations: 83 | print("Creating pronoun examples") 84 | pronoun_op = ops.PronounSwap() 85 | data_pronoun = apply_transformation(data_positive, pronoun_op) 86 | print("PronounSwap %s example pairs." % len(data_pronoun)) 87 | 88 | if args.save_intermediate: 89 | save_data(args, data_pronoun, "pronoun") 90 | 91 | data_dateswp = [] 92 | if not args.augmentations or "date_swap" in args.augmentations: 93 | print("Creating date swap examples") 94 | dateswap_op = ops.DateSwap() 95 | data_dateswp = apply_transformation(data_positive, dateswap_op) 96 | print("DateSwap %s example pairs." % len(data_dateswp)) 97 | 98 | if args.save_intermediate: 99 | save_data(args, data_dateswp, "dateswp") 100 | 101 | data_numswp = [] 102 | if not args.augmentations or "number_swap" in args.augmentations: 103 | print("Creating number swap examples") 104 | numswap_op = ops.NumberSwap() 105 | data_numswp = apply_transformation(data_positive, numswap_op) 106 | print("NumberSwap %s example pairs." % len(data_numswp)) 107 | 108 | if args.save_intermediate: 109 | save_data(args, data_numswp, "numswp") 110 | 111 | data_entswp = [] 112 | if not args.augmentations or "entity_swap" in args.augmentations: 113 | print("Creating entity swap examples") 114 | entswap_op = ops.EntitySwap() 115 | data_entswp = apply_transformation(data_positive, entswap_op) 116 | print("EntitySwap %s example pairs." % len(data_entswp)) 117 | 118 | if args.save_intermediate: 119 | save_data(args, data_entswp, "entswp") 120 | 121 | data_negation = [] 122 | if not args.augmentations or "negation" in args.augmentations: 123 | print("Creating negation examples") 124 | negation_op = ops.NegateSentences() 125 | data_negation = apply_transformation(data_positive, negation_op) 126 | print("Negation %s example pairs." % len(data_negation)) 127 | 128 | if args.save_intermediate: 129 | save_data(args, data_negation, "negation") 130 | 131 | # add noise to all 132 | data_negative = data_pronoun + data_dateswp + data_numswp + data_entswp + data_negation 133 | save_data(args, data_negative, "negative") 134 | 135 | # ADD NOISE 136 | data_pos_low_noise = [] 137 | data_neg_low_noise = [] 138 | 139 | if not args.augmentations or "noise" in args.augmentations: 140 | # add light noise 141 | print("Adding light noise to data") 142 | low_noise_op = ops.AddNoise() 143 | 144 | data_pos_low_noise = apply_transformation(data_positive, low_noise_op) 145 | print("PositiveNoisy %s example pairs." % len(data_pos_low_noise)) 146 | save_data(args, data_pos_low_noise, "positive-noise") 147 | 148 | data_neg_low_noise = apply_transformation(data_negative, low_noise_op) 149 | print("NegativeNoisy %s example pairs." % len(data_neg_low_noise)) 150 | save_data(args, data_neg_low_noise, "negative-noise") 151 | 152 | 153 | if __name__ == "__main__": 154 | PARSER = argparse.ArgumentParser() 155 | PARSER.add_argument("data_file", type=str, help="Path to file containing source documents.") 156 | PARSER.add_argument("--augmentations", type=str, nargs="+", default=(), help="List of data augmentation applied to data.") 157 | PARSER.add_argument("--all_augmentations", action="store_true", help="Flag whether all augmentation should be applied.") 158 | PARSER.add_argument("--save_intermediate", action="store_true", help="Flag whether intermediate data from each transformation should be saved in separate files.") 159 | ARGS = PARSER.parse_args() 160 | main(ARGS) 161 | -------------------------------------------------------------------------------- /data_pairing/pair_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for recreating the FactCC dataset from CNN/DM Story files. 3 | 4 | CNN/DM Story files can be downloaded from https://cs.nyu.edu/~kcho/DMQA/ 5 | 6 | Unpaired FactCC data should be downloaded and unpacked. 7 | CNN/DM data to be stored in a `cnndm` directory with `cnn` and `dm` sub-directories. 8 | """ 9 | import argparse 10 | import json 11 | import os 12 | 13 | from tqdm import tqdm 14 | 15 | 16 | def parse_story_file(content): 17 | """ 18 | Remove article highlights and unnecessary white characters. 19 | """ 20 | content_raw = content.split("@highlight")[0] 21 | content = " ".join(filter(None, [x.strip() for x in content_raw.split("\n")])) 22 | return content 23 | 24 | 25 | def main(args): 26 | """ 27 | Walk data sub-directories and recreate examples 28 | """ 29 | for path, _, filenames in os.walk(args.unpaired_data): 30 | for filename in filenames: 31 | if not ".jsonl" in filename: 32 | continue 33 | 34 | unpaired_path = os.path.join(path, filename) 35 | print("Processing file:", unpaired_path) 36 | 37 | with open(unpaired_path) as fd: 38 | dataset = [json.loads(line) for line in fd] 39 | 40 | for example in tqdm(dataset): 41 | story_path = os.path.join(args.story_files, example["filepath"]) 42 | 43 | with open(story_path) as fd: 44 | story_content = fd.read() 45 | example["text"] = parse_story_file(story_content) 46 | 47 | paired_path = unpaired_path.replace("unpaired_", "") 48 | os.makedirs(os.path.dirname(paired_path), exist_ok=True) 49 | with open(paired_path, "w") as fd: 50 | for example in dataset: 51 | fd.write(json.dumps(example, ensure_ascii=False) + "\n") 52 | 53 | 54 | if __name__ == "__main__": 55 | PARSER = argparse.ArgumentParser() 56 | PARSER.add_argument("unpaired_data", type=str, help="Path to directory holding unpaired data") 57 | PARSER.add_argument("story_files", type=str, help="Path to directory holding CNNDM story files") 58 | ARGS = PARSER.parse_args() 59 | main(ARGS) 60 | -------------------------------------------------------------------------------- /model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/factCC/5daeba51dc63f9d3197df6945501a1153a77affd/model.jpg -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/factCC/5daeba51dc63f9d3197df6945501a1153a77affd/modeling/__init__.py -------------------------------------------------------------------------------- /modeling/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Salesforce.com, Inc. 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import json 20 | import logging 21 | import math 22 | import os 23 | import sys 24 | from io import open 25 | 26 | import torch 27 | from torch import nn 28 | from torch.nn import CrossEntropyLoss, MSELoss 29 | 30 | 31 | from pytorch_transformers.modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, 32 | PreTrainedModel, prune_linear_layer, add_start_docstrings) 33 | from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertModel, BertLayer, BertPooler 34 | 35 | 36 | class BertPointer(BertPreTrainedModel): 37 | def __init__(self, config): 38 | super(BertPointer, self).__init__(config) 39 | self.num_labels = config.num_labels 40 | 41 | self.bert = BertModel(config) 42 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 43 | 44 | # classifiers 45 | self.ext_start_classifier = nn.Linear(config.hidden_size, 1, bias=False) 46 | self.ext_end_classifier = nn.Linear(config.hidden_size, 1, bias=False) 47 | self.aug_start_classifier = nn.Linear(config.hidden_size, 1, bias=False) 48 | self.aug_end_classifier = nn.Linear(config.hidden_size, 1, bias=False) 49 | 50 | self.label_classifier = nn.Linear(config.hidden_size, self.config.num_labels) 51 | 52 | self.apply(self.init_weights) 53 | 54 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, 55 | position_ids=None, head_mask=None, 56 | ext_mask=None, ext_start_labels=None, ext_end_labels=None, 57 | aug_mask=None, aug_start_labels=None, aug_end_labels=None, 58 | loss_lambda=1.0): 59 | # run through bert 60 | bert_outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 61 | attention_mask=attention_mask, head_mask=head_mask) 62 | 63 | # label classifier 64 | pooled_output = bert_outputs[1] 65 | pooled_output = self.dropout(pooled_output) 66 | label_logits = self.label_classifier(pooled_output) 67 | 68 | # extraction classifier 69 | output = bert_outputs[0] 70 | ext_mask = ext_mask.unsqueeze(-1) 71 | ext_start_logits = self.ext_start_classifier(output) * ext_mask 72 | ext_end_logits = self.ext_end_classifier(output) * ext_mask 73 | 74 | # augmentation classifier 75 | output = bert_outputs[0] 76 | aug_mask = aug_mask.unsqueeze(-1) 77 | aug_start_logits = self.aug_start_classifier(output) * aug_mask 78 | aug_end_logits = self.aug_end_classifier(output) * aug_mask 79 | 80 | span_logits = (ext_start_logits, ext_end_logits, aug_start_logits, aug_end_logits,) 81 | outputs = (label_logits,) + span_logits + bert_outputs[2:] 82 | 83 | if labels is not None and \ 84 | ext_start_labels is not None and ext_end_labels is not None and \ 85 | aug_start_labels is not None and aug_end_labels is not None: 86 | if self.num_labels == 1: 87 | # We are doing regression 88 | loss_fct = MSELoss() 89 | loss = loss_fct(label_logits.view(-1), labels.view(-1)) 90 | else: 91 | loss_fct = CrossEntropyLoss() 92 | 93 | # label loss 94 | labels_loss = loss_fct(label_logits.view(-1, self.num_labels), labels.view(-1)) 95 | 96 | # extraction loss 97 | ext_start_loss = loss_fct(ext_start_logits.squeeze(), ext_start_labels) 98 | ext_end_loss = loss_fct(ext_end_logits.squeeze(), ext_end_labels) 99 | 100 | # augmentation loss 101 | aug_start_loss = loss_fct(aug_start_logits.squeeze(), aug_start_labels) 102 | aug_end_loss = loss_fct(aug_end_logits.squeeze(), aug_end_labels) 103 | 104 | span_loss = (ext_start_loss + ext_end_loss + aug_start_loss + aug_end_loss) / 4 105 | 106 | # combined loss 107 | loss = labels_loss + loss_lambda * span_loss 108 | 109 | outputs = (loss, labels_loss, span_loss, ext_start_loss, ext_end_loss, aug_start_loss, aug_end_loss) + outputs 110 | 111 | return outputs # (loss), (logits), (hidden_states), (attentions) 112 | -------------------------------------------------------------------------------- /modeling/run.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Salesforce.com, Inc. 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet).""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import glob 22 | import logging 23 | import os 24 | import random 25 | 26 | import wandb 27 | import numpy as np 28 | import torch 29 | 30 | from model import BertPointer 31 | from utils import (compute_metrics, convert_examples_to_features, output_modes, processors) 32 | 33 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) 34 | from torch.utils.data.distributed import DistributedSampler 35 | from tqdm import tqdm, trange 36 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig, BertForSequenceClassification, BertTokenizer) 37 | 38 | from pytorch_transformers import AdamW, WarmupLinearSchedule 39 | 40 | logger = logging.getLogger(__name__) 41 | wandb.init(project="entailment-metric") 42 | 43 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), ()) 44 | 45 | MODEL_CLASSES = { 46 | 'pbert': (BertConfig, BertPointer, BertTokenizer), 47 | 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 48 | } 49 | 50 | 51 | def set_seed(args): 52 | random.seed(args.seed) 53 | np.random.seed(args.seed) 54 | torch.manual_seed(args.seed) 55 | if args.n_gpu > 0: 56 | torch.cuda.manual_seed_all(args.seed) 57 | 58 | 59 | def make_model_input(args, batch): 60 | inputs = {'input_ids': batch[0], 61 | 'attention_mask': batch[1], 62 | 'token_type_ids': batch[2], 63 | 'labels': batch[3]} 64 | 65 | # add extraction and augmentation spans for PointerBert model 66 | if args.model_type == "pbert": 67 | inputs["ext_mask"] = batch[4] 68 | inputs["ext_start_labels"] = batch[5] 69 | inputs["ext_end_labels"] = batch[6] 70 | inputs["aug_mask"] = batch[7] 71 | inputs["aug_start_labels"] = batch[8] 72 | inputs["aug_end_labels"] = batch[9] 73 | inputs["loss_lambda"] = args.loss_lambda 74 | 75 | return inputs 76 | 77 | 78 | def train(args, train_dataset, model, tokenizer): 79 | """ Train the model """ 80 | 81 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 82 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 83 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 84 | 85 | if args.max_steps > 0: 86 | t_total = args.max_steps 87 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 88 | else: 89 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 90 | 91 | # Prepare optimizer and schedule (linear warmup and decay) 92 | no_decay = ['bias', 'LayerNorm.weight'] 93 | optimizer_grouped_parameters = [ 94 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, 95 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 96 | ] 97 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 98 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 99 | if args.fp16: 100 | try: 101 | from apex import amp 102 | except ImportError: 103 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 104 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 105 | 106 | # multi-gpu training (should be after apex fp16 initialization) 107 | if args.n_gpu > 1: 108 | model = torch.nn.DataParallel(model) 109 | 110 | # Distributed training (should be after apex fp16 initialization) 111 | if args.local_rank != -1: 112 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 113 | output_device=args.local_rank, 114 | find_unused_parameters=True) 115 | 116 | # Train! 117 | logger.info("***** Running training *****") 118 | logger.info(" Num examples = %d", len(train_dataset)) 119 | logger.info(" Num Epochs = %d", args.num_train_epochs) 120 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 121 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 122 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 123 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 124 | logger.info(" Total optimization steps = %d", t_total) 125 | 126 | global_step = 0 127 | tr_loss, logging_loss = 0.0, 0.0 128 | model.zero_grad() 129 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 130 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 131 | 132 | for epoch_ix in train_iterator: 133 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 134 | for step, batch in enumerate(epoch_iterator): 135 | model.train() 136 | batch = tuple(t.to(args.device) for t in batch) 137 | inputs = make_model_input(args, batch) 138 | outputs = model(**inputs) 139 | loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) 140 | 141 | if args.n_gpu > 1: 142 | loss = loss.mean() # mean() to average on multi-gpu parallel training 143 | if args.gradient_accumulation_steps > 1: 144 | loss = loss / args.gradient_accumulation_steps 145 | 146 | if args.fp16: 147 | with amp.scale_loss(loss, optimizer) as scaled_loss: 148 | scaled_loss.backward() 149 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 150 | else: 151 | loss.backward() 152 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 153 | 154 | tr_loss += loss.item() 155 | if (step + 1) % args.gradient_accumulation_steps == 0: 156 | scheduler.step() # Update learning rate schedule 157 | optimizer.step() 158 | model.zero_grad() 159 | global_step += 1 160 | 161 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 162 | results = {} 163 | 164 | # Log metrics 165 | logits_ix = 1 if args.model_type == "bert" else 7 166 | logits = outputs[logits_ix] 167 | preds = np.argmax(logits.detach().cpu().numpy(), axis=1) 168 | out_label_ids = inputs['labels'].detach().cpu().numpy() 169 | 170 | result = compute_metrics(args.task_name, preds, out_label_ids) 171 | results.update(result) 172 | 173 | for key, value in results.items(): 174 | wandb.log({'train_{}'.format(key): value}) 175 | 176 | wandb.log({"train_loss": (tr_loss - logging_loss) / args.logging_steps}) 177 | wandb.log({"train_lr": scheduler.get_lr()[0]}) 178 | logging_loss = tr_loss 179 | 180 | if args.max_steps > 0 and global_step > args.max_steps: 181 | epoch_iterator.close() 182 | break 183 | 184 | if args.evaluate_during_training: 185 | # Only evaluate when single GPU otherwise metrics may not average well 186 | results = evaluate(args, model, tokenizer) 187 | for key, value in results.items(): 188 | wandb.log({'eval_{}'.format(key): value}) 189 | 190 | if args.local_rank in [-1, 0]: 191 | # Save model checkpoint 192 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(epoch_ix)) 193 | if not os.path.exists(output_dir): 194 | os.makedirs(output_dir) 195 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 196 | model_to_save.save_pretrained(output_dir) 197 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 198 | logger.info("Saving model checkpoint to %s", output_dir) 199 | 200 | if args.max_steps > 0 and global_step > args.max_steps: 201 | train_iterator.close() 202 | break 203 | 204 | return global_step, tr_loss / global_step 205 | 206 | 207 | def evaluate(args, model, tokenizer, prefix=""): 208 | # Loop to handle MNLI double evaluation (matched, mis-matched) 209 | eval_task_names = (args.task_name,) 210 | eval_outputs_dirs = (args.output_dir,) 211 | 212 | results = {} 213 | cnt = 0 214 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): 215 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) 216 | 217 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 218 | os.makedirs(eval_output_dir) 219 | 220 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 221 | # Note that DistributedSampler samples randomly 222 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 223 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 224 | 225 | # Eval! 226 | logger.info("***** Running evaluation {} *****".format(prefix)) 227 | logger.info(" Num examples = %d", len(eval_dataset)) 228 | logger.info(" Batch size = %d", args.eval_batch_size) 229 | 230 | eval_loss = 0.0 231 | nb_eval_steps = 0 232 | preds = None 233 | out_label_ids = None 234 | 235 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 236 | model.eval() 237 | batch = tuple(t.to(args.device) for t in batch) 238 | 239 | with torch.no_grad(): 240 | inputs = make_model_input(args, batch) 241 | outputs = model(**inputs) 242 | 243 | # monitoring 244 | tmp_eval_loss = outputs[0] 245 | logits_ix = 1 if args.model_type == "bert" else 7 246 | logits = outputs[logits_ix] 247 | eval_loss += tmp_eval_loss.mean().item() 248 | nb_eval_steps += 1 249 | 250 | if preds is None: 251 | preds = logits.detach().cpu().numpy() 252 | out_label_ids = inputs['labels'].detach().cpu().numpy() 253 | else: 254 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 255 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 256 | 257 | preds = np.argmax(preds, axis=1) 258 | result = compute_metrics(args.task_name, preds, out_label_ids) 259 | eval_loss = eval_loss / nb_eval_steps 260 | result["loss"] = eval_loss 261 | results.update(result) 262 | 263 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt") 264 | with open(output_eval_file, "w") as writer: 265 | logger.info("***** Eval results {} *****".format(prefix)) 266 | for key in sorted(results.keys()): 267 | logger.info(" %s = %s", key, str(result[key])) 268 | writer.write("%s = %s\n" % (key, str(result[key]))) 269 | 270 | return results 271 | 272 | 273 | def load_and_cache_examples(args, task, tokenizer, evaluate=False): 274 | if args.local_rank not in [-1, 0]: 275 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 276 | 277 | processor = processors[task]() 278 | output_mode = output_modes[task] 279 | # Load data features from cache or dataset file 280 | cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format( 281 | 'dev' if evaluate else 'train', 282 | list(filter(None, args.model_name_or_path.split('/'))).pop(), 283 | str(args.max_seq_length), 284 | str(task))) 285 | if os.path.exists(cached_features_file): 286 | logger.info("Loading features from cached file %s", cached_features_file) 287 | features = torch.load(cached_features_file) 288 | else: 289 | logger.info("Creating features from dataset file at %s", args.data_dir) 290 | label_list = processor.get_labels() 291 | examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir) 292 | features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode, 293 | cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end 294 | cls_token=tokenizer.cls_token, 295 | cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0, 296 | sep_token=tokenizer.sep_token, 297 | sep_token_extra=bool(args.model_type in ['roberta']), 298 | pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet 299 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 300 | pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0) 301 | if args.local_rank in [-1, 0]: 302 | logger.info("Saving features into cached file %s", cached_features_file) 303 | torch.save(features, cached_features_file) 304 | 305 | if args.local_rank == 0: 306 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 307 | 308 | # Convert to Tensors and build dataset 309 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 310 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 311 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 312 | all_ext_mask = torch.tensor([f.extraction_mask for f in features], dtype=torch.float) 313 | all_ext_start_ids = torch.tensor([f.extraction_start_ids for f in features], dtype=torch.long) 314 | all_ext_end_ids = torch.tensor([f.extraction_end_ids for f in features], dtype=torch.long) 315 | all_aug_mask = torch.tensor([f.augmentation_mask for f in features], dtype=torch.float) 316 | all_aug_start_ids = torch.tensor([f.augmentation_start_ids for f in features], dtype=torch.long) 317 | all_aug_end_ids = torch.tensor([f.augmentation_end_ids for f in features], dtype=torch.long) 318 | 319 | if output_mode == "classification": 320 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long) 321 | elif output_mode == "regression": 322 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float) 323 | 324 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids, 325 | all_ext_mask, all_ext_start_ids, all_ext_end_ids, 326 | all_aug_mask, all_aug_start_ids, all_aug_end_ids) 327 | return dataset 328 | 329 | 330 | def main(): 331 | parser = argparse.ArgumentParser() 332 | 333 | ## Required parameters 334 | parser.add_argument("--data_dir", default=None, type=str, required=True, 335 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 336 | parser.add_argument("--model_type", default=None, type=str, required=True, 337 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 338 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 339 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) 340 | parser.add_argument("--task_name", default=None, type=str, required=True, 341 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) 342 | parser.add_argument("--output_dir", default=None, type=str, required=True, 343 | help="The output directory where the model predictions and checkpoints will be written.") 344 | parser.add_argument("--train_from_scratch", action='store_true', 345 | help="Whether to run training without loading pretrained weights.") 346 | 347 | ## Other parameters 348 | parser.add_argument("--config_name", default="", type=str, 349 | help="Pretrained config name or path if not the same as model_name") 350 | parser.add_argument("--tokenizer_name", default="", type=str, 351 | help="Pretrained tokenizer name or path if not the same as model_name") 352 | parser.add_argument("--cache_dir", default="", type=str, 353 | help="Where do you want to store the pre-trained models downloaded from s3") 354 | parser.add_argument("--max_seq_length", default=512, type=int, 355 | help="The maximum total input sequence length after tokenization. Sequences longer " 356 | "than this will be truncated, sequences shorter will be padded.") 357 | parser.add_argument("--do_train", action='store_true', 358 | help="Whether to run training.") 359 | parser.add_argument("--do_eval", action='store_true', 360 | help="Whether to run eval on the dev set.") 361 | parser.add_argument("--evaluate_during_training", action='store_true', 362 | help="Run evaluation during training at each logging step.") 363 | parser.add_argument("--do_lower_case", action='store_true', 364 | help="Set this flag if you are using an uncased model.") 365 | 366 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 367 | help="Batch size per GPU/CPU for training.") 368 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 369 | help="Batch size per GPU/CPU for evaluation.") 370 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 371 | help="Number of updates steps to accumulate before performing a backward/update pass.") 372 | parser.add_argument("--learning_rate", default=5e-5, type=float, 373 | help="The initial learning rate for Adam.") 374 | parser.add_argument("--loss_lambda", default=0.1, type=float, 375 | help="The lambda parameter for loss mixing.") 376 | parser.add_argument("--weight_decay", default=0.0, type=float, 377 | help="Weight deay if we apply some.") 378 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 379 | help="Epsilon for Adam optimizer.") 380 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 381 | help="Max gradient norm.") 382 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 383 | help="Total number of training epochs to perform.") 384 | parser.add_argument("--max_steps", default=-1, type=int, 385 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 386 | parser.add_argument("--warmup_steps", default=0, type=int, 387 | help="Linear warmup over warmup_steps.") 388 | 389 | parser.add_argument('--logging_steps', type=int, default=100, 390 | help="Log every X updates steps.") 391 | parser.add_argument('--save_steps', type=int, default=50, 392 | help="Save checkpoint every X updates steps.") 393 | parser.add_argument("--eval_all_checkpoints", action='store_true', 394 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 395 | parser.add_argument("--no_cuda", action='store_true', 396 | help="Avoid using CUDA when available") 397 | parser.add_argument('--overwrite_output_dir', action='store_true', 398 | help="Overwrite the content of the output directory") 399 | parser.add_argument('--overwrite_cache', action='store_true', 400 | help="Overwrite the cached training and evaluation sets") 401 | parser.add_argument('--seed', type=int, default=42, 402 | help="random seed for initialization") 403 | 404 | parser.add_argument('--fp16', action='store_true', 405 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 406 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 407 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 408 | "See details at https://nvidia.github.io/apex/amp.html") 409 | parser.add_argument("--local_rank", type=int, default=-1, 410 | help="For distributed training: local_rank") 411 | args = parser.parse_args() 412 | 413 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: 414 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) 415 | 416 | # Setup CUDA, GPU & distributed training 417 | if args.local_rank == -1 or args.no_cuda: 418 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 419 | args.n_gpu = torch.cuda.device_count() 420 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 421 | torch.cuda.set_device(args.local_rank) 422 | device = torch.device("cuda", args.local_rank) 423 | torch.distributed.init_process_group(backend='nccl') 424 | args.n_gpu = 1 425 | args.device = device 426 | 427 | # Setup logging 428 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 429 | datefmt = '%m/%d/%Y %H:%M:%S', 430 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 431 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 432 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 433 | 434 | # Set seed 435 | set_seed(args) 436 | 437 | # Prepare GLUE task 438 | args.task_name = args.task_name.lower() 439 | if args.task_name not in processors: 440 | raise ValueError("Task not found: %s" % (args.task_name)) 441 | processor = processors[args.task_name]() 442 | args.output_mode = output_modes[args.task_name] 443 | label_list = processor.get_labels() 444 | num_labels = len(label_list) 445 | 446 | # Load pretrained model and tokenizer 447 | if args.local_rank not in [-1, 0]: 448 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 449 | 450 | args.model_type = args.model_type.lower() 451 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 452 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name) 453 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) 454 | if args.train_from_scratch: 455 | logger.info("Training model from scratch.") 456 | model = model_class(config=config) 457 | else: 458 | logger.info("Loading model from checkpoint.") 459 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) 460 | 461 | if args.local_rank == 0: 462 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 463 | 464 | model.to(args.device) 465 | 466 | wandb.watch(model) 467 | logger.info("Training/evaluation parameters %s", args) 468 | 469 | 470 | # Training 471 | if args.do_train: 472 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 473 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 474 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 475 | 476 | 477 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 478 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 479 | # Create output directory if needed 480 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 481 | os.makedirs(args.output_dir) 482 | 483 | logger.info("Saving model checkpoint to %s", args.output_dir) 484 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 485 | # They can then be reloaded using `from_pretrained()` 486 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 487 | model_to_save.save_pretrained(args.output_dir) 488 | tokenizer.save_pretrained(args.output_dir) 489 | 490 | # Good practice: save your training arguments together with the trained model 491 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 492 | 493 | # Load a trained model and vocabulary that you have fine-tuned 494 | model = model_class.from_pretrained(args.output_dir) 495 | tokenizer = tokenizer_class.from_pretrained(args.output_dir) 496 | model.to(args.device) 497 | 498 | 499 | # Evaluation 500 | results = {} 501 | if args.do_eval and args.local_rank in [-1, 0]: 502 | checkpoints = [args.output_dir] 503 | if args.eval_all_checkpoints: 504 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 505 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 506 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 507 | for checkpoint in checkpoints: 508 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 509 | model = model_class.from_pretrained(checkpoint) 510 | model.to(args.device) 511 | result = evaluate(args, model, tokenizer, prefix=global_step) 512 | result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) 513 | results.update(result) 514 | 515 | return results 516 | 517 | 518 | if __name__ == "__main__": 519 | main() 520 | -------------------------------------------------------------------------------- /modeling/scripts/factcc-eval.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Evaluate FactCC model 3 | 4 | # UPDATE PATHS BEFORE RUNNING SCRIPT 5 | export CODE_PATH= # absolute path to modeling directory 6 | export DATA_PATH= # absolute path to data directory 7 | export CKPT_PATH= # absolute path to model checkpoint 8 | 9 | export TASK_NAME=factcc_annotated 10 | export MODEL_NAME=bert-base-uncased 11 | 12 | python3 $CODE_PATH/run.py \ 13 | --task_name $TASK_NAME \ 14 | --do_eval \ 15 | --eval_all_checkpoints \ 16 | --do_lower_case \ 17 | --overwrite_cache \ 18 | --max_seq_length 512 \ 19 | --per_gpu_train_batch_size 12 \ 20 | --model_type bert \ 21 | --model_name_or_path $MODEL_NAME \ 22 | --data_dir $DATA_PATH \ 23 | --output_dir $CKPT_PATH 24 | -------------------------------------------------------------------------------- /modeling/scripts/factcc-finetune.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Fine-tune FactCC model 3 | 4 | # UPDATE PATHS BEFORE RUNNING SCRIPT 5 | export CODE_PATH= # absolute path to modeling directory 6 | export DATA_PATH= # absolute path to data directory 7 | export OUTPUT_PATH= # absolute path to model checkpoint 8 | 9 | export TASK_NAME=factcc_generated 10 | export MODEL_NAME=bert-base-uncased 11 | 12 | python3 $CODE_PATH/run.py \ 13 | --task_name $TASK_NAME \ 14 | --do_train \ 15 | --do_eval \ 16 | --evaluate_during_training \ 17 | --do_lower_case \ 18 | --max_seq_length 512 \ 19 | --per_gpu_train_batch_size 12 \ 20 | --learning_rate 2e-5 \ 21 | --num_train_epochs 10.0 \ 22 | --data_dir $DATA_PATH \ 23 | --model_type bert \ 24 | --model_name_or_path $MODEL_NAME \ 25 | --output_dir $OUTPUT_PATH/$MODEL_NAME-$TASK_NAME-finetune-$RANDOM/ 26 | -------------------------------------------------------------------------------- /modeling/scripts/factcc-train.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Train FactCC model 3 | 4 | # UPDATE PATHS BEFORE RUNNING SCRIPT 5 | export CODE_PATH= # absolute path to modeling directory 6 | export DATA_PATH= # absolute path to data directory 7 | export OUTPUT_PATH= # absolute path to model checkpoint 8 | 9 | export TASK_NAME=factcc_generated 10 | export MODEL_NAME=bert-base-uncased 11 | 12 | python3 $CODE_PATH/run.py \ 13 | --task_name $TASK_NAME \ 14 | --do_train \ 15 | --do_eval \ 16 | --do_lower_case \ 17 | --train_from_scratch \ 18 | --data_dir $DATA_PATH \ 19 | --model_type bert \ 20 | --model_name_or_path $MODEL_NAME \ 21 | --max_seq_length 512 \ 22 | --per_gpu_train_batch_size 12 \ 23 | --learning_rate 2e-5 \ 24 | --num_train_epochs 20.0 \ 25 | --evaluate_during_training \ 26 | --eval_all_checkpoints \ 27 | --overwrite_cache \ 28 | --output_dir $OUTPUT_DIR/$MODEL_NAME-$TASK_NAME-train-$RANDOM/ 29 | -------------------------------------------------------------------------------- /modeling/scripts/factccx-eval.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Evaluate FactCCX model 3 | 4 | # UPDATE PATHS BEFORE RUNNING SCRIPT 5 | export CODE_PATH= # absolute path to modeling directory 6 | export DATA_PATH= # absolute path to data directory 7 | export CKPT_PATH= # absolute path to model checkpoint 8 | 9 | export TASK_NAME=factcc_annotated 10 | export MODEL_NAME=bert-base-uncased 11 | 12 | python3 $CODE_PATH/run.py \ 13 | --task_name $TASK_NAME \ 14 | --do_eval \ 15 | --eval_all_checkpoints \ 16 | --do_lower_case \ 17 | --max_seq_length 512 \ 18 | --per_gpu_train_batch_size 12 \ 19 | --model_type pbert \ 20 | --model_name_or_path $MODEL_NAME \ 21 | --data_dir $DATA_PATH \ 22 | --output_dir $CKPT_PATH 23 | 24 | -------------------------------------------------------------------------------- /modeling/scripts/factccx-finetune.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Fine-tuning BERT-base on CNN/DM data 3 | 4 | # UPDATE PATHS BEFORE RUNNING SCRIPT 5 | export CODE_PATH= # absolute path to modeling directory 6 | export DATA_PATH= # absolute path to data directory 7 | export OUTPUT_PATH= # absolute path to model checkpoint 8 | 9 | export TASK_NAME=factcc_generated 10 | export MODEL_NAME=bert-base-uncased 11 | 12 | python3 $CODE_PATH/run.py \ 13 | --task_name $TASK_NAME \ 14 | --do_train \ 15 | --do_eval \ 16 | --do_lower_case \ 17 | --data_dir $DATA_PATH \ 18 | --model_type pbert \ 19 | --model_name_or_path $MODEL_NAME \ 20 | --max_seq_length 512 \ 21 | --per_gpu_train_batch_size 12 \ 22 | --learning_rate 2e-5 \ 23 | --num_train_epochs 10.0 \ 24 | --loss_lambda 0.1 \ 25 | --evaluate_during_training \ 26 | --eval_all_checkpoints \ 27 | --overwrite_cache \ 28 | --output_dir $OUTPUT_PATH/$MODEL_NAME-$TASK_NAME-finetune-$RANDOM/ 29 | -------------------------------------------------------------------------------- /modeling/scripts/factccx-train.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # Train FactCCX model 3 | 4 | # UPDATE PATHS BEFORE RUNNING SCRIPT 5 | export CODE_PATH= # absolute path to modeling directory 6 | export DATA_PATH= # absolute path to data directory 7 | export OUTPUT_PATH= # absolute path to model checkpoint 8 | 9 | export TASK_NAME=factcc_generated 10 | export MODEL_NAME=bert-base-uncased 11 | 12 | python3 $CODE_PATH/run.py \ 13 | --task_name $TASK_NAME \ 14 | --do_train \ 15 | --do_eval \ 16 | --do_lower_case \ 17 | --train_from_scratch \ 18 | --data_dir $DATA_PATH \ 19 | --model_type pbert \ 20 | --model_name_or_path $MODEL_NAME \ 21 | --max_seq_length 512 \ 22 | --per_gpu_train_batch_size 12 \ 23 | --learning_rate 2e-5 \ 24 | --num_train_epochs 20.0 \ 25 | --evaluate_during_training \ 26 | --eval_all_checkpoints \ 27 | --overwrite_cache \ 28 | --output_dir $OUTPUT_DIR/$MODEL_NAME-$TASK_NAME-train-$RANDOM/ 29 | -------------------------------------------------------------------------------- /modeling/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020, Salesforce.com, Inc. 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ BERT classification fine-tuning: utilities to work with GLUE tasks """ 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import csv 21 | import logging 22 | import json 23 | import os 24 | import sys 25 | from io import open 26 | 27 | from scipy.stats import pearsonr, spearmanr 28 | from sklearn.metrics import f1_score, balanced_accuracy_score, accuracy_score 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | class InputExample(object): 34 | """A single training/test example for simple sequence classification.""" 35 | 36 | def __init__(self, guid, text_a, text_b=None, label=None, 37 | extraction_span=None, augmentation_span=None): 38 | """Constructs a InputExample. 39 | 40 | Args: 41 | guid: Unique id for the example. 42 | text_a: string. The untokenized text of the first sequence. For single 43 | sequence tasks, only this sequence must be specified. 44 | text_b: (Optional) string. The untokenized text of the second sequence. 45 | Only must be specified for sequence pair tasks. 46 | label: (Optional) string. The label of the example. This should be 47 | specified for train and dev examples, but not for test examples. 48 | """ 49 | self.guid = guid 50 | self.text_a = text_a 51 | self.text_b = text_b 52 | self.label = label 53 | self.extraction_span = extraction_span 54 | self.augmentation_span = augmentation_span 55 | 56 | 57 | class InputFeatures(object): 58 | """A single set of features of data.""" 59 | 60 | def __init__(self, input_ids, input_mask, segment_ids, label_id, 61 | extraction_mask=None, extraction_start_ids=None, extraction_end_ids=None, 62 | augmentation_mask=None, augmentation_start_ids=None, augmentation_end_ids=None): 63 | self.input_ids = input_ids 64 | self.input_mask = input_mask 65 | self.segment_ids = segment_ids 66 | self.label_id = label_id 67 | self.extraction_mask = extraction_mask 68 | self.extraction_start_ids = extraction_start_ids 69 | self.extraction_end_ids = extraction_end_ids 70 | self.augmentation_mask = augmentation_mask 71 | self.augmentation_start_ids = augmentation_start_ids 72 | self.augmentation_end_ids = augmentation_end_ids 73 | 74 | 75 | class DataProcessor(object): 76 | """Base class for data converters for sequence classification data sets.""" 77 | 78 | def get_train_examples(self, data_dir): 79 | """Gets a collection of `InputExample`s for the train set.""" 80 | raise NotImplementedError() 81 | 82 | def get_dev_examples(self, data_dir): 83 | """Gets a collection of `InputExample`s for the dev set.""" 84 | raise NotImplementedError() 85 | 86 | def get_labels(self): 87 | """Gets the list of labels for this data set.""" 88 | raise NotImplementedError() 89 | 90 | @classmethod 91 | def _read_tsv(cls, input_file, quotechar=None): 92 | """Reads a tab separated value file.""" 93 | with open(input_file, "r", encoding="utf-8-sig") as f: 94 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 95 | lines = [] 96 | for line in reader: 97 | if sys.version_info[0] == 2: 98 | line = list(unicode(cell, 'utf-8') for cell in line) 99 | lines.append(line) 100 | return lines 101 | 102 | @classmethod 103 | def _read_json(cls, input_file): 104 | """Reads a jsonl file.""" 105 | with open(input_file, "r", encoding="utf-8") as f: 106 | lines = [] 107 | for line in f: 108 | lines.append(json.loads(line)) 109 | return lines 110 | 111 | 112 | class FactCCGeneratedProcessor(DataProcessor): 113 | """Processor for the generated FactCC data set.""" 114 | 115 | def get_train_examples(self, data_dir): 116 | """See base class.""" 117 | return self._create_examples( 118 | self._read_json(os.path.join(data_dir, "data-train.jsonl")), "train") 119 | 120 | def get_dev_examples(self, data_dir): 121 | """See base class.""" 122 | return self._create_examples( 123 | self._read_json(os.path.join(data_dir, "data-dev.jsonl")), "dev") 124 | 125 | def get_labels(self): 126 | """See base class.""" 127 | return ["CORRECT", "INCORRECT"] 128 | 129 | def _create_examples(self, lines, set_type): 130 | """Creates examples for the training and dev sets.""" 131 | examples = [] 132 | for example in lines: 133 | guid = example["id"] 134 | text_a = example["text"] 135 | text_b = example["claim"] 136 | label = example["label"] 137 | extraction_span = example["extraction_span"] 138 | augmentation_span = example["augmentation_span"] 139 | 140 | examples.append( 141 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, 142 | extraction_span=extraction_span, augmentation_span=augmentation_span)) 143 | return examples 144 | 145 | 146 | class FactCCManualProcessor(DataProcessor): 147 | """Processor for the WNLI data set (GLUE version).""" 148 | 149 | def get_train_examples(self, data_dir): 150 | """See base class.""" 151 | return self._create_examples( 152 | self._read_json(os.path.join(data_dir, "data-train.jsonl")), "train") 153 | 154 | def get_dev_examples(self, data_dir): 155 | """See base class.""" 156 | return self._create_examples( 157 | self._read_json(os.path.join(data_dir, "data-dev.jsonl")), "dev") 158 | 159 | def get_labels(self): 160 | """See base class.""" 161 | return ["CORRECT", "INCORRECT"] 162 | 163 | def _create_examples(self, lines, set_type): 164 | """Creates examples for the training and dev sets.""" 165 | examples = [] 166 | for (i, example) in enumerate(lines): 167 | guid = str(i) 168 | text_a = example["text"] 169 | text_b = example["claim"] 170 | label = example["label"] 171 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 172 | return examples 173 | 174 | 175 | def convert_examples_to_features(examples, label_list, max_seq_length, 176 | tokenizer, output_mode, 177 | cls_token_at_end=False, 178 | cls_token='[CLS]', 179 | cls_token_segment_id=1, 180 | sep_token='[SEP]', 181 | sep_token_extra=False, 182 | pad_on_left=False, 183 | pad_token=0, 184 | pad_token_segment_id=0, 185 | sequence_a_segment_id=0, 186 | sequence_b_segment_id=1, 187 | mask_padding_with_zero=True): 188 | """ Loads a data file into a list of `InputBatch`s 189 | `cls_token_at_end` define the location of the CLS token: 190 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 191 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 192 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 193 | """ 194 | 195 | label_map = {label : i for i, label in enumerate(label_list)} 196 | 197 | features = [] 198 | for (ex_index, example) in enumerate(examples): 199 | if ex_index % 10000 == 0: 200 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 201 | 202 | tokens_a = tokenizer.tokenize(example.text_a) 203 | 204 | tokens_b = None 205 | if example.text_b: 206 | tokens_b = tokenizer.tokenize(example.text_b) 207 | # Modifies `tokens_a` and `tokens_b` in place so that the total 208 | # length is less than the specified length. 209 | # Account for [CLS], [SEP], [SEP] with "- 3". " -4" for RoBERTa. 210 | special_tokens_count = 4 if sep_token_extra else 3 211 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count) 212 | else: 213 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. 214 | special_tokens_count = 3 if sep_token_extra else 2 215 | if len(tokens_a) > max_seq_length - special_tokens_count: 216 | tokens_a = tokens_a[:(max_seq_length - special_tokens_count)] 217 | 218 | # The convention in BERT is: 219 | # (a) For sequence pairs: 220 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 221 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 222 | # (b) For single sequences: 223 | # tokens: [CLS] the dog is hairy . [SEP] 224 | # type_ids: 0 0 0 0 0 0 0 225 | # 226 | # Where "type_ids" are used to indicate whether this is the first 227 | # sequence or the second sequence. The embedding vectors for `type=0` and 228 | # `type=1` were learned during pre-training and are added to the wordpiece 229 | # embedding vector (and position vector). This is not *strictly* necessary 230 | # since the [SEP] token unambiguously separates the sequences, but it makes 231 | # it easier for the model to learn the concept of sequences. 232 | # 233 | # For classification tasks, the first vector (corresponding to [CLS]) is 234 | # used as as the "sentence vector". Note that this only makes sense because 235 | # the entire model is fine-tuned. 236 | tokens = tokens_a + [sep_token] 237 | if sep_token_extra: 238 | # roberta uses an extra separator b/w pairs of sentences 239 | tokens += [sep_token] 240 | segment_ids = [sequence_a_segment_id] * len(tokens) 241 | 242 | if tokens_b: 243 | tokens += tokens_b + [sep_token] 244 | segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1) 245 | 246 | if cls_token_at_end: 247 | tokens = tokens + [cls_token] 248 | segment_ids = segment_ids + [cls_token_segment_id] 249 | else: 250 | tokens = [cls_token] + tokens 251 | segment_ids = [cls_token_segment_id] + segment_ids 252 | 253 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 254 | 255 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 256 | # tokens are attended to. 257 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 258 | 259 | ####### AUX LOSS DATA 260 | # get tokens_a mask 261 | extraction_span_len = len(tokens_a) + 2 262 | extraction_mask = [1 if 0 < ix < extraction_span_len else 0 for ix in range(max_seq_length)] 263 | 264 | # get extraction labels 265 | if example.extraction_span: 266 | ext_start, ext_end = example.extraction_span 267 | extraction_start_ids = ext_start + 1 268 | extraction_end_ids = ext_end + 1 269 | else: 270 | extraction_start_ids = extraction_span_len 271 | extraction_end_ids = extraction_span_len 272 | 273 | augmentation_mask = [1 if extraction_span_len <= ix < extraction_span_len + len(tokens_b) + 1 else 0 for ix in range(max_seq_length)] 274 | 275 | if example.augmentation_span: 276 | aug_start, aug_end = example.augmentation_span 277 | augmentation_start_ids = extraction_span_len + aug_start 278 | augmentation_end_ids = extraction_span_len + aug_end 279 | else: 280 | last_sep_token = extraction_span_len + len(tokens_b) 281 | augmentation_start_ids = last_sep_token 282 | augmentation_end_ids = last_sep_token 283 | 284 | # Zero-pad up to the sequence length. 285 | padding_length = max_seq_length - len(input_ids) 286 | if pad_on_left: 287 | input_ids = ([pad_token] * padding_length) + input_ids 288 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 289 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 290 | else: 291 | input_ids = input_ids + ([pad_token] * padding_length) 292 | input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 293 | segment_ids = segment_ids + ([pad_token_segment_id] * padding_length) 294 | 295 | assert len(input_ids) == max_seq_length 296 | assert len(input_mask) == max_seq_length 297 | assert len(segment_ids) == max_seq_length 298 | 299 | if output_mode == "classification": 300 | label_id = label_map[example.label] 301 | elif output_mode == "regression": 302 | label_id = float(example.label) 303 | else: 304 | raise KeyError(output_mode) 305 | 306 | if ex_index < 3: 307 | logger.info("*** Example ***") 308 | logger.info("guid: %s" % (example.guid)) 309 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) 310 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 311 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 312 | logger.info("ext mask: %s" % " ".join([str(x) for x in extraction_mask])) 313 | logger.info("ext start: %d" % extraction_start_ids) 314 | logger.info("ext end: %d" % extraction_end_ids) 315 | logger.info("aug mask: %s" % " ".join([str(x) for x in augmentation_mask])) 316 | logger.info("aug start: %d" % augmentation_start_ids) 317 | logger.info("aug end: %d" % augmentation_end_ids) 318 | logger.info("label: %d" % label_id) 319 | 320 | extraction_start_ids = min(extraction_start_ids, 511) 321 | extraction_end_ids = min(extraction_end_ids, 511) 322 | augmentation_start_ids = min(augmentation_start_ids, 511) 323 | augmentation_end_ids = min(augmentation_end_ids, 511) 324 | 325 | features.append( 326 | InputFeatures(input_ids=input_ids, 327 | input_mask=input_mask, 328 | segment_ids=segment_ids, 329 | label_id=label_id, 330 | extraction_mask=extraction_mask, 331 | extraction_start_ids=extraction_start_ids, 332 | extraction_end_ids=extraction_end_ids, 333 | augmentation_mask=augmentation_mask, 334 | augmentation_start_ids=augmentation_start_ids, 335 | augmentation_end_ids=augmentation_end_ids)) 336 | return features 337 | 338 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 339 | """Truncates a sequence pair in place to the maximum length.""" 340 | 341 | # This is a simple heuristic which will always truncate the longer sequence 342 | # one token at a time. This makes more sense than truncating an equal percent 343 | # of tokens from each, since if one sequence is very short then each token 344 | # that's truncated likely contains more information than a longer sequence. 345 | while True: 346 | total_length = len(tokens_a) + len(tokens_b) 347 | if total_length <= max_length: 348 | break 349 | if len(tokens_a) > len(tokens_b): 350 | tokens_a.pop() 351 | else: 352 | tokens_b.pop() 353 | 354 | 355 | def simple_accuracy(preds, labels): 356 | return (preds == labels).mean() 357 | 358 | 359 | def acc_and_f1(preds, labels): 360 | acc = simple_accuracy(preds, labels) 361 | f1 = f1_score(y_true=labels, y_pred=preds) 362 | return { 363 | "acc": acc, 364 | "f1": f1, 365 | "acc_and_f1": (acc + f1) / 2, 366 | } 367 | 368 | 369 | def pearson_and_spearman(preds, labels): 370 | pearson_corr = pearsonr(preds, labels)[0] 371 | spearman_corr = spearmanr(preds, labels)[0] 372 | return { 373 | "pearson": pearson_corr, 374 | "spearmanr": spearman_corr, 375 | "corr": (pearson_corr + spearman_corr) / 2, 376 | } 377 | 378 | 379 | def complex_metric(preds, labels, prefix=""): 380 | return { 381 | prefix + "bacc": balanced_accuracy_score(y_true=labels, y_pred=preds), 382 | prefix + "f1": f1_score(y_true=labels, y_pred=preds, average="micro") 383 | } 384 | 385 | 386 | def compute_metrics(task_name, preds, labels, prefix=""): 387 | assert len(preds) == len(labels) 388 | if task_name == "factcc_generated": 389 | return complex_metric(preds, labels, prefix) 390 | elif task_name == "factcc_annotated": 391 | return complex_metric(preds, labels, prefix) 392 | else: 393 | raise KeyError(task_name) 394 | 395 | processors = { 396 | "factcc_generated": FactCCGeneratedProcessor, 397 | "factcc_annotated": FactCCManualProcessor, 398 | } 399 | 400 | output_modes = { 401 | "factcc_generated": "classification", 402 | "factcc_annotated": "classification", 403 | } 404 | 405 | GLUE_TASKS_NUM_LABELS = { 406 | "factcc_generated": 2, 407 | "factcc_annotated": 2, 408 | } 409 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch_transformers==1.0.0 2 | scipy==1.3.0 3 | spacy==2.2.3 4 | torch==1.3.1 5 | numpy==1.17.4 6 | tqdm==4.32.1 7 | apex==0.9.10dev 8 | protobuf==3.11.2 9 | scikit_learn==0.22.1 10 | google-cloud-translate==1.6.0 11 | --------------------------------------------------------------------------------