├── .gitignore ├── K2019 ├── FactCC.ipynb ├── README.me ├── process_cnndm.py └── test │ └── data.json ├── README.md ├── build_dataset ├── augmentation_ops.py ├── create_data.py ├── create_data.sh └── create_data_cc.sh └── model ├── data_utils.py ├── main.py ├── main_data.py ├── model.ipynb ├── model.py ├── run.py ├── test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/.vscode/ 2 | **/.ipynb_checkpoints/ 3 | **/__pycache__/ 4 | **/results/ 5 | **/cnn_dm-bin/ 6 | **/bart.large.cnn/ 7 | **/bart.large/ 8 | **/nohup.out 9 | **/*.hypo 10 | **/*.tokenized 11 | **/*.zip 12 | *.bin 13 | *.tar.gz 14 | *.jsonl 15 | *.source 16 | **/cnn-dailymail/ 17 | -------------------------------------------------------------------------------- /K2019/FactCC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "from tqdm import tqdm\n", 11 | "\n", 12 | "from process_cnndm import get_art_abs" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 3, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "val_or_test = 'test'\n", 22 | "test_path = '{}/data-dev.jsonl'.format(val_or_test)\n", 23 | "output_path = '{}/test.source'.format(val_or_test)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 4, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "factcc_test = []\n", 33 | "with open(test_path, 'r', encoding='utf-8') as f:\n", 34 | " for line in f:\n", 35 | " line = line.strip()\n", 36 | " factcc_test.append(json.loads(line))" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 5, 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "503\n", 49 | "{'claim': 'it has already been viewed more than 1 million times.', 'label': 'CORRECT', 'filepath': 'cnndm/cnn/stories/dbce61d253b9e770529817b484aeb8b0cca76a73.story', 'id': 'cnn-test-dbce61d253b9e770529817b484aeb8b0cca76a73'}\n" 50 | ] 51 | } 52 | ], 53 | "source": [ 54 | "print(len(factcc_test))\n", 55 | "print(factcc_test[111])" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 6, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "def get_type_and_id(origin_path):\n", 65 | " first_slash = origin_path.find(\"/\")\n", 66 | " second_slash = origin_path.find(\"/\", first_slash + 1)\n", 67 | " last_slash = origin_path.rfind(\"/\")\n", 68 | " data_type = origin_path[first_slash + 1: second_slash]\n", 69 | " assert data_type in ['cnn', 'dm']\n", 70 | " data_id = origin_path[last_slash + 1:]\n", 71 | " \n", 72 | " return data_type, data_id" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 13, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "def compose_path(data_type, data_id):\n", 82 | " return \"/home/ml/cadencao/cnn-dailymail/cnn/stories/{}\".format(data_id)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 14, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "- claim:\n", 95 | "georgia southern university was in mourning after five nursing students died.\n", 96 | "- story:\n", 97 | "(CNN)Georgia Southern University was in mourning Thursday after five nursing students were killed the day before in a multivehicle wreck near Savannah. Caitlyn Baggett, Morgan Bass, Emily Clark, Abbie Deloach and Catherine (McKay) Pittman -- all juniors -- were killed in the Wednesday morning crash as they were traveling to a hospital in Savannah, according to the school website. Fellow nursing students Brittney McDaniel and Megan Richards were injured as was another person, who was not identified by the Georgia State Patrol. The young women were on their way to finish their first set of clinical rotations. \"Today should have been a day of celebration for this bright group of students,\" at St. Joseph's/Candler hospital said in a Facebook posting. \"It was their last day of clinical rotations ... in their first year of nursing school.\" Clinicals include hands-on instruction at a health care facility. A post commander for the Georgia State Patrol said a tractor-trailer smashed into an eastbound line of cars that had slowed for a prior accident on Interstate 16. \"He came along from behind them and he just did not stop for those cars,\" Sgt. Chris Nease said. There were four passenger vehicles and three tractor-trailers involved in the 5:45 a.m. accident. The women who were killed were in two cars, a Toyota Corolla and a Ford Escape. One of their vehicles caught on fire, Nease said, but it will take an investigation to determine whether the women died on impact. CNN Savannah affiliate WTOC reported one witness tried to help. \"Right about the time I got here, the car was just about catching on fire,\" Cayne Monroe told the station. \"The car just burned up really quickly. And I run up there, but there was nothing anyone could do. I've never witnessed something like that in my life. It was pretty tragic.\" The state patrol said the truck driver is from Louisiana. The 55-year-old man had not been charged as of Thursday evening, Nease told CNN. \"Every one of our students contributes in no small measure to the Eagle Nation,\" university President Brooks A. Keel said in a statement. \"The loss of any student, especially in a tragic way, is particularly painful. Losing five students is almost incomprehensible.\" Georgia Southern flew flags at half-staff and counseling was offered to students. A campuswide vigil was held Thursday night. On the university's Twitter page, a tear was added to the profile logo of the eagle mascot. The school has a student body of about 20,000 and is in Statesboro, about 60 miles from Savannah. \"You could tell that they really loved what they did,\" Sherry Danello, vice president of patient care services and chief nursing officer at St. Joseph's/Candler, said on the hospital's Facebook posting. \"They didn't just go through the task, they really connected to the patients.\" Luke Bryan, a country music star and school alumnus, tweeted his condolences: \"Praying for everyone at Georgia Southern and the families who lost loved ones.\" CNN's Matthew Stucker contributed to this report.\n" 98 | ] 99 | } 100 | ], 101 | "source": [ 102 | "with open(output_path, 'w', encoding='utf-8') as fout:\n", 103 | " for i, ft in enumerate(factcc_test):\n", 104 | " new_file_path = compose_path(*get_type_and_id(ft['filepath']))\n", 105 | " story, _ = get_art_abs(new_file_path)\n", 106 | " fout.write(ft['claim'] + ' ' + story + '\\n')\n", 107 | " \n", 108 | " # print out samples\n", 109 | " if i == 0:\n", 110 | " print('- claim:')\n", 111 | " print(ft['claim'])\n", 112 | " print('- story:')\n", 113 | " print(story)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [] 122 | } 123 | ], 124 | "metadata": { 125 | "kernelspec": { 126 | "display_name": "Python 3", 127 | "language": "python", 128 | "name": "python3" 129 | }, 130 | "language_info": { 131 | "codemirror_mode": { 132 | "name": "ipython", 133 | "version": 3 134 | }, 135 | "file_extension": ".py", 136 | "mimetype": "text/x-python", 137 | "name": "python", 138 | "nbconvert_exporter": "python", 139 | "pygments_lexer": "ipython3", 140 | "version": "3.7.7" 141 | } 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 4 145 | } 146 | -------------------------------------------------------------------------------- /K2019/README.me: -------------------------------------------------------------------------------- 1 | ## Files 2 | 3 | `val/data-dev.jsonl` - Manually annotated validation data, claim sentences were generated by summarization models. 4 | `test/data-dev.jsonl` - Manually annotated test data, claim sentences were generated by summarization models. 5 | 6 | 7 | ## Data Schema 8 | 9 | Each example is stored as a separate JSON object. 10 | 11 | JSON objects have the following keys: 12 | `id` - ID of article associated with the example, matches ID's from the CNN/DM Story files. NOT UNIQUE. 13 | `filepath` - Filepath to the Story file associated with the example. 14 | `claim` - Claim sentence generated by a summarization model. 15 | `label` - Example label 16 | 17 | ## Citation 18 | 19 | ` 20 | @article{kryscinskiFactCC2019, 21 | author = {Wojciech Kry{\'s}ci{\'n}ski and Bryan McCann and Caiming Xiong and Richard Socher}, 22 | title = {Evaluating the Factual Consistency of Abstractive Text Summarization}, 23 | journal = {arXiv preprint arXiv:1910.12840}, 24 | year = {2019}, 25 | } 26 | ` 27 | 28 | ## Contact 29 | 30 | Corresponding authors are Wojciech Kryściński (kryscinski@salesforce.com) and Bryan McCann (bmccann@salesforce.com) 31 | Issues or questions can also be posted through GitHub https://github.com/salesforce/factCC 32 | -------------------------------------------------------------------------------- /K2019/process_cnndm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | dm_single_close_quote = u'\u2019' # unicode 5 | dm_double_close_quote = u'\u201d' 6 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence 7 | 8 | def read_text_file(text_file): 9 | lines = [] 10 | with open(text_file, "r") as f: 11 | for line in f: 12 | lines.append(line.strip()) 13 | return lines 14 | 15 | 16 | def fix_missing_period(line): 17 | """Adds a period to a line that is missing a period""" 18 | if "@highlight" in line: return line 19 | if line == "": return line 20 | if line[-1] in END_TOKENS: return line 21 | return line + " ." 22 | 23 | 24 | def get_art_abs(story_file): 25 | lines = read_text_file(story_file) 26 | 27 | # Put periods on the ends of lines that are missing them 28 | lines = [fix_missing_period(line) for line in lines] 29 | 30 | # Separate out article and abstract sentences 31 | article_lines, highlights = [], [] 32 | 33 | next_is_highlight = False 34 | for idx,line in enumerate(lines): 35 | if line == "": 36 | continue # empty line 37 | elif line.startswith("@highlight"): 38 | next_is_highlight = True 39 | elif next_is_highlight: 40 | highlights.append(line) 41 | else: 42 | article_lines.append(line) 43 | 44 | # Make article into a single string 45 | article = ' '.join(article_lines) 46 | 47 | # Make abstract into a signle string 48 | abstract = ' '.join(highlights) 49 | 50 | return article, abstract -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Factual Error Correction for Abstractive Summarization Models 2 | 3 | This directory contains code necessary to replicate the training and evaluation for the EMNLP 2020 paper ["Factual Error Correction for Abstractive Summarization Models"](https://arxiv.org/abs/2010.08712) by Meng Cao, Yue Dong, Jiapeng Wu and Jackie Chi Kit Cheung. 4 | 5 | ## Directory Structure 6 | 7 | Our code is organized into four subdirectories: 8 | 9 | * `build_dataset`: code for building the aritificial trianing & test dataset. 10 | * `cnn-dailymail`: directory for the cnn-dailymail summarization dataset. 11 | * `K2019`: directory for the manually annotated dataset by Kryscinski et al. (2019). 12 | * `model`: wrapper for the fariseq BART model for training. 13 | 14 | ## (1) Build Dataset 15 | To build the training dataset, first download the processed cnn-dailymail dataset from [this link](https://drive.google.com/file/d/1uqONBkA_5rTd9CA8j5_2BeBeZqKaWTht/view?usp=sharing). Unzip and save the downloaded files in `cnn-dailymail`. 16 | 17 | Then, run the data creation bash to build the training data: 18 | 19 | ``` 20 | cd build_dataset 21 | sh create_data.sh 22 | ``` 23 | 24 | ## (2) Model Training 25 | We use [BART](https://arxiv.org/abs/1910.13461) as our base model. To download and use BART model, follow the instructions [here](https://github.com/pytorch/fairseq/tree/master/examples/bart). 26 | 27 | ## (3) K2019 28 | The annotated cnn-dailymail test set from [Kryscinski et al. 2019 ACL paper](https://arxiv.org/pdf/1910.12840.pdf). 29 | -------------------------------------------------------------------------------- /build_dataset/augmentation_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data augmentation (transformations) operations used to generate 3 | synthetic training data for the `FactCC` and `FactCCX` models. 4 | """ 5 | 6 | import random 7 | 8 | import spacy 9 | 10 | from google.cloud import translate_v2 as translate 11 | 12 | 13 | LABEL_MAP = {True: "CORRECT", False: "INCORRECT"} 14 | 15 | 16 | def align_ws(old_token, new_token): 17 | # Align trailing whitespaces between tokens 18 | if old_token[-1] == new_token[-1] == " ": 19 | return new_token 20 | elif old_token[-1] == " ": 21 | return new_token + " " 22 | elif new_token[-1] == " ": 23 | return new_token[:-1] 24 | else: 25 | return new_token 26 | 27 | 28 | def make_new_example(eid=None, text=None, summary=None, claim=None, label=None, extraction_span=None, 29 | backtranslation=None, augmentation=None, augmentation_span=None, noise=None): 30 | # Embed example information in a json object. 31 | return { 32 | "id": eid, 33 | "text": text, 34 | "summary": summary, 35 | "claim": claim, 36 | "label": label, 37 | "extraction_span": extraction_span, 38 | "backtranslation": backtranslation, 39 | "augmentation": augmentation, 40 | "augmentation_span": augmentation_span, 41 | "noise": noise 42 | } 43 | 44 | 45 | class Transformation(): 46 | # Base class for all data transformations 47 | 48 | def __init__(self): 49 | # Spacy toolkit used for all NLP-related substeps 50 | self.spacy = spacy.load("en") 51 | 52 | def transform(self, example): 53 | # Function applies transformation on passed example 54 | pass 55 | 56 | 57 | class FormatTransformation(Transformation): 58 | # add new keys to the example 59 | def __init__(self): 60 | super().__init__() 61 | 62 | def transform(self, example): 63 | page_doc = self.spacy(example["text"], disable=["tagger"]) 64 | claim, summary = self.spacy(example["summary"]), self.spacy(example["summary"]) 65 | new_example = make_new_example(eid=example["id"], 66 | text=page_doc, 67 | summary=summary, 68 | claim=claim, 69 | label=LABEL_MAP[True], 70 | backtranslation=False, noise=False) 71 | return new_example 72 | 73 | 74 | class SampleSentences(Transformation): 75 | # Embed document as Spacy object and sample one sentence as claim 76 | def __init__(self, min_sent_len=8): 77 | super().__init__() 78 | self.min_sent_len = min_sent_len 79 | 80 | def transform(self, example): 81 | assert example["text"] is not None, "Text must be available" 82 | 83 | # split into sentences 84 | page_id = example["id"] 85 | page_text = example["text"].replace("\n", " ") 86 | page_doc = self.spacy(page_text, disable=["tagger"]) 87 | sents = [sent for sent in page_doc.sents if len(sent) >= self.min_sent_len] 88 | 89 | # sample claim 90 | claim = random.choice(sents) 91 | new_example = make_new_example(eid=page_id, text=page_doc, 92 | claim=self.spacy(claim.text), 93 | label=LABEL_MAP[True], 94 | extraction_span=(claim.start, claim.end-1), 95 | backtranslation=False, noise=False) 96 | return new_example 97 | 98 | 99 | class NegateSentences(Transformation): 100 | # Apply or remove negation from negatable tokens 101 | def __init__(self): 102 | super().__init__() 103 | self.__negatable_tokens = ("are", "is", "was", "were", "have", "has", "had", 104 | "do", "does", "did", "can", "ca", "could", "may", 105 | "might", "must", "shall", "should", "will", "would") 106 | 107 | def transform(self, example): 108 | assert example["text"] is not None, "Text must be available" 109 | assert example["claim"] is not None, "Claim must be available" 110 | 111 | new_example = dict(example) 112 | new_claim, aug_span = self.__negate_sentences(new_example["claim"]) 113 | 114 | if new_claim: 115 | new_example["claim"] = new_claim 116 | new_example["label"] = LABEL_MAP[False] 117 | new_example["augmentation"] = self.__class__.__name__ 118 | new_example["augmentation_span"] = aug_span 119 | return new_example 120 | else: 121 | return None 122 | 123 | def __negate_sentences(self, claim): 124 | # find negatable token, return None if no candiates found 125 | candidate_tokens = [token for token in claim if token.text in self.__negatable_tokens] 126 | 127 | if not candidate_tokens: 128 | return None, None 129 | 130 | # choose random token to negate 131 | negated_token = random.choice(candidate_tokens) 132 | negated_ix = negated_token.i 133 | doc_len = len(claim) 134 | 135 | if negated_ix > 0: 136 | if claim[negated_ix - 1].text in self.__negatable_tokens: 137 | negated_token = claim[negated_ix - 1] 138 | negated_ix = negated_ix - 1 139 | 140 | # check whether token is negative 141 | is_negative = False 142 | if (doc_len - 1) > negated_ix: 143 | if claim[negated_ix + 1].text in ["not", "n't"]: 144 | is_negative = True 145 | elif claim[negated_ix + 1].text == "no": 146 | return None, None 147 | 148 | 149 | # negate token 150 | claim_tokens = [token.text_with_ws for token in claim] 151 | if is_negative: 152 | if claim[negated_ix + 1].text.lower() == "n't": 153 | if claim[negated_ix + 1].text.lower() == "ca": 154 | claim_tokens[negated_ix] = "can" if claim_tokens[negated_ix].islower() else "Can" 155 | claim_tokens[negated_ix] = claim_tokens[negated_ix] + " " 156 | claim_tokens.pop(negated_ix + 1) 157 | else: 158 | if claim[negated_ix].text.lower() in ["am", "may", "might", "must", "shall", "will"]: 159 | negation = "not " 160 | else: 161 | negation = random.choice(["not ", "n't "]) 162 | 163 | if negation == "n't ": 164 | if claim[negated_ix].text.lower() == "can": 165 | claim_tokens[negated_ix] = "ca" if claim_tokens[negated_ix].islower() else "Ca" 166 | else: 167 | claim_tokens[negated_ix] = claim_tokens[negated_ix][:-1] 168 | claim_tokens.insert(negated_ix + 1, negation) 169 | 170 | # create new claim object 171 | new_claim = self.spacy("".join(claim_tokens)) 172 | augmentation_span = (negated_ix, negated_ix if is_negative else negated_ix + 1) 173 | 174 | if new_claim.text == claim.text: 175 | return None, None 176 | else: 177 | return new_claim, augmentation_span 178 | 179 | 180 | class Backtranslation(Transformation): 181 | # Paraphrase sentence via backtranslation with Google Translate API 182 | # Requires API Key for Google Cloud SDK, additional charges DO apply 183 | def __init__(self, dst_lang=None): 184 | super().__init__() 185 | 186 | self.src_lang = "en" 187 | self.dst_lang = dst_lang 188 | self.accepted_langs = ["fr", "de", "zh-TW", "es", "ru"] 189 | self.translator = translate.Client() 190 | 191 | def transform(self, example): 192 | assert example["text"] is not None, "Text must be available" 193 | assert example["claim"] is not None, "Claim must be available" 194 | 195 | new_example = dict(example) 196 | new_claim, _ = self.__backtranslate(new_example["claim"]) 197 | 198 | if new_claim: 199 | new_example["claim"] = new_claim 200 | new_example["backtranslation"] = True 201 | return new_example 202 | else: 203 | return None 204 | 205 | def __backtranslate(self, claim): 206 | # chose destination language, passed or random from list 207 | dst_lang = self.dst_lang if self.dst_lang else random.choice(self.accepted_langs) 208 | 209 | # translate to intermediate language and back 210 | claim_trans = self.translator.translate(claim.text, target_language=dst_lang, format_="text") 211 | claim_btrans = self.translator.translate(claim_trans["translatedText"], target_language=self.src_lang, format_="text") 212 | 213 | # create new claim object 214 | new_claim = self.spacy(claim_btrans["translatedText"]) 215 | augmentation_span = (new_claim[0].i, new_claim[-1].i) 216 | 217 | if claim.text == new_claim.text: 218 | return None, None 219 | else: 220 | return new_claim, augmentation_span 221 | 222 | 223 | class PronounSwap(Transformation): 224 | # Swap randomly chosen pronoun 225 | def __init__(self, prob_swap=0.5): 226 | super().__init__() 227 | 228 | self.class2pronoun_map = { 229 | "SUBJECT": ["you", "he", "she", "we", "they"], 230 | "OBJECT": ["me", "you", "him", "her", "us", "them"], 231 | "POSSESSIVE": ["my", "your", "his", "her", "its", "our", "your", "their"], 232 | "REFLEXIVE": ["myself", "yourself", "himself", "itself", "ourselves", "yourselves", "themselves"] 233 | } 234 | 235 | self.pronoun2class_map = {pronoun: key for (key, values) in self.class2pronoun_map.items() for pronoun in values} 236 | self.pronouns = {pronoun for (key, values) in self.class2pronoun_map.items() for pronoun in values} 237 | 238 | def transform(self, example): 239 | assert example["text"] is not None, "Text must be available" 240 | assert example["claim"] is not None, "Claim must be available" 241 | 242 | new_example = dict(example) 243 | new_claim, aug_span = self.__swap_pronouns(new_example["claim"]) 244 | 245 | if new_claim: 246 | new_example["claim"] = new_claim 247 | new_example["label"] = LABEL_MAP[False] 248 | new_example["augmentation"] = self.__class__.__name__ 249 | new_example["augmentation_span"] = aug_span 250 | return new_example 251 | else: 252 | return None 253 | 254 | def __swap_pronouns(self, claim): 255 | # find pronouns 256 | claim_pronouns = [token for token in claim if token.text.lower() in self.pronouns] 257 | 258 | if not claim_pronouns: 259 | return None, None 260 | 261 | # find pronoun replacement 262 | chosen_token = random.choice(claim_pronouns) 263 | chosen_ix = chosen_token.i 264 | chosen_class = self.pronoun2class_map[chosen_token.text.lower()] 265 | 266 | candidate_tokens = [token for token in self.class2pronoun_map[chosen_class] if token != chosen_token.text.lower()] 267 | 268 | if not candidate_tokens: 269 | return None, None 270 | 271 | # swap pronoun and update indices 272 | swapped_token = random.choice(candidate_tokens) 273 | swapped_token = align_ws(chosen_token.text_with_ws, swapped_token) 274 | swapped_token = swapped_token if chosen_token.text.islower() else swapped_token.capitalize() 275 | 276 | claim_tokens = [token.text_with_ws for token in claim] 277 | claim_tokens[chosen_ix] = swapped_token 278 | 279 | # create new claim object 280 | new_claim = self.spacy("".join(claim_tokens)) 281 | augmentation_span = (chosen_ix, chosen_ix) 282 | 283 | if claim.text == new_claim.text: 284 | return None, None 285 | else: 286 | return new_claim, augmentation_span 287 | 288 | 289 | class NERSwap(Transformation): 290 | # Swap NER objects - parent class 291 | def __init__(self): 292 | super().__init__() 293 | self.categories = () 294 | 295 | def transform(self, example): 296 | assert example["text"] is not None, "Text must be available" 297 | assert example["claim"] is not None, "Claim must be available" 298 | 299 | new_example = dict(example) 300 | new_claim, aug_span = self.__swap_entities(new_example["text"], new_example["claim"]) 301 | 302 | if new_claim: 303 | new_example["claim"] = new_claim 304 | new_example["label"] = LABEL_MAP[False] 305 | new_example["augmentation"] = self.__class__.__name__ 306 | new_example["augmentation_span"] = aug_span 307 | return new_example 308 | else: 309 | return None 310 | 311 | def __swap_entities(self, text, claim): 312 | # find entities in given category 313 | text_ents = [ent for ent in text.ents if ent.label_ in self.categories] 314 | claim_ents = [ent for ent in claim.ents if ent.label_ in self.categories] 315 | 316 | if not claim_ents or not text_ents: 317 | return None, None 318 | 319 | # choose entity to replace and find possible replacement in source 320 | replaced_ent = random.choice(claim_ents) 321 | candidate_ents = [ent for ent in text_ents if ent.text != replaced_ent.text and ent.text not in replaced_ent.text and replaced_ent.text not in ent.text] 322 | 323 | if not candidate_ents: 324 | return None, None 325 | 326 | # update claim and indices 327 | swapped_ent = random.choice(candidate_ents) 328 | claim_tokens = [token.text_with_ws for token in claim] 329 | swapped_token = align_ws(replaced_ent.text_with_ws, swapped_ent.text_with_ws) 330 | claim_swapped = claim_tokens[:replaced_ent.start] + [swapped_token] + claim_tokens[replaced_ent.end:] 331 | 332 | # create new claim object 333 | new_claim = self.spacy("".join(claim_swapped)) 334 | augmentation_span = (replaced_ent.start, replaced_ent.start + len(swapped_ent) - 1) 335 | 336 | if new_claim.text == claim.text: 337 | return None, None 338 | else: 339 | return new_claim, augmentation_span 340 | 341 | 342 | class EntitySwap(NERSwap): 343 | # NER swapping class specialized for entities (people, companies, locations, etc.) 344 | def __init__(self): 345 | super().__init__() 346 | self.categories = ("PERSON", "ORG", "NORP", "FAC", "GPE", "LOC", "PRODUCT", 347 | "WORK_OF_ART", "EVENT") 348 | 349 | 350 | class NumberSwap(NERSwap): 351 | # NER swapping class specialized for numbers (excluding dates) 352 | def __init__(self): 353 | super().__init__() 354 | 355 | self.categories = ("PERCENT", "MONEY", "QUANTITY", "CARDINAL") 356 | 357 | 358 | class DateSwap(NERSwap): 359 | # NER swapping class specialized for dates and time 360 | def __init__(self): 361 | super().__init__() 362 | 363 | self.categories = ("DATE", "TIME") 364 | 365 | 366 | class AddNoise(Transformation): 367 | # Inject noise into claims 368 | def __init__(self, noise_prob=0.05, delete_prob=0.8): 369 | super().__init__() 370 | self.noise_prob = noise_prob 371 | self.delete_prob = delete_prob 372 | self.spacy = spacy.load("en") 373 | 374 | def transform(self, example): 375 | assert example["text"] is not None, "Text must be available" 376 | assert example["claim"] is not None, "Claim must be available" 377 | 378 | new_example = dict(example) 379 | claim = new_example["claim"] 380 | aug_span = new_example["augmentation_span"] 381 | new_claim, aug_span = self.__add_noise(claim, aug_span) 382 | 383 | if new_claim: 384 | new_example["claim"] = new_claim 385 | new_example["augmentation_span"] = aug_span 386 | new_example["noise"] = True 387 | return new_example 388 | else: 389 | return None 390 | 391 | def __add_noise(self, claim, aug_span): 392 | claim_tokens = [token.text_with_ws for token in claim] 393 | 394 | new_claim = [] 395 | for ix, token in enumerate(claim_tokens): 396 | # don't modify text inside an augmented span 397 | apply_augmentation = True 398 | if aug_span: 399 | span_start, span_end = aug_span 400 | if span_start <= ix <= span_end: 401 | apply_augmentation = False 402 | 403 | # decide whether to add noise 404 | if apply_augmentation and random.random() < self.noise_prob: 405 | # decide whether to replicate or delete token 406 | if random.random() < self.delete_prob: 407 | # update spans and skip token 408 | if aug_span: 409 | span_start, span_end = aug_span 410 | if ix < span_start: 411 | span_start -= 1 412 | span_end -= 1 413 | aug_span = span_start, span_end 414 | if len(new_claim) > 0: 415 | if new_claim[-1][-1] != " ": 416 | new_claim[-1] = new_claim[-1] + " " 417 | continue 418 | else: # duplicate token 419 | if aug_span: 420 | span_start, span_end = aug_span 421 | if ix < span_start: 422 | span_start += 1 423 | span_end += 1 424 | aug_span = span_start, span_end 425 | new_claim.append(token) 426 | new_claim.append(token) 427 | new_claim = self.spacy("".join(new_claim)) 428 | 429 | if claim.text == new_claim.text: 430 | return None, None 431 | else: 432 | return new_claim, aug_span 433 | -------------------------------------------------------------------------------- /build_dataset/create_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for generating synthetic data for FactCC training. 3 | 4 | Script expects source documents in `jsonl` format with each source document 5 | embedded in a separate json object. 6 | 7 | Json objects are required to contain `id` and `text` keys. 8 | """ 9 | 10 | import argparse 11 | import json 12 | import os 13 | import pickle 14 | 15 | from tqdm import tqdm 16 | 17 | import augmentation_ops as ops 18 | 19 | 20 | def read_fairseq_files(source_path, target_path): 21 | """Read fairseq documents (source) and summaries (target). 22 | """ 23 | data = [] 24 | with open(source_path, 'r', encoding='utf-8') as source, \ 25 | open(target_path, 'r', encoding='utf-8') as target: 26 | for i, (s, t) in enumerate(zip(source, target)): 27 | s, t = s.strip(), t.strip() 28 | data.append({ 29 | 'id': i, 'text': s, 'summary': t 30 | }) 31 | return data 32 | 33 | 34 | def load_source_docs(file_path, to_dict=False): 35 | with open(file_path, 'r', encoding="utf-8") as f: 36 | data = [json.loads(line) for line in f] 37 | 38 | if to_dict: 39 | data = {example["id"]: example for example in data} 40 | return data 41 | 42 | 43 | def load_pickle(file_path): 44 | return pickle.load(open(file_path, 'rb')) 45 | 46 | 47 | def save_data(args, data, name_suffix): 48 | output_file = os.path.splitext(args.source_file)[0] + "-" + name_suffix + ".jsonl" 49 | 50 | with open(output_file, "w", encoding="utf-8") as fd: 51 | for example in data: 52 | example = dict(example) 53 | example["text"] = example["text"].text 54 | example["summary"] = example["summary"].text 55 | example["claim"] = example["claim"].text 56 | fd.write(json.dumps(example, ensure_ascii=False) + "\n") 57 | 58 | 59 | def apply_transformation(data, operation): 60 | new_data = [] 61 | for example in tqdm(data): 62 | try: 63 | new_example = operation.transform(example) 64 | if new_example: 65 | new_data.append(new_example) 66 | except Exception as e: 67 | print("Caught exception:", e) 68 | return new_data 69 | 70 | 71 | def main(args): 72 | # load data 73 | # source_docs = load_source_docs(args.data_file, to_dict=False) 74 | # print("Loaded %d source documents." % len(source_docs)) 75 | 76 | # # create or load positive examples 77 | # print("Creating data examples") 78 | # sclaims_op = ops.SampleSentences() 79 | # data = apply_transformation(source_docs, sclaims_op) 80 | # {'id', 'text', 'claim', 'label': 'CORRECT', 'extraction_span': None, 81 | # 'backtranslation': False, 'augmentation': None, 'augmentation_span': None, 'noise': False} 82 | # print("Created %s example pairs." % len(data)) 83 | 84 | # print("Augmentations: {}".format(args.augmentations)) 85 | # # load data 86 | # source_docs = load_source_docs(args.data_file, to_dict=False) 87 | # print("Loaded %d source documents & summaries." % len(source_docs)) 88 | 89 | # load data 90 | source_docs = read_fairseq_files(args.source_file, args.target_file) 91 | print("Loaded %d source documents & summaries." % len(source_docs)) 92 | 93 | # create or load positive examples 94 | print("Creating data examples") 95 | format_op = ops.FormatTransformation() 96 | data = apply_transformation(source_docs, format_op) 97 | print("Created %s example pairs." % len(data)) 98 | 99 | if args.save_intermediate: 100 | save_data(args, data, "clean") 101 | 102 | # backtranslate 103 | data_btrans = [] 104 | if not args.augmentations or "backtranslation" in args.augmentations: 105 | print("Creating backtranslation examples") 106 | btrans_op = ops.Backtranslation() 107 | data_btrans = apply_transformation(data, btrans_op) 108 | print("Backtranslated %s example pairs." % len(data_btrans)) 109 | 110 | if args.save_intermediate: 111 | save_data(args, data_btrans, "btrans") 112 | 113 | data_positive = data_btrans 114 | else: 115 | data_positive = data 116 | 117 | # save original & translation data 118 | # data_positive = data + data_btrans 119 | # if args.save_ensemble: 120 | # print("- Positive %s example pairs." % len(data_positive)) 121 | # save_data(args, data_positive, "positive") 122 | 123 | # create negative examples 124 | data_pronoun = [] 125 | if not args.augmentations or "pronoun_swap" in args.augmentations: 126 | print("Creating pronoun examples") 127 | pronoun_op = ops.PronounSwap() 128 | data_pronoun = apply_transformation(data_positive, pronoun_op) 129 | print("PronounSwap %s example pairs." % len(data_pronoun)) 130 | 131 | if args.save_intermediate: 132 | save_data(args, data_pronoun, "pronoun") 133 | 134 | data_dateswp = [] 135 | if not args.augmentations or "date_swap" in args.augmentations: 136 | print("Creating date swap examples") 137 | dateswap_op = ops.DateSwap() 138 | data_dateswp = apply_transformation(data_positive, dateswap_op) 139 | print("DateSwap %s example pairs." % len(data_dateswp)) 140 | 141 | if args.save_intermediate: 142 | save_data(args, data_dateswp, "dateswp") 143 | 144 | data_numswp = [] 145 | if not args.augmentations or "number_swap" in args.augmentations: 146 | print("Creating number swap examples") 147 | numswap_op = ops.NumberSwap() 148 | data_numswp = apply_transformation(data_positive, numswap_op) 149 | print("NumberSwap %s example pairs." % len(data_numswp)) 150 | 151 | if args.save_intermediate: 152 | save_data(args, data_numswp, "numswp") 153 | 154 | data_entswp = [] 155 | if not args.augmentations or "entity_swap" in args.augmentations: 156 | print("Creating entity swap examples") 157 | entswap_op = ops.EntitySwap() 158 | data_entswp = apply_transformation(data_positive, entswap_op) 159 | print("EntitySwap %s example pairs." % len(data_entswp)) 160 | 161 | if args.save_intermediate: 162 | save_data(args, data_entswp, "entswp") 163 | 164 | data_negation = [] 165 | if not args.augmentations or "negation" in args.augmentations: 166 | print("Creating negation examples") 167 | negation_op = ops.NegateSentences() 168 | data_negation = apply_transformation(data_positive, negation_op) 169 | print("Negation %s example pairs." % len(data_negation)) 170 | 171 | if args.save_intermediate: 172 | save_data(args, data_negation, "negation") 173 | 174 | # add noise to all 175 | if args.save_ensemble: 176 | data_negative = data_pronoun + data_dateswp + data_numswp + data_entswp + data_negation 177 | print("- Negative %s example pairs." % len(data_negative)) 178 | save_data(args, data_negative, "negative") 179 | 180 | # add light noise 181 | data_pos_low_noise, data_neg_low_noise = [], [] 182 | if not args.augmentations or "noise" in args.augmentations: 183 | print("Adding light noise to data") 184 | low_noise_op = ops.AddNoise() 185 | 186 | data_pos_low_noise = apply_transformation(data_positive, low_noise_op) 187 | print("- PositiveNoisy %s example pairs." % len(data_pos_low_noise)) 188 | save_data(args, data_pos_low_noise, "positive-noise") 189 | 190 | data_neg_low_noise = apply_transformation(data_negative, low_noise_op) 191 | print("- NegativeNoisy %s example pairs." % len(data_neg_low_noise)) 192 | save_data(args, data_neg_low_noise, "negative-noise") 193 | 194 | 195 | if __name__ == "__main__": 196 | PARSER = argparse.ArgumentParser() 197 | 198 | # PARSER.add_argument("data_file", type=str, help="Path to file containing source documents.") 199 | PARSER.add_argument("--source_file", type=str, help="Path to file contains source documents.") 200 | PARSER.add_argument("--target_file", type=str, help="Path to file contains target summaries.") 201 | PARSER.add_argument("--augmentations", type=str, nargs="+", default=(), help="List of data augmentation applied to data.") 202 | PARSER.add_argument("--all_augmentations", action="store_true", help="Flag whether all augmentation should be applied.") 203 | PARSER.add_argument("--save_intermediate", action="store_true", help="Flag whether intermediate data from each transformation should be saved in separate files.") 204 | PARSER.add_argument("--save_ensemble", action="store_true", help="Flag whether all negative data should be saved in a single file.") 205 | ARGS = PARSER.parse_args() 206 | 207 | main(ARGS) -------------------------------------------------------------------------------- /build_dataset/create_data.sh: -------------------------------------------------------------------------------- 1 | FILE_TYPE=val 2 | 3 | python create_data.py --source_file ../cnn-dailymail/$FILE_TYPE.source --target_file ../cnn-dailymail/$FILE_TYPE.target --augmentations entity_swap pronoun_swap date_swap number_swap --save_intermediate 4 | -------------------------------------------------------------------------------- /build_dataset/create_data_cc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --account=rpp-bengioy # Yoshua pays for your job 3 | #SBATCH --cpus-per-task=6 # Ask for 6 CPUs 4 | #SBATCH --gres=gpu:1 # Ask for 1 GPU 5 | #SBATCH --mem=32G # Ask for 32 GB of RAM 6 | #SBATCH --time=3:00:00 # The job will run for 3 hours 7 | #SBATCH -o /scratch//slurm-%j.out # Write the log in $SCRATCH 8 | FILE_TYPE=train 9 | 10 | # 1. Create your environement locally 11 | module load miniconda3 12 | source activate py37 13 | 14 | # 2. Copy your dataset on the compute node 15 | # IMPORTANT: Your dataset must be compressed in one single file (zip, hdf5, ...)!!! 16 | cp $SCRATCH/summarization/cnn_dm/fairseq_files/$FILE_TYPE.source $SLURM_TMPDIR 17 | cp $SCRATCH/summarization/cnn_dm/fairseq_files/$FILE_TYPE.target $SLURM_TMPDIR 18 | 19 | # 3. Eventually unzip your dataset 20 | # unzip $SLURM_TMPDIR/ -d $SLURM_TMPDIR 21 | 22 | # 4. Launch your job, tell it to save the model in $SLURM_TMPDIR 23 | # and look for the dataset into $SLURM_TMPDIR 24 | python create_data.py --source_file $SLURM_TMPDIR/$FILE_TYPE.source --target_file $SLURM_TMPDIR/$FILE_TYPE.target --augmentations entity_swap pronoun_swap date_swap number_swap --save_intermediate 25 | 26 | # 5. Copy whatever you want to save on $SCRATCH 27 | cp $SLURM_TMPDIR/*.jsonl $SCRATCH/summarization/cnn_dm/corrupted_files 28 | -------------------------------------------------------------------------------- /model/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """Build dataset for model training and testing. 5 | 6 | Author: Meng Cao 7 | """ 8 | import os 9 | import json 10 | import copy 11 | import torch 12 | import pickle 13 | 14 | from torch.utils.data import TensorDataset 15 | 16 | BOS_TOKEN, EOS_TOKEN = '[BOS]', '[EOS]' 17 | 18 | 19 | class InputExample(object): 20 | """ 21 | A single training/test example for simple sequence classification. 22 | 23 | Args: 24 | guid: Unique id for the example. 25 | text_a: string. The untokenized text of the first sequence. For single 26 | sequence tasks, only this sequence must be specified. 27 | text_b: (Optional) string. The untokenized text of the second sequence. 28 | Only must be specified for sequence pair tasks. 29 | label: (Optional) string. The label of the example. This should be 30 | specified for train and dev examples, but not for test examples. 31 | """ 32 | def __init__(self, guid, text_a, text_b, label=None): 33 | self.guid = guid 34 | self.text_a = text_a 35 | self.text_b = text_b 36 | self.label = label 37 | 38 | def __repr__(self): 39 | return str(self.to_json_string()) 40 | 41 | def to_dict(self): 42 | """Serializes this instance to a Python dictionary.""" 43 | output = copy.deepcopy(self.__dict__) 44 | return output 45 | 46 | def to_json_string(self): 47 | """Serializes this instance to a JSON string.""" 48 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 49 | 50 | 51 | class InputFeatures(object): 52 | """ 53 | A single set of features of data. 54 | 55 | Args: 56 | input_ids: Indices of input sequence tokens in the vocabulary. 57 | attention_mask: Mask to avoid performing attention on padding token indices. 58 | Mask values selected in ``[0, 1]``: 59 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 60 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 61 | """ 62 | 63 | def __init__(self, input_ids, attention_mask, token_type_ids): 64 | self.input_ids = input_ids 65 | self.attention_mask = attention_mask 66 | self.token_type_ids = token_type_ids 67 | 68 | def __repr__(self): 69 | return str(self.to_json_string()) 70 | 71 | def to_dict(self): 72 | """Serializes this instance to a Python dictionary.""" 73 | output = copy.deepcopy(self.__dict__) 74 | return output 75 | 76 | def to_json_string(self): 77 | """Serializes this instance to a JSON string.""" 78 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 79 | 80 | 81 | class Processor(object): 82 | """Processor for the Leader-prize competition data set""" 83 | 84 | def get_train_examples(self, data_dir): 85 | """See base class.""" 86 | return self._create_examples( 87 | os.path.join(data_dir, "train.jsonl"), "train") 88 | 89 | def get_dev_examples(self, data_dir): 90 | """See base class.""" 91 | return self._create_examples( 92 | os.path.join(data_dir, "val.jsonl"), "dev") 93 | 94 | def get_test_examples(self, data_dir): 95 | """See base class.""" 96 | return self._create_examples( 97 | os.path.join(data_dir, "test.jsonl"), "test") 98 | 99 | def _create_examples(self, file_path, set_type): 100 | """Creates examples for the training and dev sets.""" 101 | with open(file_path, 'r', encoding="utf-8") as f: 102 | data = [json.loads(line) for line in f] 103 | 104 | examples = [] 105 | for i, d in enumerate(data): 106 | guid = d['id'] 107 | # text_a: story, text_b: noisy_summary, label: summary 108 | text_a, text_b, label = d['text'], d['claim'], d['summary'] 109 | examples.append( 110 | InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) 111 | return examples 112 | 113 | 114 | def load_and_cache_examples(args, tokenizer, evaluate=False): 115 | # Make sure only the first process in distributed training process the dataset, and the others will use the cache 116 | if args.local_rank not in [-1, 0] and not evaluate: 117 | torch.distributed.barrier() 118 | 119 | processor = Processor() 120 | # Load data features from cache or dataset file 121 | cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}'.format( 122 | 'dev' if evaluate else 'train', 123 | args.model_type, 124 | str(args.max_summary_length))) 125 | cached_guids_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_guids'.format( 126 | 'dev' if evaluate else 'train', 127 | args.model_type, 128 | str(args.max_summary_length))) 129 | 130 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 131 | args.logger.info("Loading features from cached file %s", cached_features_file) 132 | features = torch.load(cached_features_file) 133 | guids = torch.load(cached_guids_file) 134 | else: 135 | args.logger.info("Creating features from dataset file at %s", args.data_dir) 136 | 137 | examples = processor.get_dev_examples(args.data_dir) if evaluate \ 138 | else processor.get_train_examples(args.data_dir) 139 | 140 | pad_token_id = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0] 141 | # story_features, guids = convert_inputs_to_features(examples, tokenizer, 142 | # summary=False, 143 | # max_length=args.max_story_length, 144 | # pad_on_left=False, 145 | # pad_token=pad_token_id, 146 | # pad_token_segment_id=0) 147 | 148 | # summary_features, _ = convert_inputs_to_features(examples, tokenizer, 149 | # summary=True, 150 | # max_length=args.max_summary_length, 151 | # pad_on_left=False, 152 | # pad_token=pad_token_id, 153 | # pad_token_segment_id=0) 154 | 155 | text_features, guids = convert_text_and_summary(examples, tokenizer, 156 | pad_on_left=False, 157 | pad_token=pad_token_id, 158 | pad_token_segment_id=0) 159 | 160 | src_ids, tgt_ids = convert_outputs_to_features(examples, tokenizer, 161 | max_length=args.max_summary_length, 162 | pad_on_left=False, 163 | pad_token=pad_token_id, 164 | pad_token_segment_id=0) 165 | features = [text_features, src_ids, tgt_ids] 166 | 167 | if args.local_rank in [-1, 0]: 168 | args.logger.info("Saving features into cached file %s", cached_features_file) 169 | torch.save(features, cached_features_file) 170 | torch.save(guids, cached_guids_file) 171 | 172 | # Make sure only the first process in distributed training process the dataset, and the others will use the cache 173 | if args.local_rank == 0 and not evaluate: 174 | torch.distributed.barrier() 175 | 176 | # story_features, summary_features, src_ids, tgt_ids = features 177 | text_features, src_ids, tgt_ids = features 178 | 179 | # Convert to Tensors and build dataset 180 | text_ids = torch.tensor([f.input_ids for f in text_features], dtype=torch.long) 181 | text_attention_mask = torch.tensor([f.attention_mask for f in text_features], dtype=torch.long) 182 | text_token_type_ids = torch.tensor([f.token_type_ids for f in text_features], dtype=torch.long) 183 | 184 | # story_ids = torch.tensor([f.input_ids for f in story_features], dtype=torch.long) 185 | # story_attention_mask = torch.tensor([f.attention_mask for f in story_features], dtype=torch.long) 186 | # story_token_type_ids = torch.tensor([f.token_type_ids for f in story_features], dtype=torch.long) 187 | 188 | # summary_ids = torch.tensor([f.input_ids for f in summary_features], dtype=torch.long) 189 | # summary_attention_mask = torch.tensor([f.attention_mask for f in summary_features], dtype=torch.long) 190 | # summary_token_type_ids = torch.tensor([f.token_type_ids for f in summary_features], dtype=torch.long) 191 | 192 | src_ids = torch.tensor([f.input_ids for f in src_ids], dtype=torch.long) 193 | tgt_ids = torch.tensor([f.input_ids for f in tgt_ids], dtype=torch.long) 194 | src_attention_mask = torch.tensor([f.attention_mask for f in src_ids], dtype=torch.long) 195 | src_token_type_ids = torch.tensor([f.token_type_ids for f in src_ids], dtype=torch.long) 196 | 197 | # dataset = TensorDataset(story_ids, story_attention_mask, story_token_type_ids, 198 | # summary_ids, summary_attention_mask, summary_token_type_ids 199 | # src_ids, tgt_ids, src_attention_mask, src_token_type_ids) 200 | 201 | dataset = TensorDataset(text_ids, text_attention_mask, text_token_type_ids, 202 | src_ids, tgt_ids, src_attention_mask, src_token_type_ids) 203 | 204 | return dataset, guids 205 | 206 | 207 | def tokenzie(tokenizer, sentence, max_length, add_special_tokens=True): 208 | """Tokenize string to list of ids. 209 | """ 210 | inputs = tokenizer.encode_plus( 211 | sentence, 212 | add_special_tokens=add_special_tokens, 213 | max_length=max_length, 214 | ) 215 | return inputs["input_ids"], inputs["token_type_ids"] 216 | 217 | 218 | def convert_outputs_to_features(examples, tokenizer, 219 | max_length=200, 220 | pad_on_left=False, 221 | pad_token=0, 222 | pad_token_segment_id=0, 223 | mask_padding_with_zero=True): 224 | """ 225 | Loads a data file into a list of ``InputFeatures`` 226 | 227 | Args: 228 | examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples. 229 | tokenizer: Instance of a tokenizer that will tokenize the examples 230 | max_length: Maximum example length 231 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default) 232 | pad_token: Padding token 233 | pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet 234 | where it is 4) 235 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values 236 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for 237 | actual values) 238 | 239 | Returns: 240 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset`` 241 | containing the task-specific features. If the input is a list of ``InputExamples``, will return 242 | a list of task-specific ``InputFeatures`` which can be fed to the model. 243 | 244 | """ 245 | src_features, tgt_features = [], [] 246 | for (ex_index, example) in enumerate(examples): 247 | inputs = tokenzie(tokenizer, example.label, max_length - 1, False) 248 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] 249 | 250 | # add [BOS] & [EOS] tokens 251 | source_ids = [tokenizer.bos_token_id] + input_ids 252 | target_ids = input_ids + [tokenizer.eos_token_id] 253 | token_type_ids = token_type_ids + [0] 254 | 255 | # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. 256 | attention_mask = [1 if mask_padding_with_zero else 0] * len(source_ids) 257 | 258 | # Zero-pad up to the sequence length. 259 | padding_length = max_length - len(source_ids) 260 | if pad_on_left: 261 | source_ids = ([pad_token] * padding_length) + source_ids 262 | target_ids = ([pad_token] * padding_length) + target_ids 263 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 264 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids 265 | else: 266 | source_ids = source_ids + ([pad_token] * padding_length) 267 | target_ids = target_ids + ([pad_token] * padding_length) 268 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 269 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 270 | 271 | assert len(source_ids) == max_length, \ 272 | "Error with input length {} vs {}".format(len(input_ids), max_length) 273 | assert len(target_ids) == max_length, \ 274 | "Error with input length {} vs {}".format(len(input_ids), max_length) 275 | assert len(attention_mask) == max_length, \ 276 | "Error with input length {} vs {}".format(len(attention_mask), max_length) 277 | assert len(token_type_ids) == max_length, \ 278 | "Error with input length {} vs {}".format(len(token_type_ids), max_length) 279 | 280 | # if ex_index < 5: 281 | # print("*** Example ***") 282 | # print("guid: %s" % (example.guid)) 283 | # print("input_ids: %s" % " ".join([str(x) for x in input_ids])) 284 | # print("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) 285 | # print("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids])) 286 | 287 | src_features.append(InputFeatures(input_ids=source_ids, attention_mask=attention_mask, 288 | token_type_ids=token_type_ids)) 289 | tgt_features.append(InputFeatures(input_ids=target_ids, attention_mask=attention_mask, 290 | token_type_ids=token_type_ids)) 291 | return src_features, tgt_features 292 | 293 | 294 | def convert_inputs_to_features(examples, tokenizer, 295 | summary=True, 296 | max_length=512, 297 | pad_on_left=False, 298 | pad_token=0, 299 | pad_token_segment_id=0, 300 | mask_padding_with_zero=True): 301 | """ 302 | Loads a data file into a list of ``InputFeatures`` 303 | 304 | Args: 305 | examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples. 306 | tokenizer: Instance of a tokenizer that will tokenize the examples. 307 | summary: if ture convert summary, otherwise convert story 308 | max_length: Maximum example length 309 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default) 310 | pad_token: Padding token 311 | pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet 312 | where it is 4) 313 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values 314 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for 315 | actual values) 316 | 317 | Returns: 318 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset`` 319 | containing the task-specific features. If the input is a list of ``InputExamples``, will return 320 | a list of task-specific ``InputFeatures`` which can be fed to the model. 321 | 322 | """ 323 | features, guids = [], [] 324 | for (ex_index, example) in enumerate(examples): 325 | # if ex_index % 10000 == 0: 326 | # print("Writing example %d" % (ex_index)) 327 | example_data = example.text_b if summary else example.text_a 328 | 329 | inputs = tokenzie(tokenizer, example_data, max_length, False) 330 | input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"] 331 | 332 | # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. 333 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 334 | 335 | # Zero-pad up to the sequence length. 336 | padding_length = max_length - len(input_ids) 337 | if pad_on_left: 338 | input_ids = ([pad_token] * padding_length) + input_ids 339 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 340 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids 341 | else: 342 | input_ids = input_ids + ([pad_token] * padding_length) 343 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 344 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 345 | 346 | assert len(input_ids) == max_length, \ 347 | "Error with input length {} vs {}".format(len(input_ids), max_length) 348 | assert len(attention_mask) == max_length, \ 349 | "Error with input length {} vs {}".format(len(attention_mask), max_length) 350 | assert len(token_type_ids) == max_length, \ 351 | "Error with input length {} vs {}".format(len(token_type_ids), max_length) 352 | 353 | # if ex_index < 5: 354 | # print("*** Example ***") 355 | # print("guid: %s" % (example.guid)) 356 | # print("input_ids: %s" % " ".join([str(x) for x in input_ids])) 357 | # print("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) 358 | # print("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids])) 359 | 360 | guids.append(example.guid) 361 | features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, 362 | token_type_ids=token_type_ids)) 363 | return features, guids 364 | 365 | 366 | def convert_text_and_summary(examples, tokenizer, 367 | max_length=512, 368 | pad_on_left=False, 369 | pad_token=0, 370 | pad_token_segment_id=0, 371 | mask_padding_with_zero=True): 372 | """ 373 | Loads a data file into a list of ``InputFeatures`` 374 | 375 | Args: 376 | examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples. 377 | tokenizer: Instance of a tokenizer that will tokenize the examples. 378 | max_length: Maximum example length 379 | pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default) 380 | pad_token: Padding token 381 | pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet 382 | where it is 4) 383 | mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values 384 | and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for 385 | actual values) 386 | 387 | Returns: 388 | If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset`` 389 | containing the task-specific features. If the input is a list of ``InputExamples``, will return 390 | a list of task-specific ``InputFeatures`` which can be fed to the model. 391 | 392 | """ 393 | features, guids = [], [] 394 | for ex_index, example in enumerate(examples): 395 | text, summary = example.text_a, example.text_b 396 | 397 | text_ = tokenzie(tokenizer, text, 400, False) 398 | text_ids, text_token_type_ids = text_["input_ids"], text_["token_type_ids"] 399 | 400 | summary_ = tokenzie(tokenizer, summary, 100, False) 401 | summary_ids, summary_token_type_ids = summary_["input_ids"], summary_["token_type_ids"] 402 | 403 | sep_token_id = tokenizer.sep_token_id 404 | input_ids = summary_ids + [sep_token_id] + text_ids + [sep_token_id] 405 | token_type_ids = summary_token_type_ids + [0] + text_token_type_ids + [0] 406 | 407 | # The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. 408 | attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 409 | 410 | # Zero-pad up to the sequence length. 411 | padding_length = max_length - len(input_ids) 412 | if pad_on_left: 413 | input_ids = ([pad_token] * padding_length) + input_ids 414 | attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask 415 | token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids 416 | else: 417 | input_ids = input_ids + ([pad_token] * padding_length) 418 | attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 419 | token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) 420 | 421 | assert len(input_ids) == max_length, \ 422 | "Error with input length {} vs {}".format(len(input_ids), max_length) 423 | assert len(attention_mask) == max_length, \ 424 | "Error with input length {} vs {}".format(len(attention_mask), max_length) 425 | assert len(token_type_ids) == max_length, \ 426 | "Error with input length {} vs {}".format(len(token_type_ids), max_length) 427 | 428 | # if ex_index < 5: 429 | # print("*** Example ***") 430 | # print("guid: %s" % (example.guid)) 431 | # print("input_ids: %s" % " ".join([str(x) for x in input_ids])) 432 | # print("attention_mask: %s" % " ".join([str(x) for x in attention_mask])) 433 | # print("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids])) 434 | 435 | guids.append(example.guid) 436 | features.append(InputFeatures(input_ids=input_ids, attention_mask=attention_mask, 437 | token_type_ids=token_type_ids)) 438 | return features, guids 439 | -------------------------------------------------------------------------------- /model/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import torch 6 | import argparse 7 | 8 | from os.path import join 9 | from datetime import datetime 10 | from utils import get_logger, set_seed 11 | from model import Model, MODEL_CLASSES, ALL_MODELS 12 | from data_utils import load_and_cache_examples, processors 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 14 | from transformers import BertModel 15 | 16 | 17 | def main(): 18 | # directory for training outputs 19 | output_dir = "results/{:%Y%m%d_%H%M%S}/".format(datetime.now()) 20 | 21 | # required parameters 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument("--model_type", default=None, type=str, required=True, 25 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 26 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 27 | help="Path to pre-trained model or shortcut name: " + ", ".join(ALL_MODELS)) 28 | 29 | parser.add_argument('--load_transformer', action='store_true', default=False, 30 | help="If need to load transformer.") 31 | parser.add_argument('--transformer_path', type=str, default='', 32 | help="The path to pre-trained transformer.") 33 | 34 | # other parameters 35 | parser.add_argument("--task_name", default='lpc', type=str, 36 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) 37 | parser.add_argument("--config_name", default="", type=str, 38 | help="Pretrained config name or path if not the same as model_name") 39 | parser.add_argument("--tokenizer_name", default="", type=str, 40 | help="Pretrained tokenizer name or path if not the same as model_name") 41 | parser.add_argument("--weight_decay", default=0.0, type=float, 42 | help="Weight deay if we apply some.") 43 | parser.add_argument("--learning_rate", default=2e-5, type=float, 44 | help="The initial learning rate for Adam.") 45 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 46 | help="Epsilon for Adam optimizer.") 47 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 48 | help="Max gradient norm.") 49 | parser.add_argument("--warmup_steps", default=5, type=int, 50 | help="Linear warmup over warmup_steps.") 51 | parser.add_argument("--per_gpu_train_batch_size", default=4, type=int, 52 | help="Batch size for training.") 53 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 54 | help="Batch size for evaluation.") 55 | parser.add_argument("--no_cuda", default=False, type=bool, 56 | help="Do not use cuda.") 57 | parser.add_argument("--do_lower_case", default=True, type=bool, 58 | help="Do lower case.") 59 | parser.add_argument("--use_pretrained", default=False, type=bool, 60 | help="If use pre-trained model weights.") 61 | parser.add_argument("--seed", default=610, type=int, 62 | help="Random seed.") 63 | parser.add_argument("--num_labels", default=3, type=int, 64 | help="Classification label number.") 65 | parser.add_argument("--num_epochs", default=30, type=int, 66 | help="Total number of training epochs to perform.") 67 | parser.add_argument("--scheduler", default='warmup', type=str, 68 | help="Which type of scheduler to use.") 69 | parser.add_argument("--local_rank", type=int, default=-1, 70 | help="For distributed training: local_rank") 71 | parser.add_argument("--max_seq_length", default=256, type=int, 72 | help="The maximum total input sequence length after tokenization. Sequences longer " 73 | "than this will be truncated, sequences shorter will be padded.") 74 | parser.add_argument('--overwrite_cache', action='store_true', default=False, 75 | help="Overwrite the cached training and evaluation sets") 76 | parser.add_argument('--write_summary', default=True, type=bool, 77 | help="If write summary into tensorboard.") 78 | parser.add_argument('--fp16', action='store_true', default=False, 79 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 80 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 81 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 82 | "See details at https://nvidia.github.io/apex/amp.html") 83 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 84 | help="Number of updates steps to accumulate before performing a backward/update pass.") 85 | 86 | # data directory 87 | parser.add_argument("--data_dir", default='../data/TF-IDF', type=str, 88 | help="data directory where pickle dataset is stored.") 89 | parser.add_argument("--output_dir", default=output_dir, type=str, 90 | help="output directory for model, log file and summary.") 91 | parser.add_argument("--log_path", default=join(output_dir, "log.txt"), type=str, 92 | help="Path to log.txt.") 93 | parser.add_argument("--summary_path", default=join(output_dir, "summary"), type=str, 94 | help="Path to summary file.") 95 | parser.add_argument("--model_dir", default=join(output_dir, "model/"), type=str, 96 | help="where to load pre-trained model.") 97 | parser.add_argument("--checkpoint", default=join(output_dir, "model/"), type=str, 98 | help="Where to load pre-trained transformer model.") 99 | 100 | args = parser.parse_args() 101 | 102 | if not os.path.exists(args.model_dir): 103 | os.makedirs(args.model_dir) 104 | 105 | args.logger = get_logger(args.log_path) 106 | 107 | # Setup CUDA, GPU & distributed training 108 | args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 109 | args.n_gpu = torch.cuda.device_count() 110 | args.logger.info("- device: {}, n_gpu: {}".format(args.device, args.n_gpu)) 111 | 112 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 113 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 114 | 115 | # set seed 116 | set_seed(args.seed) 117 | 118 | # build model 119 | args.logger.info("Building model...") 120 | model = Model(args) 121 | 122 | # load transformers 123 | if args.load_transformer: 124 | args.logger.info("Loading pre-trained transformer...") 125 | weights = BertModel.from_pretrained(args.transformer_path) 126 | model.load_transformer(weights) 127 | 128 | # build dataset 129 | args.logger.info("Loading dataset...") 130 | train_dataset, _ = load_and_cache_examples(args, args.task_name, model.tokenizer, evaluate=False) 131 | eval_dataset, _ = load_and_cache_examples(args, args.task_name, model.tokenizer, evaluate=True) 132 | 133 | train_sampler = RandomSampler(train_dataset) 134 | train_dataloader = DataLoader(train_dataset, 135 | sampler=train_sampler, 136 | batch_size=args.train_batch_size) 137 | 138 | eval_sampler = SequentialSampler(eval_dataset) 139 | eval_dataloader = DataLoader(eval_dataset,· 140 | sampler=eval_sampler, 141 | batch_size=args.eval_batch_size) 142 | 143 | # training 144 | args.logger.info("Start training !!!") 145 | model.fit(train_dataloader, eval_dataloader) 146 | 147 | # test & get report 148 | args.logger.info("Loading best mode and start testing:") 149 | model.load_weights(args.model_dir) 150 | model.evaluate(eval_dataloader, True) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /model/main_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import torch 6 | import argparse 7 | 8 | from os.path import join 9 | from datetime import datetime 10 | from utils import get_logger, set_seed 11 | from model import Model, MODEL_CLASSES, ALL_MODELS 12 | from data_utils import load_and_cache_examples, processors 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 14 | from transformers import BertModel 15 | 16 | 17 | def main(): 18 | # directory for training outputs 19 | output_dir = "results/{:%Y%m%d_%H%M%S}/".format(datetime.now()) 20 | 21 | # required parameters 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument("--model_type", default=None, type=str, required=True, 25 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 26 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 27 | help="Path to pre-trained model or shortcut name: " + ", ".join(ALL_MODELS)) 28 | 29 | # data directory 30 | parser.add_argument("--data_dir", default='', type=str, 31 | help="data directory where pickle dataset is stored.") 32 | parser.add_argument("--output_dir", default=output_dir, type=str, 33 | help="output directory for model, log file and summary.") 34 | parser.add_argument("--log_path", default=join(output_dir, "log.txt"), type=str, 35 | help="Path to log.txt.") 36 | parser.add_argument("--summary_path", default=join(output_dir, "summary"), type=str, 37 | help="Path to summary file.") 38 | parser.add_argument("--model_dir", default=join(output_dir, "model/"), type=str, 39 | help="where to load pre-trained model.") 40 | parser.add_argument("--local_rank", type=int, default=-1, 41 | help="For distributed training: local_rank") 42 | parser.add_argument("--max_summary_length", default=200, type=int, 43 | help="The maximum total input sequence length after tokenization. Sequences longer " 44 | "than this will be truncated, sequences shorter will be padded.") 45 | parser.add_argument('--overwrite_cache', action='store_true', default=False, 46 | help="Overwrite the cached training and evaluation sets") 47 | 48 | 49 | args = parser.parse_args() 50 | args.logger = get_logger(args.log_path) 51 | 52 | # Setup CUDA, GPU & distributed training 53 | args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 54 | args.n_gpu = torch.cuda.device_count() 55 | args.logger.info("- device: {}, n_gpu: {}".format(args.device, args.n_gpu)) 56 | 57 | # set seed 58 | set_seed(args.seed) 59 | 60 | # load tokenizer 61 | tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path) 62 | 63 | # build dataset 64 | args.logger.info("Loading dataset...") 65 | train_dataset, _ = load_and_cache_examples(args, tokenizer, evaluate=False) 66 | eval_dataset, _ = load_and_cache_examples(args, tokenizer, evaluate=True) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /model/model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import math\n", 10 | "import torch\n", 11 | "import torch.nn as nn\n", 12 | "\n", 13 | "from fairseq.models.bart import BARTModel" 14 | ] 15 | }, 16 | { 17 | "cell_type": "markdown", 18 | "metadata": {}, 19 | "source": [ 20 | "#### Load Model and Tokenizer" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "bart = BARTModel.from_pretrained('/home/ml/cadencao/Downloads/BART_models/bart.large.xsum',\n", 30 | " checkpoint_file='model.pt',\n", 31 | " data_name_or_path='/home/ml/cadencao/Downloads/BART_models/bart.large.xsum')" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "name": "stdout", 41 | "output_type": "stream", 42 | "text": [ 43 | "- activate evaluation mode\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "bart.cuda()\n", 49 | "bart.eval()\n", 50 | "bart.half()\n", 51 | "print('- activate evaluation mode')" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "metadata": {}, 57 | "source": [ 58 | "#### Data Preparation" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 18, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "val_source_path = '/home/ml/cadencao/XSum/fairseq_files/val.bpe.source'\n", 68 | "val_target_path = '/home/ml/cadencao/XSum/fairseq_files/val.bpe.target'" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": 19, 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "bpe_val_source, bpe_val_target = [], []\n", 78 | "with open(val_source_path, 'r') as sf, open(val_target_path, 'r') as tf:\n", 79 | " for s, t in zip(sf, tf):\n", 80 | " bpe_val_source.append(s.strip())\n", 81 | " bpe_val_target.append(t.strip())" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 20, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "name": "stdout", 91 | "output_type": "stream", 92 | "text": [ 93 | "tensor([ 0, 15622, 320, 1754, 1470, 1321, 33, 57, 303, 2181,\n", 94 | " 9, 28304, 5, 15331, 31, 5, 7314, 9, 80, 4585,\n", 95 | " 9886, 10, 529, 59, 633, 2599, 4, 2],\n", 96 | " dtype=torch.int32)\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "torch_tokens = bart.task.source_dictionary.encode_line(' ' + bpe_val_target[0] + ' ', append_eos=False)\n", 102 | "print(torch_tokens)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 21, 108 | "metadata": {}, 109 | "outputs": [ 110 | { 111 | "data": { 112 | "text/plain": [ 113 | "'Three former Air France employees have been found guilty of ripping the shirts from the backs of two executives fleeing a meeting about job cuts.'" 114 | ] 115 | }, 116 | "execution_count": 21, 117 | "metadata": {}, 118 | "output_type": "execute_result" 119 | } 120 | ], 121 | "source": [ 122 | "bart.decode(torch_tokens.long())" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [] 131 | } 132 | ], 133 | "metadata": { 134 | "kernelspec": { 135 | "display_name": "Python 3", 136 | "language": "python", 137 | "name": "python3" 138 | }, 139 | "language_info": { 140 | "codemirror_mode": { 141 | "name": "ipython", 142 | "version": 3 143 | }, 144 | "file_extension": ".py", 145 | "mimetype": "text/x-python", 146 | "name": "python", 147 | "nbconvert_exporter": "python", 148 | "pygments_lexer": "ipython3", 149 | "version": "3.6.10" 150 | } 151 | }, 152 | "nbformat": 4, 153 | "nbformat_minor": 4 154 | } 155 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | 8 | from transformers import (BertConfig, 9 | BertForSequenceClassification, BertTokenizer, 10 | RobertaConfig, 11 | RobertaForSequenceClassification, 12 | RobertaTokenizer, 13 | XLMConfig, XLMForSequenceClassification, 14 | XLMTokenizer, XLNetConfig, 15 | XLNetForSequenceClassification, 16 | XLNetTokenizer, 17 | DistilBertConfig, 18 | DistilBertForSequenceClassification, 19 | DistilBertTokenizer) 20 | from transformers import AdamW, WarmupLinearSchedule 21 | from sklearn.metrics import classification_report, f1_score 22 | from torch.utils.tensorboard import SummaryWriter 23 | from torch.optim.lr_scheduler import ExponentialLR 24 | # from apex import amp 25 | from tqdm import tqdm 26 | 27 | 28 | MODEL_CLASSES = { 29 | 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 30 | 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 31 | 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 32 | 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 33 | 'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer) 34 | } 35 | 36 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, 37 | RobertaConfig, DistilBertConfig)), ()) 38 | 39 | 40 | class Model: 41 | """Enhanced Sequential Inference Model (ESIM) for natural language inference. 42 | """ 43 | def __init__(self, args): 44 | """Model initialization. 45 | """ 46 | self.args = args 47 | self.logger = args.logger 48 | 49 | self._build_model() 50 | self.model.to(args.device) 51 | 52 | self.optimizer = self._get_optimizer(self._group_parameters(self.model)) 53 | self.scheduler = self._get_scheduler(self.optimizer) 54 | 55 | # Amp: Automatic Mixed Precision 56 | if self.args.fp16: 57 | self.model, self.optimizer = amp.initialize(self.model, 58 | self.optimizer, 59 | opt_level=args.fp16_opt_level) 60 | self.logger.info("- Automatic Mixed Precision (AMP) is used.") 61 | else: 62 | self.logger.info("- NO Automatic Mixed Precision (AMP) :/") 63 | 64 | # multi-gpu training (should be after apex fp16 initialization) 65 | if args.n_gpu > 1: 66 | self.logger.info("- Let's use {} GPUs !".format(torch.cuda.device_count())) 67 | self.model = nn.DataParallel(self.model) 68 | else: 69 | self.logger.info("- Train the model on single GPU :/") 70 | 71 | # tensorboard 72 | if args.write_summary: 73 | self.logger.info("- Let's use tensorboard on local rank {} device :)".format(args.local_rank)) 74 | self.writer = SummaryWriter(self.args.summary_path) 75 | 76 | def _build_model(self): 77 | """Build model. 78 | """ 79 | model_type = self.args.model_type.lower() 80 | self.config_class, self.model_class, self.tokenizer_class = MODEL_CLASSES[model_type] 81 | if self.args.use_pretrained: 82 | self.load_weights(self.args.checkpoint) 83 | else: 84 | self._load_from_library(self.args) 85 | 86 | def _load_from_library(self, args): 87 | """Initialize ESIM model paramerters. 88 | """ 89 | self.logger.info("- Downloading model...") 90 | config = self.config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, 91 | num_labels=args.num_labels, 92 | finetuning_task=args.task_name) 93 | self.tokenizer = self.tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name 94 | else args.model_name_or_path, 95 | do_lower_case=args.do_lower_case) 96 | self.model = self.model_class.from_pretrained(args.model_name_or_path, 97 | from_tf=bool('.ckpt' in args.model_name_or_path), 98 | config=config) 99 | 100 | def _group_parameters(self, model): 101 | """Specify which parameters do weight decay and which not. 102 | """ 103 | no_decay = ['bias', 'LayerNorm.weight'] 104 | optimizer_grouped_parameters = [ 105 | {'params': 106 | [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 107 | 'weight_decay': self.args.weight_decay}, 108 | {'params': 109 | [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 110 | 'weight_decay': 0.0} 111 | ] 112 | return optimizer_grouped_parameters 113 | 114 | def _get_optimizer(self, optimizer_grouped_parameters): 115 | """Get optimizer for model training. 116 | """ 117 | optimizer = AdamW(optimizer_grouped_parameters, 118 | lr=self.args.learning_rate, 119 | eps=self.args.adam_epsilon) 120 | return optimizer 121 | 122 | def _get_scheduler(self, optimizer): 123 | """Get scheduler for adjusting learning rate. 124 | """ 125 | if self.args.scheduler == 'warmup': 126 | scheduler = WarmupLinearSchedule(optimizer, 127 | warmup_steps=self.args.warmup_steps, 128 | t_total=self.args.num_epochs) 129 | elif self.args.scheduler == 'exponential': 130 | scheduler = ExponentialLR(optimizer, 0.95) 131 | return scheduler 132 | 133 | def load_weights(self, checkpoint): 134 | """Load pre-trained model weights. 135 | """ 136 | self.logger.info("- Load pre-trained model from: {}".format(checkpoint)) 137 | self.model = self.model_class.from_pretrained(checkpoint) 138 | self.tokenizer = self.tokenizer_class.from_pretrained(checkpoint, 139 | do_lower_case=self.args.do_lower_case) 140 | self.model.to(self.args.device) 141 | return self.model, self.tokenizer 142 | 143 | def load_transformer(self, weights): 144 | """Load pre-trained model weights. 145 | """ 146 | # Take care of distributed/parallel training 147 | model_to_save = self.model.module if hasattr(self.model, 'module') else self.model 148 | model_type = self.args.model_type.lower() 149 | if model_type == 'bert': 150 | model_to_save.bert.load_state_dict(weights.state_dict()) 151 | elif model_type == 'xlnet': 152 | model_to_save.transformer.load_state_dict(weights.state_dict()) 153 | elif model_type == 'distilbert': 154 | model_to_save.distilbert.load_state_dict(weights.state_dict()) 155 | else: 156 | raise Exception("Unknow model type!") 157 | 158 | def save_model(self, output_path): 159 | """Save model's weights. 160 | """ 161 | # Take care of distributed/parallel training 162 | model_to_save = self.model.module if hasattr(self.model, 'module') else self.model 163 | model_to_save.save_pretrained(output_path) 164 | self.tokenizer.save_pretrained(output_path) 165 | torch.save(self.args, os.path.join(output_path, 'training_args.bin')) 166 | self.logger.info("- model, tokenzier and args is saved at: {}".format(output_path)) 167 | 168 | def loss_batch(self, inputs, optimizer=None, step=None): 169 | """Calculate loss on a single batch of data. 170 | """ 171 | if optimizer: 172 | assert step is not None 173 | outputs = self.model(**inputs) 174 | loss, logits = outputs[0], outputs[1] 175 | 176 | if self.args.n_gpu > 1: 177 | loss = loss.mean() # mean() to average on multi-gpu parallel training 178 | if self.args.gradient_accumulation_steps > 1: 179 | loss = loss / self.args.gradient_accumulation_steps 180 | 181 | if optimizer is not None: 182 | if self.args.fp16: 183 | with amp.scale_loss(loss, optimizer) as scaled_loss: 184 | scaled_loss.backward() 185 | else: 186 | loss.backward() 187 | 188 | if (step + 1) % self.args.gradient_accumulation_steps == 0: 189 | if self.args.fp16: 190 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 191 | self.args.max_grad_norm) 192 | else: 193 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 194 | self.args.max_grad_norm) 195 | optimizer.step() # update model parameters 196 | optimizer.zero_grad() # clean all gradients 197 | 198 | return loss.item(), logits.detach() 199 | 200 | def train_epoch(self, train_dataloader, optimizer, epoch): 201 | """Train the model for one single epoch. 202 | """ 203 | self.model.train() # set the model to training mode 204 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 205 | 206 | train_loss = 0.0 207 | for i, batch in enumerate(epoch_iterator): 208 | batch = tuple(t.to(self.args.device) for t in batch) 209 | inputs = {'input_ids': batch[0], 210 | 'attention_mask': batch[1], 211 | 'labels': batch[3]} 212 | 213 | # XLM, DistilBERT and RoBERTa don't use segment_ids 214 | if self.args.model_type != 'distilbert': 215 | inputs['token_type_ids'] = batch[2] if self.args.model_type in ['bert', 'xlnet'] else None 216 | 217 | batch_loss, _ = self.loss_batch(inputs, 218 | optimizer=optimizer, 219 | step=i) 220 | train_loss += batch_loss 221 | 222 | if self.writer: 223 | self.writer.add_scalar('batch_loss', batch_loss, epoch*len(train_dataloader) + i + 1) 224 | 225 | # compute the average loss (batch loss) 226 | epoch_loss = train_loss / len(train_dataloader) 227 | 228 | # update scheduler 229 | self.scheduler.step() 230 | 231 | return epoch_loss 232 | 233 | def evaluate(self, eval_dataloader, print_report=False): 234 | """Evaluate the model. 235 | """ 236 | self.model.eval() # set the model to evaluation mode 237 | with torch.no_grad(): 238 | pred_class, label_class = [], [] 239 | eval_loss, eval_corrects = 0.0, 0.0 240 | for _, batch in enumerate(tqdm(eval_dataloader, desc="Iteration")): 241 | batch = tuple(t.to(self.args.device) for t in batch) 242 | inputs = {'input_ids': batch[0], 243 | 'attention_mask': batch[1], 244 | 'labels': batch[3]} 245 | if self.args.model_type != 'distilbert': 246 | inputs['token_type_ids'] = batch[2] if self.args.model_type in ['bert', 'xlnet'] else None 247 | 248 | batch_loss, outputs = self.loss_batch(inputs, optimizer=None) 249 | _, preds = torch.max(outputs, 1) # preds: [batch_size] 250 | 251 | # save predictions 252 | pred_class += preds.tolist() 253 | label_class += batch[3].tolist() 254 | 255 | # update loss & accuracy 256 | eval_loss += batch_loss 257 | eval_corrects += torch.sum(preds == (batch[3])).double() 258 | 259 | avg_loss = eval_loss / len(eval_dataloader) 260 | avg_acc = eval_corrects / len(eval_dataloader.dataset) 261 | 262 | macro_f1 = f1_score(label_class, pred_class, average='macro') 263 | 264 | if print_report: 265 | self.logger.info('\n') 266 | # target_names = ['False', 'Partly Ture', 'True'] 267 | self.logger.info(classification_report(label_class, pred_class)) 268 | return avg_loss, avg_acc, macro_f1 269 | 270 | def fit(self, train_dataloader, eval_dataloader): 271 | """Model training and evaluation. 272 | """ 273 | best_f1 = 0. 274 | num_epochs = self.args.num_epochs 275 | 276 | for epoch in range(num_epochs): 277 | self.logger.info('Epoch {}/{}'.format(epoch + 1, num_epochs)) 278 | 279 | # training 280 | train_loss = self.train_epoch(train_dataloader, self.optimizer, epoch) 281 | self.logger.info("Traing Loss: {}".format(train_loss)) 282 | 283 | # evaluation, only on the master node 284 | eval_loss, eval_acc, macro_f1 = self.evaluate(eval_dataloader, True) 285 | self.logger.info("Evaluation:") 286 | self.logger.info("- loss: {}".format(eval_loss)) 287 | self.logger.info("- acc: {}".format(eval_acc)) 288 | self.logger.info("- macro F1: {}".format(macro_f1)) 289 | 290 | # monitor loss and accuracy 291 | if self.writer: 292 | self.writer.add_scalar('epoch_loss', train_loss, epoch) 293 | self.writer.add_scalar('eval_loss', eval_loss, epoch) 294 | self.writer.add_scalar('eval_acc', eval_acc, epoch) 295 | self.writer.add_scalar('lr', self.scheduler.get_lr()[0], epoch) 296 | 297 | # save the model 298 | if macro_f1 >= best_f1: 299 | best_f1 = macro_f1 300 | self.logger.info("New best score!") 301 | self.save_model(self.args.model_dir) 302 | 303 | def test(self, test_dataloader): 304 | """Test the model on unlabeled dataset. 305 | """ 306 | self.model.eval() # set the model to evaluation mode 307 | pred_class = [] 308 | for _, batch in enumerate(tqdm(test_dataloader, desc="Iteration")): 309 | batch = tuple(t.to(self.args.device) for t in batch) 310 | inputs = {'input_ids': batch[0], 311 | 'attention_mask': batch[1], 312 | 'labels': None} 313 | if self.args.model_type != 'distilbert': 314 | inputs['token_type_ids'] = batch[2] if self.args.model_type in ['bert', 'xlnet'] else None 315 | outputs = self.model(**inputs) 316 | _, preds = torch.max(outputs[0].detach(), 1) # preds: [batch_size] 317 | pred_class += preds.tolist() 318 | return pred_class 319 | -------------------------------------------------------------------------------- /model/run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import torch 6 | import argparse 7 | 8 | from os.path import join 9 | from datetime import datetime 10 | from utils import get_logger, set_seed 11 | from model import Model, MODEL_CLASSES, ALL_MODELS 12 | from data_utils import load_and_cache_examples, processors 13 | from torch.utils.data import DataLoader, SequentialSampler 14 | from extract_sentences import make_data 15 | 16 | 17 | # These are the file paths where the validation/test set will be mounted (read only) 18 | # into your Docker container. 19 | METADATA_FILEPATH = '' 20 | ARTICLES_FILEPATH = '' 21 | 22 | # This is the filepath where the predictions should be written to. 23 | OUTPUT_DIR = './' 24 | 25 | 26 | def write_results(output_dir, guids, preds): 27 | output_eval_file = os.path.join(output_dir, "predictions.txt") 28 | with open(output_eval_file, "w") as writer: 29 | for i, p in zip(guids, preds): 30 | writer.write("%d,%d\n" % (int(i), p)) 31 | 32 | 33 | def main(): 34 | # directory for training outputs 35 | output_dir = "results/{:%Y%m%d_%H%M%S}/".format(datetime.now()) 36 | 37 | # required parameters 38 | parser = argparse.ArgumentParser() 39 | 40 | parser.add_argument("--model_type", default=None, type=str, required=True, 41 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 42 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 43 | help="Path to pre-trained model or shortcut name: " + ", ".join(ALL_MODELS)) 44 | parser.add_argument("--checkpoint", default='', type=str, required=True, 45 | help="where to load pre-trained model.") 46 | parser.add_argument("--max_seq_length", default=512, type=int, required=True, 47 | help="The maximum total input sequence length after tokenization. Sequences longer " 48 | "than this will be truncated, sequences shorter will be padded.") 49 | parser.add_argument("--use_pretrained", action='store_true', default=True, 50 | help="If use pre-trained model weights.") 51 | 52 | # other parameters 53 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 54 | help="Number of updates steps to accumulate before performing a backward/update pass.") 55 | parser.add_argument("--config_name", default="", type=str, 56 | help="Pretrained config name or path if not the same as model_name") 57 | parser.add_argument("--tokenizer_name", default="", type=str, 58 | help="Pretrained tokenizer name or path if not the same as model_name") 59 | parser.add_argument("--num_epochs", default=30, type=int, 60 | help="Total number of training epochs to perform.") 61 | parser.add_argument("--task_name", default='lpc', type=str, 62 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) 63 | parser.add_argument("--weight_decay", default=0.0, type=float, 64 | help="Weight deay if we apply some.") 65 | parser.add_argument("--learning_rate", default=2e-5, type=float, 66 | help="The initial learning rate for Adam.") 67 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 68 | help="Epsilon for Adam optimizer.") 69 | parser.add_argument("--warmup_steps", default=3, type=int, 70 | help="Linear warmup over warmup_steps.") 71 | parser.add_argument("--per_gpu_eval_batch_size", default=4, type=int, 72 | help="Batch size for evaluation.") 73 | parser.add_argument("--no_cuda", default=False, type=bool, 74 | help="Do not use cuda.") 75 | parser.add_argument("--do_lower_case", default=True, type=bool, 76 | help="Do lower case.") 77 | parser.add_argument("--seed", default=610, type=int, 78 | help="Random seed.") 79 | parser.add_argument("--num_labels", default=3, type=int, 80 | help="Classification label number.") 81 | parser.add_argument("--scheduler", default='warmup', type=str, 82 | help="Which type of scheduler to use.") 83 | parser.add_argument("--local_rank", type=int, default=-1, 84 | help="For distributed training: local_rank") 85 | parser.add_argument('--overwrite_cache', action='store_true', default=False, 86 | help="Overwrite the cached training and evaluation sets") 87 | parser.add_argument('--write_summary', default=True, type=bool, 88 | help="If write summary into tensorboard.") 89 | parser.add_argument('--fp16', action='store_true', default=False, 90 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 91 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 92 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 93 | "See details at https://nvidia.github.io/apex/amp.html") 94 | 95 | # data directory 96 | parser.add_argument("--data_dir", default='../data/TF-IDF', type=str, 97 | help="data directory where pickle dataset is stored.") 98 | parser.add_argument("--output_dir", default=output_dir, type=str, 99 | help="output directory for model, log file and summary.") 100 | parser.add_argument("--log_path", default=join(output_dir, "log.txt"), type=str, 101 | help="Path to log.txt.") 102 | parser.add_argument("--summary_path", default=join(output_dir, "summary"), type=str, 103 | help="Path to summary file.") 104 | 105 | args = parser.parse_args() 106 | 107 | if not os.path.exists(args.output_dir): 108 | os.makedirs(args.output_dir) 109 | 110 | if not os.path.exists(args.data_dir): 111 | os.makedirs(args.data_dir) 112 | 113 | args.logger = get_logger(args.log_path) 114 | 115 | # Setup CUDA, GPU & distributed training 116 | args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 117 | args.n_gpu = torch.cuda.device_count() 118 | args.logger.info("- device: {}, n_gpu: {}".format(args.device, args.n_gpu)) 119 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 120 | 121 | # set seed 122 | set_seed(args.seed) 123 | 124 | # build model 125 | args.logger.info("Build model...") 126 | model = Model(args) 127 | 128 | # make data 129 | make_data(ARTICLES_FILEPATH, METADATA_FILEPATH, args.data_dir, file_name='dev') 130 | 131 | # build dataset 132 | args.logger.info("Loading dataset...") 133 | eval_dataset, guids = load_and_cache_examples(args, args.task_name, model.tokenizer, evaluate=True) 134 | eval_sampler = SequentialSampler(eval_dataset) 135 | eval_dataloader = DataLoader(eval_dataset, 136 | sampler=eval_sampler, 137 | batch_size=args.eval_batch_size) 138 | 139 | # training 140 | args.logger.info("Start testing:") 141 | preds = model.test(eval_dataloader) 142 | assert len(preds) == len(guids), "Prediction list and GUID list length do NOT equal!!!" 143 | 144 | # write results 145 | args.logger.info("Write prediction results:") 146 | write_results(OUTPUT_DIR, guids, preds) 147 | args.logger.info("Save results at: {}".format(os.path.join(OUTPUT_DIR, 'predictions.txt'))) 148 | 149 | 150 | if __name__ == '__main__': 151 | main() 152 | -------------------------------------------------------------------------------- /model/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import torch 6 | import argparse 7 | 8 | from os.path import join 9 | from datetime import datetime 10 | from utils import get_logger, set_seed 11 | from model import Model, MODEL_CLASSES, ALL_MODELS 12 | from data_utils import load_and_cache_examples, processors 13 | from torch.utils.data import DataLoader, SequentialSampler 14 | 15 | 16 | def write_results(output_dir, guids, preds): 17 | output_eval_file = os.path.join(output_dir, "predictions.txt") 18 | with open(output_eval_file, "w") as writer: 19 | for i, p in zip(guids, preds): 20 | writer.write("%d,%d\n" % (int(i), p)) 21 | 22 | 23 | def main(): 24 | # directory for training outputs 25 | output_dir = "results/{:%Y%m%d_%H%M%S}/".format(datetime.now()) 26 | 27 | # required parameters 28 | parser = argparse.ArgumentParser() 29 | 30 | parser.add_argument("--model_type", default=None, type=str, required=True, 31 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 32 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 33 | help="Path to pre-trained model or shortcut name: " + ", ".join(ALL_MODELS)) 34 | parser.add_argument("--checkpoint", default='', type=str, required=True, 35 | help="where to load pre-trained model.") 36 | parser.add_argument("--max_seq_length", default=512, type=int, required=True, 37 | help="The maximum total input sequence length after tokenization. Sequences longer " 38 | "than this will be truncated, sequences shorter will be padded.") 39 | parser.add_argument("--use_pretrained", action='store_true', default=True, 40 | help="If use pre-trained model weights.") 41 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 42 | help="Number of updates steps to accumulate before performing a backward/update pass.") 43 | 44 | # other parameters 45 | parser.add_argument("--config_name", default="", type=str, 46 | help="Pretrained config name or path if not the same as model_name") 47 | parser.add_argument("--tokenizer_name", default="", type=str, 48 | help="Pretrained tokenizer name or path if not the same as model_name") 49 | parser.add_argument("--num_epochs", default=30, type=int, 50 | help="Total number of training epochs to perform.") 51 | parser.add_argument("--task_name", default='lpc', type=str, 52 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) 53 | parser.add_argument("--weight_decay", default=0.0, type=float, 54 | help="Weight deay if we apply some.") 55 | parser.add_argument("--learning_rate", default=2e-5, type=float, 56 | help="The initial learning rate for Adam.") 57 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 58 | help="Epsilon for Adam optimizer.") 59 | parser.add_argument("--warmup_steps", default=3, type=int, 60 | help="Linear warmup over warmup_steps.") 61 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 62 | help="Batch size for evaluation.") 63 | parser.add_argument("--no_cuda", default=False, type=bool, 64 | help="Do not use cuda.") 65 | parser.add_argument("--do_lower_case", default=True, type=bool, 66 | help="Do lower case.") 67 | parser.add_argument("--seed", default=610, type=int, 68 | help="Random seed.") 69 | parser.add_argument("--num_labels", default=3, type=int, 70 | help="Classification label number.") 71 | parser.add_argument("--scheduler", default='warmup', type=str, 72 | help="Which type of scheduler to use.") 73 | parser.add_argument("--local_rank", type=int, default=-1, 74 | help="For distributed training: local_rank") 75 | parser.add_argument('--overwrite_cache', action='store_true', default=False, 76 | help="Overwrite the cached training and evaluation sets") 77 | parser.add_argument('--write_summary', default=True, type=bool, 78 | help="If write summary into tensorboard.") 79 | parser.add_argument('--fp16', action='store_true', default=False, 80 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 81 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 82 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 83 | "See details at https://nvidia.github.io/apex/amp.html") 84 | 85 | # data directory 86 | parser.add_argument("--data_dir", default='../data/TF-IDF', type=str, 87 | help="data directory where pickle dataset is stored.") 88 | parser.add_argument("--output_dir", default=output_dir, type=str, 89 | help="output directory for model, log file and summary.") 90 | parser.add_argument("--log_path", default=join(output_dir, "log.txt"), type=str, 91 | help="Path to log.txt.") 92 | parser.add_argument("--summary_path", default=join(output_dir, "summary"), type=str, 93 | help="Path to summary file.") 94 | parser.add_argument("--results_dir", default="./", type=str, 95 | help="Path to write prediction resutls.") 96 | 97 | args = parser.parse_args() 98 | 99 | if not os.path.exists(args.output_dir): 100 | os.makedirs(args.output_dir) 101 | 102 | args.logger = get_logger(args.log_path) 103 | 104 | # Setup CUDA, GPU & distributed training 105 | args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 106 | args.n_gpu = torch.cuda.device_count() 107 | args.logger.info("- device: {}, n_gpu: {}".format(args.device, args.n_gpu)) 108 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 109 | 110 | # set seed 111 | set_seed(args.seed) 112 | 113 | # build model 114 | args.logger.info("Build model...") 115 | model = Model(args) 116 | 117 | # build dataset 118 | args.logger.info("Loading dataset...") 119 | eval_dataset, guids = load_and_cache_examples(args, args.task_name, model.tokenizer, evaluate=True) 120 | eval_sampler = SequentialSampler(eval_dataset) 121 | eval_dataloader = DataLoader(eval_dataset, 122 | sampler=eval_sampler, 123 | batch_size=args.eval_batch_size) 124 | 125 | # training 126 | args.logger.info("Start testing:") 127 | model.evaluate(eval_dataloader, True) 128 | # preds = model.test(eval_dataloader) 129 | # assert len(preds) == len(guids), "Prediction list and GUID list length do NOT equal!!!" 130 | 131 | # write results 132 | # args.logger.info("Write prediction results:") 133 | # write_results(args.results_dir, guids, preds) 134 | # args.logger.info("Save results at: {}/predictons.txt".format(args.results_dir)) 135 | 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import time 4 | import sys 5 | import logging 6 | import numpy as np 7 | import random 8 | import torch 9 | 10 | 11 | def set_seed(seed, n_gpu=1): 12 | random.seed(seed) 13 | np.random.seed(seed) 14 | torch.manual_seed(seed) 15 | if n_gpu > 0: 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | 19 | def get_logger(filename): 20 | """Return a logger instance that writes in filename 21 | Args: 22 | filename: (string) path to log.txt 23 | Returns: 24 | logger: (instance of logger) 25 | """ 26 | logger = logging.getLogger('logger') 27 | logger.setLevel(logging.DEBUG) 28 | logging.basicConfig(format='%(message)s', level=logging.DEBUG) 29 | handler = logging.FileHandler(filename) 30 | handler.setLevel(logging.DEBUG) 31 | handler.setFormatter(logging.Formatter( 32 | '%(asctime)s:%(levelname)s: %(message)s')) 33 | logging.getLogger().addHandler(handler) 34 | 35 | return logger 36 | 37 | 38 | class Progbar(object): 39 | """Progbar class copied from keras (https://github.com/fchollet/keras/) 40 | Displays a progress bar. 41 | Small edit : added strict arg to update 42 | # Arguments 43 | target: Total number of steps expected. 44 | interval: Minimum visual progress update interval (in seconds). 45 | """ 46 | 47 | def __init__(self, target, width=30, verbose=1): 48 | self.width = width 49 | self.target = target 50 | self.sum_values = {} 51 | self.unique_values = [] 52 | self.start = time.time() 53 | self.total_width = 0 54 | self.seen_so_far = 0 55 | self.verbose = verbose 56 | 57 | def update(self, current, values=[], exact=[], strict=[]): 58 | """ 59 | Updates the progress bar. 60 | # Arguments 61 | current: Index of current step. 62 | values: List of tuples (name, value_for_last_step). 63 | The progress bar will display averages for these values. 64 | exact: List of tuples (name, value_for_last_step). 65 | The progress bar will display these values directly. 66 | """ 67 | 68 | for k, v in values: 69 | if k not in self.sum_values: 70 | self.sum_values[k] = [v * (current - self.seen_so_far), 71 | current - self.seen_so_far] 72 | self.unique_values.append(k) 73 | else: 74 | self.sum_values[k][0] += v * (current - self.seen_so_far) 75 | self.sum_values[k][1] += (current - self.seen_so_far) 76 | for k, v in exact: 77 | if k not in self.sum_values: 78 | self.unique_values.append(k) 79 | self.sum_values[k] = [v, 1] 80 | 81 | for k, v in strict: 82 | if k not in self.sum_values: 83 | self.unique_values.append(k) 84 | self.sum_values[k] = v 85 | 86 | self.seen_so_far = current 87 | 88 | now = time.time() 89 | if self.verbose == 1: 90 | prev_total_width = self.total_width 91 | sys.stdout.write("\b" * prev_total_width) 92 | sys.stdout.write("\r") 93 | 94 | numdigits = int(np.floor(np.log10(self.target))) + 1 95 | barstr = '%%%dd/%%%dd [' % (numdigits, numdigits) 96 | bar = barstr % (current, self.target) 97 | prog = float(current)/self.target 98 | prog_width = int(self.width*prog) 99 | if prog_width > 0: 100 | bar += ('='*(prog_width-1)) 101 | if current < self.target: 102 | bar += '>' 103 | else: 104 | bar += '=' 105 | bar += ('.'*(self.width-prog_width)) 106 | bar += ']' 107 | sys.stdout.write(bar) 108 | self.total_width = len(bar) 109 | 110 | if current: 111 | time_per_unit = (now - self.start) / current 112 | else: 113 | time_per_unit = 0 114 | eta = time_per_unit*(self.target - current) 115 | info = '' 116 | if current < self.target: 117 | info += ' - ETA: %ds' % eta 118 | else: 119 | info += ' - %ds' % (now - self.start) 120 | for k in self.unique_values: 121 | if type(self.sum_values[k]) is list: 122 | info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1])) 123 | else: 124 | info += ' - %s: %s' % (k, self.sum_values[k]) 125 | 126 | self.total_width += len(info) 127 | if prev_total_width > self.total_width: 128 | info += ((prev_total_width-self.total_width) * " ") 129 | 130 | sys.stdout.write(info) 131 | sys.stdout.flush() 132 | 133 | if current >= self.target: 134 | sys.stdout.write("\n") 135 | 136 | if self.verbose == 2: 137 | if current >= self.target: 138 | info = '%ds' % (now - self.start) 139 | for k in self.unique_values: 140 | info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1])) 141 | sys.stdout.write(info + "\n") 142 | 143 | def add(self, n, values=[]): 144 | self.update(self.seen_so_far+n, values) 145 | --------------------------------------------------------------------------------