├── .gitignore
├── CODEOWNERS
├── CONTRIBUTING-ARCHIVED.md
├── LICENSE.txt
├── README.md
├── data_generation
├── __init__.py
├── augmentation_ops.py
└── create_data.py
├── data_pairing
└── pair_data.py
├── model.jpg
├── modeling
├── __init__.py
├── model.py
├── run.py
├── scripts
│ ├── factcc-eval.sh
│ ├── factcc-finetune.sh
│ ├── factcc-train.sh
│ ├── factccx-eval.sh
│ ├── factccx-finetune.sh
│ └── factccx-train.sh
└── utils.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | # *.ipynb_checkpoints
77 | # */*.ipynb/*
78 | # **/*.ipynb/*
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 |
127 | ### MacOS
128 | .DS_Store
129 | #
130 | #
131 | **/gen_data/*
132 | # Other files
133 | .word_vectors_cache/
134 | **/wandb/*
135 | */wandb/*
136 |
137 | **/apex/*
138 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing.
2 | #ECCN:Open Source
3 |
--------------------------------------------------------------------------------
/CONTRIBUTING-ARCHIVED.md:
--------------------------------------------------------------------------------
1 | # ARCHIVED
2 |
3 | This project is `Archived` and is no longer actively maintained;
4 | We are not accepting contributions or Pull Requests.
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019, Salesforce.com, Inc.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 |
8 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9 |
10 | * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 |
12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Evaluating the Factual Consistency of Abstractive Text Summarization
2 | Authors: [Wojciech Kryściński](https://twitter.com/iam_wkr), [Bryan McCann](https://bmccann.github.io/), [Caiming Xiong](http://www.stat.ucla.edu/~caiming/), and [Richard Socher](https://www.socher.org/)
3 |
4 | ## Introduction
5 | Currently used metrics for assessing summarization algorithms do not account for whether summaries are factually consistent with source documents.
6 | We propose a weakly-supervised, model-based approach for verifying factual consistency and identifying conflicts between source documents and a generated summary.
7 | Training data is generated by applying a series of rule-based transformations to the sentences of source documents.
8 | The factual consistency model is then trained jointly for three tasks:
9 | 1) identify whether sentences remain factually consistent after transformation,
10 | 2) extract a span in the source documents to support the consistency prediction,
11 | 3) extract a span in the summary sentence that is inconsistent if one exists.
12 | Transferring this model to summaries generated by several state-of-the art models reveals that this highly scalable approach substantially outperforms previous models,
13 | including those trained with strong supervision using standard datasets for natural language inference and fact checking.
14 | Additionally, human evaluation shows that the auxiliary span extraction tasks provide useful assistance in the process of verifying factual consistency.
15 |
16 | Paper link: https://arxiv.org/abs/1910.12840
17 |
18 |

19 |
20 |
21 | ## Table of Contents
22 |
23 | 1. [Updates](#updates)
24 | 2. [Citation](#citation)
25 | 3. [License](#license)
26 | 4. [Usage](#usage)
27 | 5. [Get Involved](#get-involved)
28 |
29 | ## Updates
30 | #### 1/27/2020
31 | Updated manually annotated data files - fixed `filepaths` in misaligned examples.
32 |
33 | Updated model checkpoint files - recomputed evaluation metrics for fixed examples.
34 |
35 | ## Citation
36 | ```
37 | @article{kryscinskiFactCC2019,
38 | author = {Wojciech Kry{\'s}ci{\'n}ski and Bryan McCann and Caiming Xiong and Richard Socher},
39 | title = {Evaluating the Factual Consistency of Abstractive Text Summarization},
40 | journal = {arXiv preprint arXiv:1910.12840},
41 | year = {2019},
42 | }
43 | ```
44 |
45 |
46 | ## License
47 | The code is released under the BSD-3 License (see `LICENSE.txt` for details), but we also ask that users respect the following:
48 |
49 | This software should not be used to promote or profit from violence, hate, and division, environmental destruction, abuse of human rights,
50 | or the destruction of people's physical and mental health.
51 |
52 |
53 | ## Usage
54 | Code repository uses `Python 3`.
55 | Prior to running any scripts please make sure to install required Python packages listed in the `requirements.txt` file.
56 |
57 | Example call:
58 | `pip3 install -r requirements.txt`
59 |
60 | ### Training and Evaluation Datasets
61 | Generated training data can be found [here](https://storage.googleapis.com/sfr-factcc-data-research/unpaired_generated_data.tar.gz).
62 |
63 | Manually annotated validation and test data can be found [here](https://storage.googleapis.com/sfr-factcc-data-research/unpaired_annotated_data.tar.gz).
64 |
65 | Both generated and manually annotated datasets require pairing with the original CNN/DailyMail articles.
66 |
67 | To recreate the datasets follow the instructions:
68 | 1. Download CNN Stories and Daily Mail Stories from https://cs.nyu.edu/~kcho/DMQA/
69 | 2. Create a `cnndm` directory and unpack downloaded files into the directory
70 | 3. Download and unpack FactCC data _(do not rename directory)_
71 | 4. Run the `pair_data.py` script to pair the data with original articles
72 |
73 | Example call:
74 |
75 | `python3 data_pairing/pair_data.py `
76 |
77 | ### Generating Data
78 |
79 | Synthetic training data can be generated using code available in the `data_generation` directory.
80 |
81 | The data generation script expects the source documents input as one `jsonl` file, where each source document is embedded in a separate json object.
82 | The json object is required to contain an `id` key which stores an example id (uniqness is not required), and a `text` field that stores the text of the source document.
83 |
84 | Certain transformations rely on NER tagging, thus for best results use source documents with original (proper) casing.
85 |
86 |
87 | The following claim augmentations (transformations) are available:
88 | - `backtranslation` - Paraphrasing claim via backtranslation (requires Google Translate API key; costs apply)
89 | - `pronoun_swap` - Swapping a random pronoun in the claim
90 | - `date_swap` - Swapping random date/time found in the claim with one present in the source article
91 | - `number_swap` - Swapping random number found in the claim with one present in the source article
92 | - `entity_swap` - Swapping random entity name found in the claim with one present in the source article
93 | - `negation` - Negating meaning of the claim
94 | - `noise` - Injecting noise into the claim sentence
95 |
96 | For a detailed description of available transformations please refer to `Section 3.1` in the paper.
97 |
98 | To authenticate with the Google Cloud API follow [these](https://cloud.google.com/docs/authentication/getting-started) instructions.
99 |
100 | Example call:
101 |
102 | `python3 data_generation/create_data.py [--augmentations list-of-augmentations]`
103 |
104 | ### Model Code
105 |
106 | `FactCC` and `FactCCX` models can be trained or initialized from a checkpoint using code available in the `modeling` directory.
107 |
108 | Quickstart training, fine-tuning, and evaluation scripts are shared in the `scripts` directory.
109 | Before use make sure to update `*_PATH` variables with appropriate, absolute paths.
110 |
111 | To customize training or evaluation settings please refer to the flags in the `run.py` file.
112 |
113 | To utilize Weights&Biases dashboards login to the service using the following command: `wandb login `.
114 |
115 | Trained `FactCC` model checkpoint can be found [here](https://storage.googleapis.com/sfr-factcc-data-research/factcc-checkpoint.tar.gz).
116 |
117 | Trained `FactCCX` model checkpoint can be found [here](https://storage.googleapis.com/sfr-factcc-data-research/factccx-checkpoint.tar.gz).
118 |
119 | *IMPORTANT:* Due to data pre-processing, the first run of training or evaluation code on a large dataset can take up to a few hours before the actual procedure starts.
120 |
121 | #### Running on other data
122 | To run pretrained `FactCC` or `FactCCX` models on your data follow the instruction:
123 | 1. Download pre-trained model checkpoint, linked above
124 | 2. Prepare your data in `jsonl` format. Each example should be a separate `json` object with `id`, `text`, `claim` keys
125 | representing example id, source document, and claim sentence accordingly. Name file as `data-dev.jsonl`
126 | 3. Update corresponding `*-eval.sh` script
127 |
128 |
129 | ## Get Involved
130 |
131 | Please create a GitHub issue if you have any questions, suggestions, requests or bug-reports.
132 | We welcome PRs!
133 |
--------------------------------------------------------------------------------
/data_generation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/factCC/5daeba51dc63f9d3197df6945501a1153a77affd/data_generation/__init__.py
--------------------------------------------------------------------------------
/data_generation/augmentation_ops.py:
--------------------------------------------------------------------------------
1 | """
2 | Data augmentation (transformations) operations used to generate
3 | synthetic training data for the `FactCC` and `FactCCX` models.
4 | """
5 |
6 | import random
7 |
8 | import spacy
9 |
10 | from google.cloud import translate
11 |
12 |
13 | LABEL_MAP = {True: "CORRECT", False: "INCORRECT"}
14 |
15 |
16 | def align_ws(old_token, new_token):
17 | # Align trailing whitespaces between tokens
18 | if old_token[-1] == new_token[-1] == " ":
19 | return new_token
20 | elif old_token[-1] == " ":
21 | return new_token + " "
22 | elif new_token[-1] == " ":
23 | return new_token[:-1]
24 | else:
25 | return new_token
26 |
27 |
28 | def make_new_example(eid=None, text=None, claim=None, label=None, extraction_span=None,
29 | backtranslation=None, augmentation=None, augmentation_span=None, noise=None):
30 | # Embed example information in a json object.
31 | return {
32 | "id": eid,
33 | "text": text,
34 | "claim": claim,
35 | "label": label,
36 | "extraction_span": extraction_span,
37 | "backtranslation": backtranslation,
38 | "augmentation": augmentation,
39 | "augmentation_span": augmentation_span,
40 | "noise": noise
41 | }
42 |
43 |
44 | class Transformation():
45 | # Base class for all data transformations
46 |
47 | def __init__(self):
48 | # Spacy toolkit used for all NLP-related substeps
49 | self.spacy = spacy.load("en")
50 |
51 | def transform(self, example):
52 | # Function applies transformation on passed example
53 | pass
54 |
55 |
56 | class SampleSentences(Transformation):
57 | # Embed document as Spacy object and sample one sentence as claim
58 | def __init__(self, min_sent_len=8):
59 | super().__init__()
60 | self.min_sent_len = min_sent_len
61 |
62 | def transform(self, example):
63 | assert example["text"] is not None, "Text must be available"
64 |
65 | # split into sentences
66 | page_id = example["id"]
67 | page_text = example["text"].replace("\n", " ")
68 | page_doc = self.spacy(page_text, disable=["tagger"])
69 | sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len]
70 |
71 | # sample claim
72 | claim = random.choice(sents)
73 | new_example = make_new_example(eid=page_id, text=page_doc,
74 | claim=self.spacy(claim.text),
75 | label=LABEL_MAP[True],
76 | extraction_span=(claim.start, claim.end-1),
77 | backtranslation=False, noise=False)
78 | return new_example
79 |
80 |
81 | class NegateSentences(Transformation):
82 | # Apply or remove negation from negatable tokens
83 | def __init__(self):
84 | super().__init__()
85 | self.__negatable_tokens = ("are", "is", "was", "were", "have", "has", "had",
86 | "do", "does", "did", "can", "ca", "could", "may",
87 | "might", "must", "shall", "should", "will", "would")
88 |
89 | def transform(self, example):
90 | assert example["text"] is not None, "Text must be available"
91 | assert example["claim"] is not None, "Claim must be available"
92 |
93 | new_example = dict(example)
94 | new_claim, aug_span = self.__negate_sentences(new_example["claim"])
95 |
96 | if new_claim:
97 | new_example["claim"] = new_claim
98 | new_example["label"] = LABEL_MAP[False]
99 | new_example["augmentation"] = self.__class__.__name__
100 | new_example["augmentation_span"] = aug_span
101 | return new_example
102 | else:
103 | return None
104 |
105 | def __negate_sentences(self, claim):
106 | # find negatable token, return None if no candiates found
107 | candidate_tokens = [token for token in claim if token.text in self.__negatable_tokens]
108 |
109 | if not candidate_tokens:
110 | return None, None
111 |
112 | # choose random token to negate
113 | negated_token = random.choice(candidate_tokens)
114 | negated_ix = negated_token.i
115 | doc_len = len(claim)
116 |
117 | if negated_ix > 0:
118 | if claim[negated_ix - 1].text in self.__negatable_tokens:
119 | negated_token = claim[negated_ix - 1]
120 | negated_ix = negated_ix - 1
121 |
122 | # check whether token is negative
123 | is_negative = False
124 | if (doc_len - 1) > negated_ix:
125 | if claim[negated_ix + 1].text in ["not", "n't"]:
126 | is_negative = True
127 | elif claim[negated_ix + 1].text == "no":
128 | return None, None
129 |
130 |
131 | # negate token
132 | claim_tokens = [token.text_with_ws for token in claim]
133 | if is_negative:
134 | if claim[negated_ix + 1].text.lower() == "n't":
135 | if claim[negated_ix + 1].text.lower() == "ca":
136 | claim_tokens[negated_ix] = "can" if claim_tokens[negated_ix].islower() else "Can"
137 | claim_tokens[negated_ix] = claim_tokens[negated_ix] + " "
138 | claim_tokens.pop(negated_ix + 1)
139 | else:
140 | if claim[negated_ix].text.lower() in ["am", "may", "might", "must", "shall", "will"]:
141 | negation = "not "
142 | else:
143 | negation = random.choice(["not ", "n't "])
144 |
145 | if negation == "n't ":
146 | if claim[negated_ix].text.lower() == "can":
147 | claim_tokens[negated_ix] = "ca" if claim_tokens[negated_ix].islower() else "Ca"
148 | else:
149 | claim_tokens[negated_ix] = claim_tokens[negated_ix][:-1]
150 | claim_tokens.insert(negated_ix + 1, negation)
151 |
152 | # create new claim object
153 | new_claim = self.spacy("".join(claim_tokens))
154 | augmentation_span = (negated_ix, negated_ix if is_negative else negated_ix + 1)
155 |
156 | if new_claim.text == claim.text:
157 | return None, None
158 | else:
159 | return new_claim, augmentation_span
160 |
161 |
162 | class Backtranslation(Transformation):
163 | # Paraphrase sentence via backtranslation with Google Translate API
164 | # Requires API Key for Google Cloud SDK, additional charges DO apply
165 | def __init__(self, dst_lang=None):
166 | super().__init__()
167 |
168 | self.src_lang = "en"
169 | self.dst_lang = dst_lang
170 | self.accepted_langs = ["fr", "de", "zh-TW", "es", "ru"]
171 | self.translator = translate.Client()
172 |
173 | def transform(self, example):
174 | assert example["text"] is not None, "Text must be available"
175 | assert example["claim"] is not None, "Claim must be available"
176 |
177 | new_example = dict(example)
178 | new_claim, _ = self.__backtranslate(new_example["claim"])
179 |
180 | if new_claim:
181 | new_example["claim"] = new_claim
182 | new_example["backtranslation"] = True
183 | return new_example
184 | else:
185 | return None
186 |
187 | def __backtranslate(self, claim):
188 | # chose destination language, passed or random from list
189 | dst_lang = self.dst_lang if self.dst_lang else random.choice(self.accepted_langs)
190 |
191 | # translate to intermediate language and back
192 | claim_trans = self.translator.translate(claim.text, target_language=dst_lang, format_="text")
193 | claim_btrans = self.translator.translate(claim_trans["translatedText"], target_language=self.src_lang, format_="text")
194 |
195 | # create new claim object
196 | new_claim = self.spacy(claim_btrans["translatedText"])
197 | augmentation_span = (new_claim[0].i, new_claim[-1].i)
198 |
199 | if claim.text == new_claim.text:
200 | return None, None
201 | else:
202 | return new_claim, augmentation_span
203 |
204 |
205 | class PronounSwap(Transformation):
206 | # Swap randomly chosen pronoun
207 | def __init__(self, prob_swap=0.5):
208 | super().__init__()
209 |
210 | self.class2pronoun_map = {
211 | "SUBJECT": ["you", "he", "she", "we", "they"],
212 | "OBJECT": ["me", "you", "him", "her", "us", "them"],
213 | "POSSESSIVE": ["my", "your", "his", "her", "its", "out", "your", "their"],
214 | "REFLEXIVE": ["myself", "yourself", "himself", "itself", "outselves", "yourselves", "themselves"]
215 | }
216 |
217 | self.pronoun2class_map = {pronoun: key for (key, values) in self.class2pronoun_map.items() for pronoun in values}
218 | self.pronouns = {pronoun for (key, values) in self.class2pronoun_map.items() for pronoun in values}
219 |
220 | def transform(self, example):
221 | assert example["text"] is not None, "Text must be available"
222 | assert example["claim"] is not None, "Claim must be available"
223 |
224 | new_example = dict(example)
225 | new_claim, aug_span = self.__swap_pronouns(new_example["claim"])
226 |
227 | if new_claim:
228 | new_example["claim"] = new_claim
229 | new_example["label"] = LABEL_MAP[False]
230 | new_example["augmentation"] = self.__class__.__name__
231 | new_example["augmentation_span"] = aug_span
232 | return new_example
233 | else:
234 | return None
235 |
236 | def __swap_pronouns(self, claim):
237 | # find pronouns
238 | claim_pronouns = [token for token in claim if token.text.lower() in self.pronouns]
239 |
240 | if not claim_pronouns:
241 | return None, None
242 |
243 | # find pronoun replacement
244 | chosen_token = random.choice(claim_pronouns)
245 | chosen_ix = chosen_token.i
246 | chosen_class = self.pronoun2class_map[chosen_token.text.lower()]
247 |
248 | candidate_tokens = [token for token in self.class2pronoun_map[chosen_class] if token != chosen_token.text.lower()]
249 |
250 | if not candidate_tokens:
251 | return None, None
252 |
253 | # swap pronoun and update indices
254 | swapped_token = random.choice(candidate_tokens)
255 | swapped_token = align_ws(chosen_token.text_with_ws, swapped_token)
256 | swapped_token = swapped_token if chosen_token.text.islower() else swapped_token.capitalize()
257 |
258 | claim_tokens = [token.text_with_ws for token in claim]
259 | claim_tokens[chosen_ix] = swapped_token
260 |
261 | # create new claim object
262 | new_claim = self.spacy("".join(claim_tokens))
263 | augmentation_span = (chosen_ix, chosen_ix)
264 |
265 | if claim.text == new_claim.text:
266 | return None, None
267 | else:
268 | return new_claim, augmentation_span
269 |
270 |
271 | class NERSwap(Transformation):
272 | # Swap NER objects - parent class
273 | def __init__(self):
274 | super().__init__()
275 | self.categories = ()
276 |
277 | def transform(self, example):
278 | assert example["text"] is not None, "Text must be available"
279 | assert example["claim"] is not None, "Claim must be available"
280 |
281 | new_example = dict(example)
282 | new_claim, aug_span = self.__swap_entities(new_example["text"], new_example["claim"])
283 |
284 | if new_claim:
285 | new_example["claim"] = new_claim
286 | new_example["label"] = LABEL_MAP[False]
287 | new_example["augmentation"] = self.__class__.__name__
288 | new_example["augmentation_span"] = aug_span
289 | return new_example
290 | else:
291 | return None
292 |
293 | def __swap_entities(self, text, claim):
294 | # find entities in given category
295 | text_ents = [ent for ent in text.ents if ent.label_ in self.categories]
296 | claim_ents = [ent for ent in claim.ents if ent.label_ in self.categories]
297 |
298 | if not claim_ents or not text_ents:
299 | return None, None
300 |
301 | # choose entity to replace and find possible replacement in source
302 | replaced_ent = random.choice(claim_ents)
303 | candidate_ents = [ent for ent in text_ents if ent.text != replaced_ent.text and ent.text not in replaced_ent.text and replaced_ent.text not in ent.text]
304 |
305 | if not candidate_ents:
306 | return None, None
307 |
308 | # update claim and indices
309 | swapped_ent = random.choice(candidate_ents)
310 | claim_tokens = [token.text_with_ws for token in claim]
311 | swapped_token = align_ws(replaced_ent.text_with_ws, swapped_ent.text_with_ws)
312 | claim_swapped = claim_tokens[:replaced_ent.start] + [swapped_token] + claim_tokens[replaced_ent.end:]
313 |
314 | # create new claim object
315 | new_claim = self.spacy("".join(claim_swapped))
316 | augmentation_span = (replaced_ent.start, replaced_ent.start + len(swapped_ent) - 1)
317 |
318 | if new_claim.text == claim.text:
319 | return None, None
320 | else:
321 | return new_claim, augmentation_span
322 |
323 |
324 | class EntitySwap(NERSwap):
325 | # NER swapping class specialized for entities (people, companies, locations, etc.)
326 | def __init__(self):
327 | super().__init__()
328 | self.categories = ("PERSON", "ORG", "NORP", "FAC", "GPE", "LOC", "PRODUCT",
329 | "WORK_OF_ART", "EVENT")
330 |
331 |
332 | class NumberSwap(NERSwap):
333 | # NER swapping class specialized for numbers (excluding dates)
334 | def __init__(self):
335 | super().__init__()
336 |
337 | self.categories = ("PERCENT", "MONEY", "QUANTITY", "CARDINAL")
338 |
339 |
340 | class DateSwap(NERSwap):
341 | # NER swapping class specialized for dates and time
342 | def __init__(self):
343 | super().__init__()
344 |
345 | self.categories = ("DATE", "TIME")
346 |
347 |
348 | class AddNoise(Transformation):
349 | # Inject noise into claims
350 | def __init__(self, noise_prob=0.05, delete_prob=0.8):
351 | super().__init__()
352 |
353 | self.noise_prob = noise_prob
354 | self.delete_prob = delete_prob
355 | self.spacy = spacy.load("en")
356 |
357 | def transform(self, example):
358 | assert example["text"] is not None, "Text must be available"
359 | assert example["claim"] is not None, "Claim must be available"
360 |
361 | new_example = dict(example)
362 | claim = new_example["claim"]
363 | aug_span = new_example["augmentation_span"]
364 | new_claim, aug_span = self.__add_noise(claim, aug_span)
365 |
366 | if new_claim:
367 | new_example["claim"] = new_claim
368 | new_example["augmentation_span"] = aug_span
369 | new_example["noise"] = True
370 | return new_example
371 | else:
372 | return None
373 |
374 | def __add_noise(self, claim, aug_span):
375 | claim_tokens = [token.text_with_ws for token in claim]
376 |
377 | new_claim = []
378 | for ix, token in enumerate(claim_tokens):
379 | # don't modify text inside an augmented span
380 | apply_augmentation = True
381 | if aug_span:
382 | span_start, span_end = aug_span
383 | if span_start <= ix <= span_end:
384 | apply_augmentation = False
385 |
386 | # decide whether to add noise
387 | if apply_augmentation and random.random() < self.noise_prob:
388 | # decide whether to replicate or delete token
389 | if random.random() < self.delete_prob:
390 | # update spans and skip token
391 | if aug_span:
392 | span_start, span_end = aug_span
393 | if ix < span_start:
394 | span_start -= 1
395 | span_end -= 1
396 | aug_span = span_start, span_end
397 | if len(new_claim) > 0:
398 | if new_claim[-1][-1] != " ":
399 | new_claim[-1] = new_claim[-1] + " "
400 | continue
401 | else:
402 | if aug_span:
403 | span_start, span_end = aug_span
404 | if ix < span_start:
405 | span_start += 1
406 | span_end += 1
407 | aug_span = span_start, span_end
408 | new_claim.append(token)
409 | new_claim.append(token)
410 | new_claim = self.spacy("".join(new_claim))
411 |
412 | if claim.text == new_claim.text:
413 | return None, None
414 | else:
415 | return new_claim, aug_span
416 |
--------------------------------------------------------------------------------
/data_generation/create_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for generating synthetic data for FactCC training.
3 |
4 | Script expects source documents in `jsonl` format with each source document
5 | embedded in a separate json object.
6 |
7 | Json objects are required to contain `id` and `text` keys.
8 | """
9 |
10 | import argparse
11 | import json
12 | import os
13 |
14 | from tqdm import tqdm
15 |
16 | import augmentation_ops as ops
17 |
18 |
19 |
20 | def load_source_docs(file_path, to_dict=False):
21 | with open(file_path, encoding="utf-8") as f:
22 | data = [json.loads(line) for line in f]
23 |
24 | if to_dict:
25 | data = {example["id"]: example for example in data}
26 | return data
27 |
28 |
29 | def save_data(args, data, name_suffix):
30 | output_file = os.path.splitext(args.data_file)[0] + "-" + name_suffix + ".jsonl"
31 |
32 | with open(output_file, "w", encoding="utf-8") as fd:
33 | for example in data:
34 | example = dict(example)
35 | example["text"] = example["text"].text
36 | example["claim"] = example["claim"].text
37 | fd.write(json.dumps(example, ensure_ascii=False) + "\n")
38 |
39 |
40 | def apply_transformation(data, operation):
41 | new_data = []
42 | for example in tqdm(data):
43 | try:
44 | new_example = operation.transform(example)
45 | if new_example:
46 | new_data.append(new_example)
47 | except Exception as e:
48 | print("Caught exception:", e)
49 | return new_data
50 |
51 |
52 | def main(args):
53 | # load data
54 | source_docs = load_source_docs(args.data_file, to_dict=False)
55 | print("Loaded %d source documents." % len(source_docs))
56 |
57 | # create or load positive examples
58 | print("Creating data examples")
59 | sclaims_op = ops.SampleSentences()
60 | data = apply_transformation(source_docs, sclaims_op)
61 | print("Created %s example pairs." % len(data))
62 |
63 | if args.save_intermediate:
64 | save_data(args, data, "clean")
65 |
66 | # backtranslate
67 | data_btrans = []
68 | if not args.augmentations or "backtranslation" in args.augmentations:
69 | print("Creating backtranslation examples")
70 | btrans_op = ops.Backtranslation()
71 | data_btrans = apply_transformation(data, btrans_op)
72 | print("Backtranslated %s example pairs." % len(data_btrans))
73 |
74 | if args.save_intermediate:
75 | save_data(args, data_btrans, "btrans")
76 |
77 | data_positive = data + data_btrans
78 | save_data(args, data_positive, "positive")
79 |
80 | # create negative examples
81 | data_pronoun = []
82 | if not args.augmentations or "pronoun_swap" in args.augmentations:
83 | print("Creating pronoun examples")
84 | pronoun_op = ops.PronounSwap()
85 | data_pronoun = apply_transformation(data_positive, pronoun_op)
86 | print("PronounSwap %s example pairs." % len(data_pronoun))
87 |
88 | if args.save_intermediate:
89 | save_data(args, data_pronoun, "pronoun")
90 |
91 | data_dateswp = []
92 | if not args.augmentations or "date_swap" in args.augmentations:
93 | print("Creating date swap examples")
94 | dateswap_op = ops.DateSwap()
95 | data_dateswp = apply_transformation(data_positive, dateswap_op)
96 | print("DateSwap %s example pairs." % len(data_dateswp))
97 |
98 | if args.save_intermediate:
99 | save_data(args, data_dateswp, "dateswp")
100 |
101 | data_numswp = []
102 | if not args.augmentations or "number_swap" in args.augmentations:
103 | print("Creating number swap examples")
104 | numswap_op = ops.NumberSwap()
105 | data_numswp = apply_transformation(data_positive, numswap_op)
106 | print("NumberSwap %s example pairs." % len(data_numswp))
107 |
108 | if args.save_intermediate:
109 | save_data(args, data_numswp, "numswp")
110 |
111 | data_entswp = []
112 | if not args.augmentations or "entity_swap" in args.augmentations:
113 | print("Creating entity swap examples")
114 | entswap_op = ops.EntitySwap()
115 | data_entswp = apply_transformation(data_positive, entswap_op)
116 | print("EntitySwap %s example pairs." % len(data_entswp))
117 |
118 | if args.save_intermediate:
119 | save_data(args, data_entswp, "entswp")
120 |
121 | data_negation = []
122 | if not args.augmentations or "negation" in args.augmentations:
123 | print("Creating negation examples")
124 | negation_op = ops.NegateSentences()
125 | data_negation = apply_transformation(data_positive, negation_op)
126 | print("Negation %s example pairs." % len(data_negation))
127 |
128 | if args.save_intermediate:
129 | save_data(args, data_negation, "negation")
130 |
131 | # add noise to all
132 | data_negative = data_pronoun + data_dateswp + data_numswp + data_entswp + data_negation
133 | save_data(args, data_negative, "negative")
134 |
135 | # ADD NOISE
136 | data_pos_low_noise = []
137 | data_neg_low_noise = []
138 |
139 | if not args.augmentations or "noise" in args.augmentations:
140 | # add light noise
141 | print("Adding light noise to data")
142 | low_noise_op = ops.AddNoise()
143 |
144 | data_pos_low_noise = apply_transformation(data_positive, low_noise_op)
145 | print("PositiveNoisy %s example pairs." % len(data_pos_low_noise))
146 | save_data(args, data_pos_low_noise, "positive-noise")
147 |
148 | data_neg_low_noise = apply_transformation(data_negative, low_noise_op)
149 | print("NegativeNoisy %s example pairs." % len(data_neg_low_noise))
150 | save_data(args, data_neg_low_noise, "negative-noise")
151 |
152 |
153 | if __name__ == "__main__":
154 | PARSER = argparse.ArgumentParser()
155 | PARSER.add_argument("data_file", type=str, help="Path to file containing source documents.")
156 | PARSER.add_argument("--augmentations", type=str, nargs="+", default=(), help="List of data augmentation applied to data.")
157 | PARSER.add_argument("--all_augmentations", action="store_true", help="Flag whether all augmentation should be applied.")
158 | PARSER.add_argument("--save_intermediate", action="store_true", help="Flag whether intermediate data from each transformation should be saved in separate files.")
159 | ARGS = PARSER.parse_args()
160 | main(ARGS)
161 |
--------------------------------------------------------------------------------
/data_pairing/pair_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for recreating the FactCC dataset from CNN/DM Story files.
3 |
4 | CNN/DM Story files can be downloaded from https://cs.nyu.edu/~kcho/DMQA/
5 |
6 | Unpaired FactCC data should be downloaded and unpacked.
7 | CNN/DM data to be stored in a `cnndm` directory with `cnn` and `dm` sub-directories.
8 | """
9 | import argparse
10 | import json
11 | import os
12 |
13 | from tqdm import tqdm
14 |
15 |
16 | def parse_story_file(content):
17 | """
18 | Remove article highlights and unnecessary white characters.
19 | """
20 | content_raw = content.split("@highlight")[0]
21 | content = " ".join(filter(None, [x.strip() for x in content_raw.split("\n")]))
22 | return content
23 |
24 |
25 | def main(args):
26 | """
27 | Walk data sub-directories and recreate examples
28 | """
29 | for path, _, filenames in os.walk(args.unpaired_data):
30 | for filename in filenames:
31 | if not ".jsonl" in filename:
32 | continue
33 |
34 | unpaired_path = os.path.join(path, filename)
35 | print("Processing file:", unpaired_path)
36 |
37 | with open(unpaired_path) as fd:
38 | dataset = [json.loads(line) for line in fd]
39 |
40 | for example in tqdm(dataset):
41 | story_path = os.path.join(args.story_files, example["filepath"])
42 |
43 | with open(story_path) as fd:
44 | story_content = fd.read()
45 | example["text"] = parse_story_file(story_content)
46 |
47 | paired_path = unpaired_path.replace("unpaired_", "")
48 | os.makedirs(os.path.dirname(paired_path), exist_ok=True)
49 | with open(paired_path, "w") as fd:
50 | for example in dataset:
51 | fd.write(json.dumps(example, ensure_ascii=False) + "\n")
52 |
53 |
54 | if __name__ == "__main__":
55 | PARSER = argparse.ArgumentParser()
56 | PARSER.add_argument("unpaired_data", type=str, help="Path to directory holding unpaired data")
57 | PARSER.add_argument("story_files", type=str, help="Path to directory holding CNNDM story files")
58 | ARGS = PARSER.parse_args()
59 | main(ARGS)
60 |
--------------------------------------------------------------------------------
/model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/factCC/5daeba51dc63f9d3197df6945501a1153a77affd/model.jpg
--------------------------------------------------------------------------------
/modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/factCC/5daeba51dc63f9d3197df6945501a1153a77affd/modeling/__init__.py
--------------------------------------------------------------------------------
/modeling/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Salesforce.com, Inc.
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from __future__ import absolute_import, division, print_function, unicode_literals
18 |
19 | import json
20 | import logging
21 | import math
22 | import os
23 | import sys
24 | from io import open
25 |
26 | import torch
27 | from torch import nn
28 | from torch.nn import CrossEntropyLoss, MSELoss
29 |
30 |
31 | from pytorch_transformers.modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig,
32 | PreTrainedModel, prune_linear_layer, add_start_docstrings)
33 | from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertModel, BertLayer, BertPooler
34 |
35 |
36 | class BertPointer(BertPreTrainedModel):
37 | def __init__(self, config):
38 | super(BertPointer, self).__init__(config)
39 | self.num_labels = config.num_labels
40 |
41 | self.bert = BertModel(config)
42 | self.dropout = nn.Dropout(config.hidden_dropout_prob)
43 |
44 | # classifiers
45 | self.ext_start_classifier = nn.Linear(config.hidden_size, 1, bias=False)
46 | self.ext_end_classifier = nn.Linear(config.hidden_size, 1, bias=False)
47 | self.aug_start_classifier = nn.Linear(config.hidden_size, 1, bias=False)
48 | self.aug_end_classifier = nn.Linear(config.hidden_size, 1, bias=False)
49 |
50 | self.label_classifier = nn.Linear(config.hidden_size, self.config.num_labels)
51 |
52 | self.apply(self.init_weights)
53 |
54 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
55 | position_ids=None, head_mask=None,
56 | ext_mask=None, ext_start_labels=None, ext_end_labels=None,
57 | aug_mask=None, aug_start_labels=None, aug_end_labels=None,
58 | loss_lambda=1.0):
59 | # run through bert
60 | bert_outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
61 | attention_mask=attention_mask, head_mask=head_mask)
62 |
63 | # label classifier
64 | pooled_output = bert_outputs[1]
65 | pooled_output = self.dropout(pooled_output)
66 | label_logits = self.label_classifier(pooled_output)
67 |
68 | # extraction classifier
69 | output = bert_outputs[0]
70 | ext_mask = ext_mask.unsqueeze(-1)
71 | ext_start_logits = self.ext_start_classifier(output) * ext_mask
72 | ext_end_logits = self.ext_end_classifier(output) * ext_mask
73 |
74 | # augmentation classifier
75 | output = bert_outputs[0]
76 | aug_mask = aug_mask.unsqueeze(-1)
77 | aug_start_logits = self.aug_start_classifier(output) * aug_mask
78 | aug_end_logits = self.aug_end_classifier(output) * aug_mask
79 |
80 | span_logits = (ext_start_logits, ext_end_logits, aug_start_logits, aug_end_logits,)
81 | outputs = (label_logits,) + span_logits + bert_outputs[2:]
82 |
83 | if labels is not None and \
84 | ext_start_labels is not None and ext_end_labels is not None and \
85 | aug_start_labels is not None and aug_end_labels is not None:
86 | if self.num_labels == 1:
87 | # We are doing regression
88 | loss_fct = MSELoss()
89 | loss = loss_fct(label_logits.view(-1), labels.view(-1))
90 | else:
91 | loss_fct = CrossEntropyLoss()
92 |
93 | # label loss
94 | labels_loss = loss_fct(label_logits.view(-1, self.num_labels), labels.view(-1))
95 |
96 | # extraction loss
97 | ext_start_loss = loss_fct(ext_start_logits.squeeze(), ext_start_labels)
98 | ext_end_loss = loss_fct(ext_end_logits.squeeze(), ext_end_labels)
99 |
100 | # augmentation loss
101 | aug_start_loss = loss_fct(aug_start_logits.squeeze(), aug_start_labels)
102 | aug_end_loss = loss_fct(aug_end_logits.squeeze(), aug_end_labels)
103 |
104 | span_loss = (ext_start_loss + ext_end_loss + aug_start_loss + aug_end_loss) / 4
105 |
106 | # combined loss
107 | loss = labels_loss + loss_lambda * span_loss
108 |
109 | outputs = (loss, labels_loss, span_loss, ext_start_loss, ext_end_loss, aug_start_loss, aug_end_loss) + outputs
110 |
111 | return outputs # (loss), (logits), (hidden_states), (attentions)
112 |
--------------------------------------------------------------------------------
/modeling/run.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Salesforce.com, Inc.
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet)."""
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import argparse
21 | import glob
22 | import logging
23 | import os
24 | import random
25 |
26 | import wandb
27 | import numpy as np
28 | import torch
29 |
30 | from model import BertPointer
31 | from utils import (compute_metrics, convert_examples_to_features, output_modes, processors)
32 |
33 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
34 | from torch.utils.data.distributed import DistributedSampler
35 | from tqdm import tqdm, trange
36 | from pytorch_transformers import (WEIGHTS_NAME, BertConfig, BertForSequenceClassification, BertTokenizer)
37 |
38 | from pytorch_transformers import AdamW, WarmupLinearSchedule
39 |
40 | logger = logging.getLogger(__name__)
41 | wandb.init(project="entailment-metric")
42 |
43 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), ())
44 |
45 | MODEL_CLASSES = {
46 | 'pbert': (BertConfig, BertPointer, BertTokenizer),
47 | 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
48 | }
49 |
50 |
51 | def set_seed(args):
52 | random.seed(args.seed)
53 | np.random.seed(args.seed)
54 | torch.manual_seed(args.seed)
55 | if args.n_gpu > 0:
56 | torch.cuda.manual_seed_all(args.seed)
57 |
58 |
59 | def make_model_input(args, batch):
60 | inputs = {'input_ids': batch[0],
61 | 'attention_mask': batch[1],
62 | 'token_type_ids': batch[2],
63 | 'labels': batch[3]}
64 |
65 | # add extraction and augmentation spans for PointerBert model
66 | if args.model_type == "pbert":
67 | inputs["ext_mask"] = batch[4]
68 | inputs["ext_start_labels"] = batch[5]
69 | inputs["ext_end_labels"] = batch[6]
70 | inputs["aug_mask"] = batch[7]
71 | inputs["aug_start_labels"] = batch[8]
72 | inputs["aug_end_labels"] = batch[9]
73 | inputs["loss_lambda"] = args.loss_lambda
74 |
75 | return inputs
76 |
77 |
78 | def train(args, train_dataset, model, tokenizer):
79 | """ Train the model """
80 |
81 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
82 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
83 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
84 |
85 | if args.max_steps > 0:
86 | t_total = args.max_steps
87 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
88 | else:
89 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
90 |
91 | # Prepare optimizer and schedule (linear warmup and decay)
92 | no_decay = ['bias', 'LayerNorm.weight']
93 | optimizer_grouped_parameters = [
94 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
95 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
96 | ]
97 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
98 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
99 | if args.fp16:
100 | try:
101 | from apex import amp
102 | except ImportError:
103 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
104 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
105 |
106 | # multi-gpu training (should be after apex fp16 initialization)
107 | if args.n_gpu > 1:
108 | model = torch.nn.DataParallel(model)
109 |
110 | # Distributed training (should be after apex fp16 initialization)
111 | if args.local_rank != -1:
112 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
113 | output_device=args.local_rank,
114 | find_unused_parameters=True)
115 |
116 | # Train!
117 | logger.info("***** Running training *****")
118 | logger.info(" Num examples = %d", len(train_dataset))
119 | logger.info(" Num Epochs = %d", args.num_train_epochs)
120 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
121 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
122 | args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
123 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
124 | logger.info(" Total optimization steps = %d", t_total)
125 |
126 | global_step = 0
127 | tr_loss, logging_loss = 0.0, 0.0
128 | model.zero_grad()
129 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
130 | set_seed(args) # Added here for reproductibility (even between python 2 and 3)
131 |
132 | for epoch_ix in train_iterator:
133 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
134 | for step, batch in enumerate(epoch_iterator):
135 | model.train()
136 | batch = tuple(t.to(args.device) for t in batch)
137 | inputs = make_model_input(args, batch)
138 | outputs = model(**inputs)
139 | loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
140 |
141 | if args.n_gpu > 1:
142 | loss = loss.mean() # mean() to average on multi-gpu parallel training
143 | if args.gradient_accumulation_steps > 1:
144 | loss = loss / args.gradient_accumulation_steps
145 |
146 | if args.fp16:
147 | with amp.scale_loss(loss, optimizer) as scaled_loss:
148 | scaled_loss.backward()
149 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
150 | else:
151 | loss.backward()
152 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
153 |
154 | tr_loss += loss.item()
155 | if (step + 1) % args.gradient_accumulation_steps == 0:
156 | scheduler.step() # Update learning rate schedule
157 | optimizer.step()
158 | model.zero_grad()
159 | global_step += 1
160 |
161 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
162 | results = {}
163 |
164 | # Log metrics
165 | logits_ix = 1 if args.model_type == "bert" else 7
166 | logits = outputs[logits_ix]
167 | preds = np.argmax(logits.detach().cpu().numpy(), axis=1)
168 | out_label_ids = inputs['labels'].detach().cpu().numpy()
169 |
170 | result = compute_metrics(args.task_name, preds, out_label_ids)
171 | results.update(result)
172 |
173 | for key, value in results.items():
174 | wandb.log({'train_{}'.format(key): value})
175 |
176 | wandb.log({"train_loss": (tr_loss - logging_loss) / args.logging_steps})
177 | wandb.log({"train_lr": scheduler.get_lr()[0]})
178 | logging_loss = tr_loss
179 |
180 | if args.max_steps > 0 and global_step > args.max_steps:
181 | epoch_iterator.close()
182 | break
183 |
184 | if args.evaluate_during_training:
185 | # Only evaluate when single GPU otherwise metrics may not average well
186 | results = evaluate(args, model, tokenizer)
187 | for key, value in results.items():
188 | wandb.log({'eval_{}'.format(key): value})
189 |
190 | if args.local_rank in [-1, 0]:
191 | # Save model checkpoint
192 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(epoch_ix))
193 | if not os.path.exists(output_dir):
194 | os.makedirs(output_dir)
195 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
196 | model_to_save.save_pretrained(output_dir)
197 | torch.save(args, os.path.join(output_dir, 'training_args.bin'))
198 | logger.info("Saving model checkpoint to %s", output_dir)
199 |
200 | if args.max_steps > 0 and global_step > args.max_steps:
201 | train_iterator.close()
202 | break
203 |
204 | return global_step, tr_loss / global_step
205 |
206 |
207 | def evaluate(args, model, tokenizer, prefix=""):
208 | # Loop to handle MNLI double evaluation (matched, mis-matched)
209 | eval_task_names = (args.task_name,)
210 | eval_outputs_dirs = (args.output_dir,)
211 |
212 | results = {}
213 | cnt = 0
214 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
215 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
216 |
217 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
218 | os.makedirs(eval_output_dir)
219 |
220 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
221 | # Note that DistributedSampler samples randomly
222 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
223 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
224 |
225 | # Eval!
226 | logger.info("***** Running evaluation {} *****".format(prefix))
227 | logger.info(" Num examples = %d", len(eval_dataset))
228 | logger.info(" Batch size = %d", args.eval_batch_size)
229 |
230 | eval_loss = 0.0
231 | nb_eval_steps = 0
232 | preds = None
233 | out_label_ids = None
234 |
235 | for batch in tqdm(eval_dataloader, desc="Evaluating"):
236 | model.eval()
237 | batch = tuple(t.to(args.device) for t in batch)
238 |
239 | with torch.no_grad():
240 | inputs = make_model_input(args, batch)
241 | outputs = model(**inputs)
242 |
243 | # monitoring
244 | tmp_eval_loss = outputs[0]
245 | logits_ix = 1 if args.model_type == "bert" else 7
246 | logits = outputs[logits_ix]
247 | eval_loss += tmp_eval_loss.mean().item()
248 | nb_eval_steps += 1
249 |
250 | if preds is None:
251 | preds = logits.detach().cpu().numpy()
252 | out_label_ids = inputs['labels'].detach().cpu().numpy()
253 | else:
254 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
255 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
256 |
257 | preds = np.argmax(preds, axis=1)
258 | result = compute_metrics(args.task_name, preds, out_label_ids)
259 | eval_loss = eval_loss / nb_eval_steps
260 | result["loss"] = eval_loss
261 | results.update(result)
262 |
263 | output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
264 | with open(output_eval_file, "w") as writer:
265 | logger.info("***** Eval results {} *****".format(prefix))
266 | for key in sorted(results.keys()):
267 | logger.info(" %s = %s", key, str(result[key]))
268 | writer.write("%s = %s\n" % (key, str(result[key])))
269 |
270 | return results
271 |
272 |
273 | def load_and_cache_examples(args, task, tokenizer, evaluate=False):
274 | if args.local_rank not in [-1, 0]:
275 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
276 |
277 | processor = processors[task]()
278 | output_mode = output_modes[task]
279 | # Load data features from cache or dataset file
280 | cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
281 | 'dev' if evaluate else 'train',
282 | list(filter(None, args.model_name_or_path.split('/'))).pop(),
283 | str(args.max_seq_length),
284 | str(task)))
285 | if os.path.exists(cached_features_file):
286 | logger.info("Loading features from cached file %s", cached_features_file)
287 | features = torch.load(cached_features_file)
288 | else:
289 | logger.info("Creating features from dataset file at %s", args.data_dir)
290 | label_list = processor.get_labels()
291 | examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
292 | features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
293 | cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
294 | cls_token=tokenizer.cls_token,
295 | cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
296 | sep_token=tokenizer.sep_token,
297 | sep_token_extra=bool(args.model_type in ['roberta']),
298 | pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
299 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
300 | pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)
301 | if args.local_rank in [-1, 0]:
302 | logger.info("Saving features into cached file %s", cached_features_file)
303 | torch.save(features, cached_features_file)
304 |
305 | if args.local_rank == 0:
306 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
307 |
308 | # Convert to Tensors and build dataset
309 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
310 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
311 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
312 | all_ext_mask = torch.tensor([f.extraction_mask for f in features], dtype=torch.float)
313 | all_ext_start_ids = torch.tensor([f.extraction_start_ids for f in features], dtype=torch.long)
314 | all_ext_end_ids = torch.tensor([f.extraction_end_ids for f in features], dtype=torch.long)
315 | all_aug_mask = torch.tensor([f.augmentation_mask for f in features], dtype=torch.float)
316 | all_aug_start_ids = torch.tensor([f.augmentation_start_ids for f in features], dtype=torch.long)
317 | all_aug_end_ids = torch.tensor([f.augmentation_end_ids for f in features], dtype=torch.long)
318 |
319 | if output_mode == "classification":
320 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
321 | elif output_mode == "regression":
322 | all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)
323 |
324 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,
325 | all_ext_mask, all_ext_start_ids, all_ext_end_ids,
326 | all_aug_mask, all_aug_start_ids, all_aug_end_ids)
327 | return dataset
328 |
329 |
330 | def main():
331 | parser = argparse.ArgumentParser()
332 |
333 | ## Required parameters
334 | parser.add_argument("--data_dir", default=None, type=str, required=True,
335 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
336 | parser.add_argument("--model_type", default=None, type=str, required=True,
337 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
338 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
339 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
340 | parser.add_argument("--task_name", default=None, type=str, required=True,
341 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
342 | parser.add_argument("--output_dir", default=None, type=str, required=True,
343 | help="The output directory where the model predictions and checkpoints will be written.")
344 | parser.add_argument("--train_from_scratch", action='store_true',
345 | help="Whether to run training without loading pretrained weights.")
346 |
347 | ## Other parameters
348 | parser.add_argument("--config_name", default="", type=str,
349 | help="Pretrained config name or path if not the same as model_name")
350 | parser.add_argument("--tokenizer_name", default="", type=str,
351 | help="Pretrained tokenizer name or path if not the same as model_name")
352 | parser.add_argument("--cache_dir", default="", type=str,
353 | help="Where do you want to store the pre-trained models downloaded from s3")
354 | parser.add_argument("--max_seq_length", default=512, type=int,
355 | help="The maximum total input sequence length after tokenization. Sequences longer "
356 | "than this will be truncated, sequences shorter will be padded.")
357 | parser.add_argument("--do_train", action='store_true',
358 | help="Whether to run training.")
359 | parser.add_argument("--do_eval", action='store_true',
360 | help="Whether to run eval on the dev set.")
361 | parser.add_argument("--evaluate_during_training", action='store_true',
362 | help="Run evaluation during training at each logging step.")
363 | parser.add_argument("--do_lower_case", action='store_true',
364 | help="Set this flag if you are using an uncased model.")
365 |
366 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
367 | help="Batch size per GPU/CPU for training.")
368 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
369 | help="Batch size per GPU/CPU for evaluation.")
370 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
371 | help="Number of updates steps to accumulate before performing a backward/update pass.")
372 | parser.add_argument("--learning_rate", default=5e-5, type=float,
373 | help="The initial learning rate for Adam.")
374 | parser.add_argument("--loss_lambda", default=0.1, type=float,
375 | help="The lambda parameter for loss mixing.")
376 | parser.add_argument("--weight_decay", default=0.0, type=float,
377 | help="Weight deay if we apply some.")
378 | parser.add_argument("--adam_epsilon", default=1e-8, type=float,
379 | help="Epsilon for Adam optimizer.")
380 | parser.add_argument("--max_grad_norm", default=1.0, type=float,
381 | help="Max gradient norm.")
382 | parser.add_argument("--num_train_epochs", default=3.0, type=float,
383 | help="Total number of training epochs to perform.")
384 | parser.add_argument("--max_steps", default=-1, type=int,
385 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
386 | parser.add_argument("--warmup_steps", default=0, type=int,
387 | help="Linear warmup over warmup_steps.")
388 |
389 | parser.add_argument('--logging_steps', type=int, default=100,
390 | help="Log every X updates steps.")
391 | parser.add_argument('--save_steps', type=int, default=50,
392 | help="Save checkpoint every X updates steps.")
393 | parser.add_argument("--eval_all_checkpoints", action='store_true',
394 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
395 | parser.add_argument("--no_cuda", action='store_true',
396 | help="Avoid using CUDA when available")
397 | parser.add_argument('--overwrite_output_dir', action='store_true',
398 | help="Overwrite the content of the output directory")
399 | parser.add_argument('--overwrite_cache', action='store_true',
400 | help="Overwrite the cached training and evaluation sets")
401 | parser.add_argument('--seed', type=int, default=42,
402 | help="random seed for initialization")
403 |
404 | parser.add_argument('--fp16', action='store_true',
405 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
406 | parser.add_argument('--fp16_opt_level', type=str, default='O1',
407 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
408 | "See details at https://nvidia.github.io/apex/amp.html")
409 | parser.add_argument("--local_rank", type=int, default=-1,
410 | help="For distributed training: local_rank")
411 | args = parser.parse_args()
412 |
413 | if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
414 | raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
415 |
416 | # Setup CUDA, GPU & distributed training
417 | if args.local_rank == -1 or args.no_cuda:
418 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
419 | args.n_gpu = torch.cuda.device_count()
420 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
421 | torch.cuda.set_device(args.local_rank)
422 | device = torch.device("cuda", args.local_rank)
423 | torch.distributed.init_process_group(backend='nccl')
424 | args.n_gpu = 1
425 | args.device = device
426 |
427 | # Setup logging
428 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
429 | datefmt = '%m/%d/%Y %H:%M:%S',
430 | level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
431 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
432 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
433 |
434 | # Set seed
435 | set_seed(args)
436 |
437 | # Prepare GLUE task
438 | args.task_name = args.task_name.lower()
439 | if args.task_name not in processors:
440 | raise ValueError("Task not found: %s" % (args.task_name))
441 | processor = processors[args.task_name]()
442 | args.output_mode = output_modes[args.task_name]
443 | label_list = processor.get_labels()
444 | num_labels = len(label_list)
445 |
446 | # Load pretrained model and tokenizer
447 | if args.local_rank not in [-1, 0]:
448 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
449 |
450 | args.model_type = args.model_type.lower()
451 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
452 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
453 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
454 | if args.train_from_scratch:
455 | logger.info("Training model from scratch.")
456 | model = model_class(config=config)
457 | else:
458 | logger.info("Loading model from checkpoint.")
459 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
460 |
461 | if args.local_rank == 0:
462 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
463 |
464 | model.to(args.device)
465 |
466 | wandb.watch(model)
467 | logger.info("Training/evaluation parameters %s", args)
468 |
469 |
470 | # Training
471 | if args.do_train:
472 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
473 | global_step, tr_loss = train(args, train_dataset, model, tokenizer)
474 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
475 |
476 |
477 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
478 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
479 | # Create output directory if needed
480 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
481 | os.makedirs(args.output_dir)
482 |
483 | logger.info("Saving model checkpoint to %s", args.output_dir)
484 | # Save a trained model, configuration and tokenizer using `save_pretrained()`.
485 | # They can then be reloaded using `from_pretrained()`
486 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
487 | model_to_save.save_pretrained(args.output_dir)
488 | tokenizer.save_pretrained(args.output_dir)
489 |
490 | # Good practice: save your training arguments together with the trained model
491 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
492 |
493 | # Load a trained model and vocabulary that you have fine-tuned
494 | model = model_class.from_pretrained(args.output_dir)
495 | tokenizer = tokenizer_class.from_pretrained(args.output_dir)
496 | model.to(args.device)
497 |
498 |
499 | # Evaluation
500 | results = {}
501 | if args.do_eval and args.local_rank in [-1, 0]:
502 | checkpoints = [args.output_dir]
503 | if args.eval_all_checkpoints:
504 | checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
505 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
506 | logger.info("Evaluate the following checkpoints: %s", checkpoints)
507 | for checkpoint in checkpoints:
508 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
509 | model = model_class.from_pretrained(checkpoint)
510 | model.to(args.device)
511 | result = evaluate(args, model, tokenizer, prefix=global_step)
512 | result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
513 | results.update(result)
514 |
515 | return results
516 |
517 |
518 | if __name__ == "__main__":
519 | main()
520 |
--------------------------------------------------------------------------------
/modeling/scripts/factcc-eval.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # Evaluate FactCC model
3 |
4 | # UPDATE PATHS BEFORE RUNNING SCRIPT
5 | export CODE_PATH= # absolute path to modeling directory
6 | export DATA_PATH= # absolute path to data directory
7 | export CKPT_PATH= # absolute path to model checkpoint
8 |
9 | export TASK_NAME=factcc_annotated
10 | export MODEL_NAME=bert-base-uncased
11 |
12 | python3 $CODE_PATH/run.py \
13 | --task_name $TASK_NAME \
14 | --do_eval \
15 | --eval_all_checkpoints \
16 | --do_lower_case \
17 | --overwrite_cache \
18 | --max_seq_length 512 \
19 | --per_gpu_train_batch_size 12 \
20 | --model_type bert \
21 | --model_name_or_path $MODEL_NAME \
22 | --data_dir $DATA_PATH \
23 | --output_dir $CKPT_PATH
24 |
--------------------------------------------------------------------------------
/modeling/scripts/factcc-finetune.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # Fine-tune FactCC model
3 |
4 | # UPDATE PATHS BEFORE RUNNING SCRIPT
5 | export CODE_PATH= # absolute path to modeling directory
6 | export DATA_PATH= # absolute path to data directory
7 | export OUTPUT_PATH= # absolute path to model checkpoint
8 |
9 | export TASK_NAME=factcc_generated
10 | export MODEL_NAME=bert-base-uncased
11 |
12 | python3 $CODE_PATH/run.py \
13 | --task_name $TASK_NAME \
14 | --do_train \
15 | --do_eval \
16 | --evaluate_during_training \
17 | --do_lower_case \
18 | --max_seq_length 512 \
19 | --per_gpu_train_batch_size 12 \
20 | --learning_rate 2e-5 \
21 | --num_train_epochs 10.0 \
22 | --data_dir $DATA_PATH \
23 | --model_type bert \
24 | --model_name_or_path $MODEL_NAME \
25 | --output_dir $OUTPUT_PATH/$MODEL_NAME-$TASK_NAME-finetune-$RANDOM/
26 |
--------------------------------------------------------------------------------
/modeling/scripts/factcc-train.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # Train FactCC model
3 |
4 | # UPDATE PATHS BEFORE RUNNING SCRIPT
5 | export CODE_PATH= # absolute path to modeling directory
6 | export DATA_PATH= # absolute path to data directory
7 | export OUTPUT_PATH= # absolute path to model checkpoint
8 |
9 | export TASK_NAME=factcc_generated
10 | export MODEL_NAME=bert-base-uncased
11 |
12 | python3 $CODE_PATH/run.py \
13 | --task_name $TASK_NAME \
14 | --do_train \
15 | --do_eval \
16 | --do_lower_case \
17 | --train_from_scratch \
18 | --data_dir $DATA_PATH \
19 | --model_type bert \
20 | --model_name_or_path $MODEL_NAME \
21 | --max_seq_length 512 \
22 | --per_gpu_train_batch_size 12 \
23 | --learning_rate 2e-5 \
24 | --num_train_epochs 20.0 \
25 | --evaluate_during_training \
26 | --eval_all_checkpoints \
27 | --overwrite_cache \
28 | --output_dir $OUTPUT_DIR/$MODEL_NAME-$TASK_NAME-train-$RANDOM/
29 |
--------------------------------------------------------------------------------
/modeling/scripts/factccx-eval.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # Evaluate FactCCX model
3 |
4 | # UPDATE PATHS BEFORE RUNNING SCRIPT
5 | export CODE_PATH= # absolute path to modeling directory
6 | export DATA_PATH= # absolute path to data directory
7 | export CKPT_PATH= # absolute path to model checkpoint
8 |
9 | export TASK_NAME=factcc_annotated
10 | export MODEL_NAME=bert-base-uncased
11 |
12 | python3 $CODE_PATH/run.py \
13 | --task_name $TASK_NAME \
14 | --do_eval \
15 | --eval_all_checkpoints \
16 | --do_lower_case \
17 | --max_seq_length 512 \
18 | --per_gpu_train_batch_size 12 \
19 | --model_type pbert \
20 | --model_name_or_path $MODEL_NAME \
21 | --data_dir $DATA_PATH \
22 | --output_dir $CKPT_PATH
23 |
24 |
--------------------------------------------------------------------------------
/modeling/scripts/factccx-finetune.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # Fine-tuning BERT-base on CNN/DM data
3 |
4 | # UPDATE PATHS BEFORE RUNNING SCRIPT
5 | export CODE_PATH= # absolute path to modeling directory
6 | export DATA_PATH= # absolute path to data directory
7 | export OUTPUT_PATH= # absolute path to model checkpoint
8 |
9 | export TASK_NAME=factcc_generated
10 | export MODEL_NAME=bert-base-uncased
11 |
12 | python3 $CODE_PATH/run.py \
13 | --task_name $TASK_NAME \
14 | --do_train \
15 | --do_eval \
16 | --do_lower_case \
17 | --data_dir $DATA_PATH \
18 | --model_type pbert \
19 | --model_name_or_path $MODEL_NAME \
20 | --max_seq_length 512 \
21 | --per_gpu_train_batch_size 12 \
22 | --learning_rate 2e-5 \
23 | --num_train_epochs 10.0 \
24 | --loss_lambda 0.1 \
25 | --evaluate_during_training \
26 | --eval_all_checkpoints \
27 | --overwrite_cache \
28 | --output_dir $OUTPUT_PATH/$MODEL_NAME-$TASK_NAME-finetune-$RANDOM/
29 |
--------------------------------------------------------------------------------
/modeling/scripts/factccx-train.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # Train FactCCX model
3 |
4 | # UPDATE PATHS BEFORE RUNNING SCRIPT
5 | export CODE_PATH= # absolute path to modeling directory
6 | export DATA_PATH= # absolute path to data directory
7 | export OUTPUT_PATH= # absolute path to model checkpoint
8 |
9 | export TASK_NAME=factcc_generated
10 | export MODEL_NAME=bert-base-uncased
11 |
12 | python3 $CODE_PATH/run.py \
13 | --task_name $TASK_NAME \
14 | --do_train \
15 | --do_eval \
16 | --do_lower_case \
17 | --train_from_scratch \
18 | --data_dir $DATA_PATH \
19 | --model_type pbert \
20 | --model_name_or_path $MODEL_NAME \
21 | --max_seq_length 512 \
22 | --per_gpu_train_batch_size 12 \
23 | --learning_rate 2e-5 \
24 | --num_train_epochs 20.0 \
25 | --evaluate_during_training \
26 | --eval_all_checkpoints \
27 | --overwrite_cache \
28 | --output_dir $OUTPUT_DIR/$MODEL_NAME-$TASK_NAME-train-$RANDOM/
29 |
--------------------------------------------------------------------------------
/modeling/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020, Salesforce.com, Inc.
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ BERT classification fine-tuning: utilities to work with GLUE tasks """
17 |
18 | from __future__ import absolute_import, division, print_function
19 |
20 | import csv
21 | import logging
22 | import json
23 | import os
24 | import sys
25 | from io import open
26 |
27 | from scipy.stats import pearsonr, spearmanr
28 | from sklearn.metrics import f1_score, balanced_accuracy_score, accuracy_score
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | class InputExample(object):
34 | """A single training/test example for simple sequence classification."""
35 |
36 | def __init__(self, guid, text_a, text_b=None, label=None,
37 | extraction_span=None, augmentation_span=None):
38 | """Constructs a InputExample.
39 |
40 | Args:
41 | guid: Unique id for the example.
42 | text_a: string. The untokenized text of the first sequence. For single
43 | sequence tasks, only this sequence must be specified.
44 | text_b: (Optional) string. The untokenized text of the second sequence.
45 | Only must be specified for sequence pair tasks.
46 | label: (Optional) string. The label of the example. This should be
47 | specified for train and dev examples, but not for test examples.
48 | """
49 | self.guid = guid
50 | self.text_a = text_a
51 | self.text_b = text_b
52 | self.label = label
53 | self.extraction_span = extraction_span
54 | self.augmentation_span = augmentation_span
55 |
56 |
57 | class InputFeatures(object):
58 | """A single set of features of data."""
59 |
60 | def __init__(self, input_ids, input_mask, segment_ids, label_id,
61 | extraction_mask=None, extraction_start_ids=None, extraction_end_ids=None,
62 | augmentation_mask=None, augmentation_start_ids=None, augmentation_end_ids=None):
63 | self.input_ids = input_ids
64 | self.input_mask = input_mask
65 | self.segment_ids = segment_ids
66 | self.label_id = label_id
67 | self.extraction_mask = extraction_mask
68 | self.extraction_start_ids = extraction_start_ids
69 | self.extraction_end_ids = extraction_end_ids
70 | self.augmentation_mask = augmentation_mask
71 | self.augmentation_start_ids = augmentation_start_ids
72 | self.augmentation_end_ids = augmentation_end_ids
73 |
74 |
75 | class DataProcessor(object):
76 | """Base class for data converters for sequence classification data sets."""
77 |
78 | def get_train_examples(self, data_dir):
79 | """Gets a collection of `InputExample`s for the train set."""
80 | raise NotImplementedError()
81 |
82 | def get_dev_examples(self, data_dir):
83 | """Gets a collection of `InputExample`s for the dev set."""
84 | raise NotImplementedError()
85 |
86 | def get_labels(self):
87 | """Gets the list of labels for this data set."""
88 | raise NotImplementedError()
89 |
90 | @classmethod
91 | def _read_tsv(cls, input_file, quotechar=None):
92 | """Reads a tab separated value file."""
93 | with open(input_file, "r", encoding="utf-8-sig") as f:
94 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
95 | lines = []
96 | for line in reader:
97 | if sys.version_info[0] == 2:
98 | line = list(unicode(cell, 'utf-8') for cell in line)
99 | lines.append(line)
100 | return lines
101 |
102 | @classmethod
103 | def _read_json(cls, input_file):
104 | """Reads a jsonl file."""
105 | with open(input_file, "r", encoding="utf-8") as f:
106 | lines = []
107 | for line in f:
108 | lines.append(json.loads(line))
109 | return lines
110 |
111 |
112 | class FactCCGeneratedProcessor(DataProcessor):
113 | """Processor for the generated FactCC data set."""
114 |
115 | def get_train_examples(self, data_dir):
116 | """See base class."""
117 | return self._create_examples(
118 | self._read_json(os.path.join(data_dir, "data-train.jsonl")), "train")
119 |
120 | def get_dev_examples(self, data_dir):
121 | """See base class."""
122 | return self._create_examples(
123 | self._read_json(os.path.join(data_dir, "data-dev.jsonl")), "dev")
124 |
125 | def get_labels(self):
126 | """See base class."""
127 | return ["CORRECT", "INCORRECT"]
128 |
129 | def _create_examples(self, lines, set_type):
130 | """Creates examples for the training and dev sets."""
131 | examples = []
132 | for example in lines:
133 | guid = example["id"]
134 | text_a = example["text"]
135 | text_b = example["claim"]
136 | label = example["label"]
137 | extraction_span = example["extraction_span"]
138 | augmentation_span = example["augmentation_span"]
139 |
140 | examples.append(
141 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label,
142 | extraction_span=extraction_span, augmentation_span=augmentation_span))
143 | return examples
144 |
145 |
146 | class FactCCManualProcessor(DataProcessor):
147 | """Processor for the WNLI data set (GLUE version)."""
148 |
149 | def get_train_examples(self, data_dir):
150 | """See base class."""
151 | return self._create_examples(
152 | self._read_json(os.path.join(data_dir, "data-train.jsonl")), "train")
153 |
154 | def get_dev_examples(self, data_dir):
155 | """See base class."""
156 | return self._create_examples(
157 | self._read_json(os.path.join(data_dir, "data-dev.jsonl")), "dev")
158 |
159 | def get_labels(self):
160 | """See base class."""
161 | return ["CORRECT", "INCORRECT"]
162 |
163 | def _create_examples(self, lines, set_type):
164 | """Creates examples for the training and dev sets."""
165 | examples = []
166 | for (i, example) in enumerate(lines):
167 | guid = str(i)
168 | text_a = example["text"]
169 | text_b = example["claim"]
170 | label = example["label"]
171 | examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
172 | return examples
173 |
174 |
175 | def convert_examples_to_features(examples, label_list, max_seq_length,
176 | tokenizer, output_mode,
177 | cls_token_at_end=False,
178 | cls_token='[CLS]',
179 | cls_token_segment_id=1,
180 | sep_token='[SEP]',
181 | sep_token_extra=False,
182 | pad_on_left=False,
183 | pad_token=0,
184 | pad_token_segment_id=0,
185 | sequence_a_segment_id=0,
186 | sequence_b_segment_id=1,
187 | mask_padding_with_zero=True):
188 | """ Loads a data file into a list of `InputBatch`s
189 | `cls_token_at_end` define the location of the CLS token:
190 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
191 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
192 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
193 | """
194 |
195 | label_map = {label : i for i, label in enumerate(label_list)}
196 |
197 | features = []
198 | for (ex_index, example) in enumerate(examples):
199 | if ex_index % 10000 == 0:
200 | logger.info("Writing example %d of %d" % (ex_index, len(examples)))
201 |
202 | tokens_a = tokenizer.tokenize(example.text_a)
203 |
204 | tokens_b = None
205 | if example.text_b:
206 | tokens_b = tokenizer.tokenize(example.text_b)
207 | # Modifies `tokens_a` and `tokens_b` in place so that the total
208 | # length is less than the specified length.
209 | # Account for [CLS], [SEP], [SEP] with "- 3". " -4" for RoBERTa.
210 | special_tokens_count = 4 if sep_token_extra else 3
211 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
212 | else:
213 | # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa.
214 | special_tokens_count = 3 if sep_token_extra else 2
215 | if len(tokens_a) > max_seq_length - special_tokens_count:
216 | tokens_a = tokens_a[:(max_seq_length - special_tokens_count)]
217 |
218 | # The convention in BERT is:
219 | # (a) For sequence pairs:
220 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
221 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
222 | # (b) For single sequences:
223 | # tokens: [CLS] the dog is hairy . [SEP]
224 | # type_ids: 0 0 0 0 0 0 0
225 | #
226 | # Where "type_ids" are used to indicate whether this is the first
227 | # sequence or the second sequence. The embedding vectors for `type=0` and
228 | # `type=1` were learned during pre-training and are added to the wordpiece
229 | # embedding vector (and position vector). This is not *strictly* necessary
230 | # since the [SEP] token unambiguously separates the sequences, but it makes
231 | # it easier for the model to learn the concept of sequences.
232 | #
233 | # For classification tasks, the first vector (corresponding to [CLS]) is
234 | # used as as the "sentence vector". Note that this only makes sense because
235 | # the entire model is fine-tuned.
236 | tokens = tokens_a + [sep_token]
237 | if sep_token_extra:
238 | # roberta uses an extra separator b/w pairs of sentences
239 | tokens += [sep_token]
240 | segment_ids = [sequence_a_segment_id] * len(tokens)
241 |
242 | if tokens_b:
243 | tokens += tokens_b + [sep_token]
244 | segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
245 |
246 | if cls_token_at_end:
247 | tokens = tokens + [cls_token]
248 | segment_ids = segment_ids + [cls_token_segment_id]
249 | else:
250 | tokens = [cls_token] + tokens
251 | segment_ids = [cls_token_segment_id] + segment_ids
252 |
253 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
254 |
255 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
256 | # tokens are attended to.
257 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
258 |
259 | ####### AUX LOSS DATA
260 | # get tokens_a mask
261 | extraction_span_len = len(tokens_a) + 2
262 | extraction_mask = [1 if 0 < ix < extraction_span_len else 0 for ix in range(max_seq_length)]
263 |
264 | # get extraction labels
265 | if example.extraction_span:
266 | ext_start, ext_end = example.extraction_span
267 | extraction_start_ids = ext_start + 1
268 | extraction_end_ids = ext_end + 1
269 | else:
270 | extraction_start_ids = extraction_span_len
271 | extraction_end_ids = extraction_span_len
272 |
273 | augmentation_mask = [1 if extraction_span_len <= ix < extraction_span_len + len(tokens_b) + 1 else 0 for ix in range(max_seq_length)]
274 |
275 | if example.augmentation_span:
276 | aug_start, aug_end = example.augmentation_span
277 | augmentation_start_ids = extraction_span_len + aug_start
278 | augmentation_end_ids = extraction_span_len + aug_end
279 | else:
280 | last_sep_token = extraction_span_len + len(tokens_b)
281 | augmentation_start_ids = last_sep_token
282 | augmentation_end_ids = last_sep_token
283 |
284 | # Zero-pad up to the sequence length.
285 | padding_length = max_seq_length - len(input_ids)
286 | if pad_on_left:
287 | input_ids = ([pad_token] * padding_length) + input_ids
288 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
289 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
290 | else:
291 | input_ids = input_ids + ([pad_token] * padding_length)
292 | input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
293 | segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
294 |
295 | assert len(input_ids) == max_seq_length
296 | assert len(input_mask) == max_seq_length
297 | assert len(segment_ids) == max_seq_length
298 |
299 | if output_mode == "classification":
300 | label_id = label_map[example.label]
301 | elif output_mode == "regression":
302 | label_id = float(example.label)
303 | else:
304 | raise KeyError(output_mode)
305 |
306 | if ex_index < 3:
307 | logger.info("*** Example ***")
308 | logger.info("guid: %s" % (example.guid))
309 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens]))
310 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
311 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
312 | logger.info("ext mask: %s" % " ".join([str(x) for x in extraction_mask]))
313 | logger.info("ext start: %d" % extraction_start_ids)
314 | logger.info("ext end: %d" % extraction_end_ids)
315 | logger.info("aug mask: %s" % " ".join([str(x) for x in augmentation_mask]))
316 | logger.info("aug start: %d" % augmentation_start_ids)
317 | logger.info("aug end: %d" % augmentation_end_ids)
318 | logger.info("label: %d" % label_id)
319 |
320 | extraction_start_ids = min(extraction_start_ids, 511)
321 | extraction_end_ids = min(extraction_end_ids, 511)
322 | augmentation_start_ids = min(augmentation_start_ids, 511)
323 | augmentation_end_ids = min(augmentation_end_ids, 511)
324 |
325 | features.append(
326 | InputFeatures(input_ids=input_ids,
327 | input_mask=input_mask,
328 | segment_ids=segment_ids,
329 | label_id=label_id,
330 | extraction_mask=extraction_mask,
331 | extraction_start_ids=extraction_start_ids,
332 | extraction_end_ids=extraction_end_ids,
333 | augmentation_mask=augmentation_mask,
334 | augmentation_start_ids=augmentation_start_ids,
335 | augmentation_end_ids=augmentation_end_ids))
336 | return features
337 |
338 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
339 | """Truncates a sequence pair in place to the maximum length."""
340 |
341 | # This is a simple heuristic which will always truncate the longer sequence
342 | # one token at a time. This makes more sense than truncating an equal percent
343 | # of tokens from each, since if one sequence is very short then each token
344 | # that's truncated likely contains more information than a longer sequence.
345 | while True:
346 | total_length = len(tokens_a) + len(tokens_b)
347 | if total_length <= max_length:
348 | break
349 | if len(tokens_a) > len(tokens_b):
350 | tokens_a.pop()
351 | else:
352 | tokens_b.pop()
353 |
354 |
355 | def simple_accuracy(preds, labels):
356 | return (preds == labels).mean()
357 |
358 |
359 | def acc_and_f1(preds, labels):
360 | acc = simple_accuracy(preds, labels)
361 | f1 = f1_score(y_true=labels, y_pred=preds)
362 | return {
363 | "acc": acc,
364 | "f1": f1,
365 | "acc_and_f1": (acc + f1) / 2,
366 | }
367 |
368 |
369 | def pearson_and_spearman(preds, labels):
370 | pearson_corr = pearsonr(preds, labels)[0]
371 | spearman_corr = spearmanr(preds, labels)[0]
372 | return {
373 | "pearson": pearson_corr,
374 | "spearmanr": spearman_corr,
375 | "corr": (pearson_corr + spearman_corr) / 2,
376 | }
377 |
378 |
379 | def complex_metric(preds, labels, prefix=""):
380 | return {
381 | prefix + "bacc": balanced_accuracy_score(y_true=labels, y_pred=preds),
382 | prefix + "f1": f1_score(y_true=labels, y_pred=preds, average="micro")
383 | }
384 |
385 |
386 | def compute_metrics(task_name, preds, labels, prefix=""):
387 | assert len(preds) == len(labels)
388 | if task_name == "factcc_generated":
389 | return complex_metric(preds, labels, prefix)
390 | elif task_name == "factcc_annotated":
391 | return complex_metric(preds, labels, prefix)
392 | else:
393 | raise KeyError(task_name)
394 |
395 | processors = {
396 | "factcc_generated": FactCCGeneratedProcessor,
397 | "factcc_annotated": FactCCManualProcessor,
398 | }
399 |
400 | output_modes = {
401 | "factcc_generated": "classification",
402 | "factcc_annotated": "classification",
403 | }
404 |
405 | GLUE_TASKS_NUM_LABELS = {
406 | "factcc_generated": 2,
407 | "factcc_annotated": 2,
408 | }
409 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch_transformers==1.0.0
2 | scipy==1.3.0
3 | spacy==2.2.3
4 | torch==1.3.1
5 | numpy==1.17.4
6 | tqdm==4.32.1
7 | apex==0.9.10dev
8 | protobuf==3.11.2
9 | scikit_learn==0.22.1
10 | google-cloud-translate==1.6.0
11 |
--------------------------------------------------------------------------------