├── README.md ├── finetune_wi_wiktionary ├── finetune.py ├── finetune.sh ├── ftCollator.py └── ftTrainer.py ├── finetune_wo_wiktionary ├── finetune.py ├── finetune.sh ├── ftCollator.py └── ftTrainer.py ├── preprocess_datasets ├── get_description.py ├── load_dataset.py ├── load_preprocess.sh ├── preprocess.py └── select_word.py └── preprocess_wiktionary ├── construct_wiktionary.py └── download_wiktionary.sh /README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Dict-BERT: Enhancing Language Model Pre-training with Dictionary 3 | 4 | ## Introduction 5 | 6 | -- This is the pytorch implementation of our [ACL 2022](https://www.2022.aclweb.org/) paper "*Dict-BERT: Enhancing Language Model Pre-training with Dictionary*" [\[PDF\]](https://arxiv.org/abs/2110.06490). 7 | In this paper, we propose DictBERT, which is a novel pre-trained language model by leveraging rare word definitions in English dictionaries (e.g., Wiktionary). DictBERT is based on the BERT architecture, trained under the same setting as BERT. Please refer more details in our paper. 8 | 9 | ## Install the packages 10 | 11 | python version >=3.6 12 | 13 | 14 | ``` 15 | transformers==4.7.0 16 | datasets==1.8.0 17 | torch==1.8.0 18 | ``` 19 | 20 | Also need to install `dataclasses`, `scipy`, `sklearn`, `nltk` 21 | 22 | 23 | 24 | ## Preprocess the data 25 | 26 | -- download Wiktionary 27 | 28 | ```bash 29 | cd preprocess_wiktionary 30 | bash download_wiktionary.sh 31 | ``` 32 | 33 | -- download GLUE benchmark 34 | ```bash 35 | cd preprocess_datasets 36 | bash load_preprocess.sh 37 | ``` 38 | 39 | ## Download the checkpoint 40 | 41 | -- Huggingface Hub [\[link\]](https://huggingface.co/wyu1/DictBERT) 42 | 43 | ``` 44 | git lfs install 45 | git clone https://huggingface.co/wyu1/DictBERT 46 | ``` 47 | 48 | ## Run experiments on GLUE 49 | 50 | -- without dictionary 51 | 52 | ```bash 53 | cd finetune_wo_wiktionary 54 | bash finetune.sh 55 | ``` 56 | 57 | -- with dictionary 58 | 59 | ```bash 60 | cd finetune_wi_wiktionary 61 | bash finetune.sh 62 | ``` 63 | 64 | 65 | ## Citation 66 | 67 | ``` 68 | @inproceedings{yu2022dict, 69 | title={Dict-BERT: Enhancing Language Model Pre-training with Dictionary}, 70 | author={Yu, Wenhao and Zhu, Chenguang and Fang, Yuwei and Yu, Donghan and Wang, Shuohang and Xu, Yichong and Zeng, Michael and Jiang, Meng}, 71 | booktitle={Findings of the Association for Computational Linguistics: ACL 2022}, 72 | pages={1907--1918}, 73 | year={2022} 74 | } 75 | ``` 76 | 77 | Please kindly cite our paper if you find this paper and the codes helpful. 78 | 79 | ## Acknowledgements 80 | 81 | Many thanks to the Github repository of [Transformers](https://github.com/huggingface/transformers). Part of our codes are modified based on their codes. 82 | -------------------------------------------------------------------------------- /finetune_wi_wiktionary/finetune.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import json 3 | import random 4 | import logging 5 | from dataclasses import dataclass, field 6 | from collections import defaultdict 7 | from typing import Optional 8 | 9 | import numpy as np 10 | from datasets import load_dataset, load_metric 11 | 12 | from transformers import ( 13 | BertConfig, 14 | BertForSequenceClassification, 15 | BertTokenizer, 16 | 17 | EvalPrediction, 18 | HfArgumentParser, 19 | PretrainedConfig, 20 | TrainingArguments, 21 | set_seed, 22 | ) 23 | 24 | from transformers.trainer_utils import is_main_process 25 | from transformers.utils import check_min_version 26 | 27 | from ftTrainer import ftTrainer 28 | from ftCollator import ( 29 | default_data_collator, 30 | DataCollatorWithPadding, 31 | ) 32 | 33 | check_min_version("4.7.0.dev0") 34 | 35 | task_to_keys = { 36 | "cola": ("sentence", None), 37 | "mnli": ("premise", "hypothesis"), 38 | "mrpc": ("sentence1", "sentence2"), 39 | "qnli": ("question", "sentence"), 40 | "qqp": ("question1", "question2"), 41 | "rte": ("sentence1", "sentence2"), 42 | "sst2": ("sentence", None), 43 | "stsb": ("sentence1", "sentence2"), 44 | "wnli": ("sentence1", "sentence2"), 45 | } 46 | 47 | logger = logging.getLogger(__name__) 48 | logger.setLevel(logging.INFO) 49 | 50 | @dataclass 51 | class DataTrainingArguments: 52 | """ 53 | Arguments pertaining to what data we are going to input our model for training and eval. 54 | 55 | Using `HfArgumentParser` we can turn this class 56 | into argparse arguments to be able to specify them on 57 | the command line. 58 | """ 59 | 60 | task_name: Optional[str] = field( 61 | default=None, 62 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 63 | ) 64 | dataset_config_name: Optional[str] = field( 65 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 66 | ) 67 | max_seq_length: int = field( 68 | default=128, 69 | metadata={ 70 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 71 | "than this will be truncated, sequences shorter will be padded." 72 | }, 73 | ) 74 | overwrite_cache: bool = field( 75 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 76 | ) 77 | pad_to_max_length: bool = field( 78 | default=False, 79 | metadata={ 80 | "help": "Whether to pad all samples to `max_seq_length`. " 81 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 82 | }, 83 | ) 84 | max_train_ratios: Optional[float] = field( 85 | default=None, 86 | metadata={ 87 | "help": "For debugging purposes or quicker training, truncate the ratio of training examples to this " 88 | "value if set." 89 | }, 90 | ) 91 | max_train_samples: Optional[int] = field( 92 | default=None, 93 | metadata={ 94 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 95 | "value if set." 96 | }, 97 | ) 98 | max_eval_samples: Optional[int] = field( 99 | default=None, 100 | metadata={ 101 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 102 | "value if set." 103 | }, 104 | ) 105 | max_predict_samples: Optional[int] = field( 106 | default=None, 107 | metadata={ 108 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 109 | "value if set." 110 | }, 111 | ) 112 | train_file: Optional[str] = field( 113 | default=None, metadata={"help": "A csv or a json file containing the training data."} 114 | ) 115 | validation_file: Optional[str] = field( 116 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 117 | ) 118 | test_file: Optional[str] = field( 119 | default=None, metadata={"help": "A csv or a json file containing the test data."}) 120 | dict_file: Optional[str] = field( 121 | default=None, metadata={"help": "A json file containing the dictionary entry with 'word, input_ids, ...'"}) 122 | 123 | def __post_init__(self): 124 | if self.task_name is not None: 125 | self.task_name = self.task_name.lower() 126 | # if self.task_name not in task_to_keys.keys(): 127 | # self.task_name = None 128 | elif self.train_file is None or self.validation_file is None: 129 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 130 | else: 131 | train_extension = self.train_file.split(".")[-1] 132 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." 133 | validation_extension = self.validation_file.split(".")[-1] 134 | assert ( 135 | validation_extension == train_extension 136 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 137 | 138 | 139 | @dataclass 140 | class ModelArguments: 141 | """ 142 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 143 | """ 144 | 145 | model_name_or_path: str = field( 146 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 147 | ) 148 | config_name: Optional[str] = field( 149 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 150 | ) 151 | tokenizer_name: Optional[str] = field( 152 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 153 | ) 154 | cache_dir: Optional[str] = field( 155 | default=None, 156 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 157 | ) 158 | use_fast_tokenizer: bool = field( 159 | default=True, 160 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 161 | ) 162 | model_revision: str = field( 163 | default="main", 164 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 165 | ) 166 | use_auth_token: bool = field( 167 | default=False, 168 | metadata={ 169 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 170 | "with private models)." 171 | }, 172 | ) 173 | 174 | def tokenize_dataset(datasets, data_args, model, tokenizer, is_regression, num_labels, label_list): 175 | 176 | # Some models have set the order of the labels to use, so let's make sure we do use it. 177 | label_to_id = None 178 | if ( 179 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 180 | and data_args.task_name in task_to_keys.keys() 181 | and not is_regression 182 | ): 183 | # Some have all caps in their config, some don't. 184 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 185 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 186 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 187 | else: 188 | logger.warning( 189 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 190 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 191 | "\nIgnoring the model labels as a result.", 192 | ) 193 | elif data_args.task_name not in task_to_keys.keys() and not is_regression: 194 | label_to_id = {v: i for i, v in enumerate(label_list)} 195 | 196 | if data_args.max_seq_length > tokenizer.model_max_length: 197 | logger.warning( 198 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 199 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 200 | ) 201 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 202 | 203 | # Preprocessing the datasets 204 | if data_args.task_name in task_to_keys.keys(): 205 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 206 | else: 207 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 208 | non_label_column_names = [name for name in datasets["train"].column_names if name != "label"] 209 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 210 | sentence1_key, sentence2_key = "sentence1", "sentence2" 211 | else: 212 | sentence1_key, sentence2_key = 'text', None 213 | 214 | # Padding strategy 215 | if data_args.pad_to_max_length: 216 | padding = "max_length" 217 | else: 218 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 219 | padding = False 220 | 221 | use_dict = True if data_args.dict_file else False 222 | if use_dict: 223 | word2dict = defaultdict(dict) 224 | with open(data_args.dict_file, 'r') as f: 225 | for line in f.readlines(): 226 | line = json.loads(line) 227 | text = line['word'] + ' is ' + line['text'] + ' [SEP]' 228 | line['input_ids'] = tokenizer(text, add_special_tokens=False)['input_ids'] 229 | word2dict[line['word']] = line 230 | 231 | def post_process_with_dict(features): 232 | 233 | batch_input_ids, batch_token_type_ids = [], [] 234 | batch_rows, batch_columns, batch_segments = [], [], [] 235 | 236 | # for input_ids in features['input_ids']: 237 | for input_ids, token_type_ids in zip(features['input_ids'], features['token_type_ids']): 238 | # token_type_ids = [0] * len(input_ids) 239 | assert len(input_ids) <= data_args.max_seq_length 240 | if len(input_ids) == data_args.max_seq_length: 241 | batch_input_ids.append(input_ids) 242 | batch_token_type_ids.append(token_type_ids) 243 | batch_rows.append([]) 244 | batch_columns.append([]) 245 | batch_segments.append([1] * len(input_ids)) 246 | continue 247 | 248 | input_tokens = [tokenizer._convert_id_to_token(id) for id in input_ids] 249 | 250 | word_indexes = [] 251 | input_words = ' '.join(input_tokens).replace(' ##', '').split() 252 | for (i, token) in enumerate(input_tokens): 253 | if token.startswith("##"): 254 | word_indexes[-1].append(i) 255 | else: 256 | word_indexes.append([i]) 257 | 258 | assert len(input_words) == len(word_indexes) 259 | 260 | rows, columns = [], [] 261 | segments = [1] * len(token_type_ids) 262 | current_segment = 1 263 | 264 | word2seg = defaultdict(int) 265 | for index, word in zip(word_indexes, input_words): 266 | wordIndict = word2dict.get(word) 267 | if wordIndict: 268 | if word2seg.get(word): 269 | rows += index # already a list of index 270 | columns += [word2seg[word]] * len(index) 271 | 272 | else: # not in ... 273 | def_ids = wordIndict['input_ids'] 274 | if len(input_ids) + len(def_ids) >= data_args.max_seq_length: 275 | print(max_seq_length, 'exceed!') 276 | break # if exceed the max_length 277 | 278 | rows += index # already a list of index 279 | input_ids += def_ids 280 | 281 | current_segment += 1 282 | columns += [current_segment] * len(index) 283 | segments += [current_segment] * len(def_ids) 284 | token_type_ids += [1] * len(def_ids) 285 | 286 | word2seg[word] = current_segment 287 | 288 | assert len(input_ids) == len(token_type_ids) == len(segments) 289 | 290 | batch_input_ids.append(input_ids) 291 | batch_token_type_ids.append(token_type_ids) 292 | batch_rows.append(rows) 293 | batch_columns.append(columns) 294 | batch_segments.append(segments) 295 | 296 | return {'input_ids': batch_input_ids, 297 | 'rows': batch_rows, 298 | 'columns': batch_columns, 299 | 'segments': batch_segments, 300 | 'token_type_ids': batch_token_type_ids, 301 | } 302 | 303 | def preprocess_function(examples): 304 | 305 | # Tokenize the texts 306 | args = ((examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])) 307 | tokenized_text = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 308 | # Map labels to IDs (not necessary for GLUE tasks) 309 | 310 | if use_dict: 311 | tokenized_text = post_process_with_dict(tokenized_text) 312 | 313 | if label_to_id is not None and "label" in examples: 314 | tokenized_text["labels"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 315 | 316 | return tokenized_text 317 | 318 | datasets = datasets.map(preprocess_function, 319 | batched=True, 320 | load_from_cache_file=not data_args.overwrite_cache 321 | ) 322 | return datasets 323 | 324 | 325 | def main(): 326 | # See all possible arguments in src/transformers/training_args.py 327 | # or by passing the --help flag to this script. 328 | # We now keep distinct sets of args, for a cleaner separation of concerns. 329 | 330 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 331 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 332 | # If we pass only one argument to the script and it's the path to a json file, 333 | # let's parse it to get our arguments. 334 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 335 | else: 336 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 337 | 338 | # Setup logging 339 | logging.basicConfig( 340 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 341 | datefmt="%m/%d/%Y %H:%M:%S", 342 | handlers=[logging.StreamHandler(sys.stdout)], 343 | ) 344 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 345 | 346 | # Log on each process the small summary: 347 | logger.warning( 348 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 349 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 350 | ) 351 | 352 | logger.info(f"Training/evaluation parameters {training_args}") 353 | 354 | # Set seed before initializing model. 355 | set_seed(training_args.seed) 356 | 357 | if data_args.task_name in task_to_keys.keys() and not data_args.train_file: 358 | # Downloading and loading a dataset from the hub. 359 | datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir) 360 | else: 361 | 362 | if data_args.task_name == 'mnli': 363 | validation_matched_file = data_args.validation_file.replace('validation', 'validation_matched') 364 | validation_mismatched_file = data_args.validation_file.replace('validation', 'validation_mismatched') 365 | data_files = {"train": data_args.train_file, 366 | "validation_matched": validation_matched_file, 367 | "validation_mismatched": validation_mismatched_file, 368 | } 369 | else: 370 | # Loading a dataset from your local files. 371 | # CSV/JSON training and evaluation files are needed. 372 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 373 | 374 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 375 | # when you use `do_predict` without specifying a GLUE benchmark task. 376 | if training_args.do_predict: 377 | if data_args.test_file is not None: 378 | train_extension = data_args.train_file.split(".")[-1] 379 | test_extension = data_args.test_file.split(".")[-1] 380 | assert ( 381 | test_extension == train_extension 382 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 383 | if data_args.task_name == 'mnli': 384 | data_files['test_matched'] = data_args.test_file.replace('test', 'test_matched') 385 | data_files['test_mismatched'] = data_args.test_file.replace('test', 'test_mismatched') 386 | else: 387 | data_files["test"] = data_args.test_file 388 | else: 389 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 390 | 391 | for key in data_files.keys(): 392 | logger.info(f"load a local file for {key}: {data_files[key]}") 393 | 394 | if data_args.train_file.endswith(".csv"): 395 | # Loading a dataset from local csv files 396 | datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir) 397 | else: 398 | # Loading a dataset from local json files 399 | datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir) 400 | # See more about loading any type of standard or custom dataset at 401 | # https://huggingface.co/docs/datasets/loading_datasets.html. 402 | 403 | # Labels 404 | label_list = None 405 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 406 | is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"] or data_args.task_name == "stsb" 407 | if is_regression: 408 | num_labels = 1 409 | else: 410 | # A useful fast method: 411 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 412 | label_list = datasets["train"].unique("label") 413 | label_list.sort() # Let's sort it for determinism 414 | num_labels = len(label_list) 415 | 416 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 417 | # download model & vocab. 418 | config = BertConfig.from_pretrained( 419 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 420 | num_labels=num_labels, 421 | finetuning_task=data_args.task_name, 422 | cache_dir=model_args.cache_dir, 423 | revision=model_args.model_revision, 424 | use_auth_token=True if model_args.use_auth_token else None, 425 | ) 426 | tokenizer = BertTokenizer.from_pretrained( 427 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 428 | cache_dir=model_args.cache_dir, 429 | use_fast=model_args.use_fast_tokenizer, 430 | revision=model_args.model_revision, 431 | use_auth_token=True if model_args.use_auth_token else None, 432 | ) 433 | model = BertForSequenceClassification.from_pretrained( 434 | model_args.model_name_or_path, 435 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 436 | config=config, 437 | cache_dir=model_args.cache_dir, 438 | revision=model_args.model_revision, 439 | use_auth_token=True if model_args.use_auth_token else None, 440 | ) 441 | 442 | datasets = tokenize_dataset(datasets, data_args, model, tokenizer, 443 | is_regression, num_labels, label_list) 444 | logger.warning(datasets) 445 | 446 | if training_args.do_train: 447 | if "train" not in datasets: 448 | raise ValueError("--do_train requires a train dataset") 449 | train_dataset = datasets["train"] 450 | if data_args.max_train_ratios is not None: 451 | data_args.max_train_samples = int(len(train_dataset) * data_args.max_train_ratios) 452 | if data_args.max_train_samples is not None: 453 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 454 | 455 | if training_args.do_eval: 456 | if "validation" not in datasets and "validation_matched" not in datasets: 457 | raise ValueError("--do_eval requires a validation dataset") 458 | eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] 459 | if data_args.max_eval_samples is not None: 460 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 461 | 462 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 463 | if "test" not in datasets and "test_matched" not in datasets: 464 | raise ValueError("--do_predict requires a test dataset") 465 | predict_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] 466 | if data_args.max_predict_samples is not None: 467 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 468 | 469 | # Log a few random samples from the training set: 470 | if training_args.do_train: 471 | for index in random.sample(range(len(train_dataset)), 1): 472 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 473 | 474 | # Get the metric function 475 | if data_args.task_name in task_to_keys.keys(): 476 | metric = load_metric("glue", data_args.task_name) 477 | else: 478 | metric_acc = load_metric("accuracy") 479 | metric_f1 = load_metric("f1") 480 | 481 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 482 | # predictions and label_ids field) and has to return a dictionary string to float. 483 | def compute_metrics(p: EvalPrediction): 484 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 485 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 486 | if data_args.task_name in task_to_keys.keys(): 487 | result = metric.compute(predictions=preds, references=p.label_ids) 488 | if len(result) > 1: 489 | result["combined_score"] = np.mean(list(result.values())).item() 490 | return result 491 | elif is_regression: 492 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 493 | else: 494 | result = metric_acc.compute(predictions=preds, references=p.label_ids) 495 | result['mi_f1'] = metric_f1.compute(predictions=preds, references=p.label_ids, average="micro")['f1'] 496 | result['ma_f1'] = metric_f1.compute(predictions=preds, references=p.label_ids, average="macro")['f1'] 497 | return result 498 | 499 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. 500 | if data_args.pad_to_max_length: 501 | data_collator = default_data_collator 502 | elif training_args.fp16: 503 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 504 | else: 505 | data_collator = None 506 | 507 | training_args.task_name = data_args.task_name 508 | # Initialize our Trainer 509 | trainer = ftTrainer( 510 | model=model, 511 | args=training_args, 512 | dataset=datasets, 513 | train_dataset=train_dataset if training_args.do_train else None, 514 | eval_dataset=eval_dataset if training_args.do_eval else None, 515 | predict_dataset=predict_dataset if training_args.do_predict else None, 516 | label_list=label_list if training_args.do_predict else None, 517 | compute_metrics=compute_metrics, 518 | tokenizer=tokenizer, 519 | data_collator=data_collator, 520 | ) 521 | 522 | # Training 523 | if training_args.do_train: 524 | 525 | trainer.train() 526 | trainer.save_state() 527 | 528 | if training_args.do_predict: 529 | 530 | logger.info("*** Predict ***") 531 | 532 | # Loop to handle MNLI double evaluation (matched, mis-matched) 533 | tasks = [data_args.task_name] 534 | predict_datasets = [predict_dataset] 535 | if data_args.task_name == "mnli": 536 | tasks.append("mnli-mm") 537 | predict_datasets.append(datasets["test_mismatched"]) 538 | 539 | for predict_dataset, task in zip(predict_datasets, tasks): 540 | # Removing the `label` columns because it contains -1 and Trainer won't like that. 541 | predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions 542 | predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) 543 | 544 | output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.tsv") 545 | if trainer.is_world_process_zero(): 546 | with open(output_predict_file, "w") as writer: 547 | logger.info(f"***** Predict results {task} *****") 548 | writer.write("index\tprediction\n") 549 | for index, item in enumerate(predictions): 550 | if is_regression: 551 | writer.write(f"{index}\t{item:3.3f}\n") 552 | else: 553 | item = label_list[item] 554 | writer.write(f"{index}\t{item}\n") 555 | 556 | 557 | if __name__ == "__main__": 558 | main() 559 | -------------------------------------------------------------------------------- /finetune_wi_wiktionary/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export task=cola 4 | export input_dir=../glue_datasets 5 | export outfolder=output-${task}-dict 6 | # export model_name=bert-base-uncased 7 | export model_name=wyu1/DictBERT 8 | 9 | 10 | if [[ ${task} =~ $'cola' ]] || [[ ${task} =~ $'sst2' ]] 11 | then 12 | epoch=10; max_length=128 13 | elif [[ ${task} =~ $'qqp' ]] || [[ ${task} =~ $'mnli' ]] || [[ ${task} =~ $'qnli' ]] 14 | then 15 | epoch=5; max_length=128 16 | elif [[ ${OUTPUT_NAME} =~ $'rte' ]] 17 | then 18 | epoch=10; max_length=256 19 | elif [[ ${OUTPUT_NAME} =~ $'mrpc' ]] || [[ ${task} =~ $'stsb' ]] 20 | then 21 | epoch=5; max_length=256 22 | fi 23 | 24 | python -u finetune.py \ 25 | --model_name_or_path $model_name \ 26 | --task_name $task\ 27 | --train_file ${input_dir}/${task}/train.prc.json \ 28 | --validation_file ${input_dir}/${task}/validation.prc.json \ 29 | --test_file ${input_dir}/${task}/test.prc.json \ 30 | --dict_file ${input_dir}/${task}/vocab.90.json \ 31 | --do_train \ 32 | --do_eval \ 33 | --do_predict \ 34 | --max_seq_length $max_length \ 35 | --per_device_train_batch_size 32 \ 36 | --per_device_eval_batch_size 32 \ 37 | --remove_unused_columns False \ 38 | --learning_rate 2e-5 \ 39 | --num_train_epochs $epoch \ 40 | --fp16 \ 41 | --fp16_opt_level O2 \ 42 | --output_dir $outfolder/$task_name/ 43 | -------------------------------------------------------------------------------- /finetune_wi_wiktionary/ftCollator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import Any, Callable, Dict, List, NewType, Optional, Union 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from transformers.file_utils import PaddingStrategy 8 | from transformers.modeling_utils import PreTrainedModel 9 | from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase 10 | 11 | InputDataClass = NewType("InputDataClass", Any) 12 | DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]]) 13 | 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(logging.INFO) 16 | 17 | @dataclass 18 | class DataCollatorWithPadding: 19 | 20 | tokenizer: PreTrainedTokenizerBase 21 | padding: Union[bool, str, PaddingStrategy] = True 22 | max_length: Optional[int] = None 23 | pad_to_multiple_of: Optional[int] = None 24 | 25 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]): 26 | 27 | assert isinstance(features[0], (dict, BatchEncoding)) 28 | 29 | input_ids = [f['input_ids'] for f in features] 30 | token_type_ids = [f['token_type_ids'] for f in features] 31 | segments = [f['segments'] for f in features] 32 | 33 | batch_input_ids = _collate_batch(input_ids, self.tokenizer) 34 | batch_token_type_ids = _collate_batch(token_type_ids, self.tokenizer, padding_token_id=0) 35 | batch_segments = _collate_batch(segments, self.tokenizer, padding_token_id=0) 36 | 37 | if features[0].get('attention_mask'): 38 | attention_mask = [f["attention_mask"] for f in features] 39 | batch_attention_mask = _collate_batch(attention_mask, self.tokenizer, padding_token_id=0) 40 | raise TypeError('Should not be used!') 41 | 42 | else: # No attention mask in training, since create it as follows 43 | segments = [f['segments'] for f in features] 44 | batch_segments = _collate_batch(segments, self.tokenizer, padding_token_id=0) 45 | 46 | nums, rows, columns = [], [], [] 47 | for idx, f in enumerate(features): 48 | rows += f['rows'] 49 | columns += f['columns'] 50 | nums += len(f['rows']) * [idx] 51 | 52 | batch_one_hot = F.one_hot(batch_segments) 53 | batch_one_hot_T = batch_one_hot.clone().transpose(1, 2) 54 | 55 | batch_one_hot[nums, rows, columns] = 1 56 | batch_attention_mask = batch_one_hot @ batch_one_hot_T 57 | 58 | if features[0].get('labels') is not None: 59 | batch_labels = torch.tensor([f['labels'] for f in features]) 60 | elif features[0].get('label_ids') is not None: 61 | batch_labels = torch.tensor([f['label_ids'] for f in features]) 62 | elif features[0].get('label') is not None: 63 | batch_labels = torch.tensor([f['label'] for f in features]) 64 | else: 65 | return {"input_ids": batch_input_ids, 66 | "attention_mask": batch_attention_mask, 67 | "token_type_ids": batch_token_type_ids, 68 | } 69 | 70 | return {"input_ids": batch_input_ids, 71 | "labels": batch_labels, 72 | "attention_mask": batch_attention_mask, 73 | "token_type_ids": batch_token_type_ids, 74 | } 75 | 76 | 77 | def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None, padding_token_id: Optional[int] = None): 78 | """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" 79 | # Tensorize if necessary. 80 | if isinstance(examples[0], (list, tuple)): 81 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 82 | 83 | # Check if padding is necessary. 84 | length_of_first = examples[0].size(0) 85 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 86 | if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): 87 | return torch.stack(examples, dim=0) 88 | 89 | # If yes, check if we have a `pad_token`. 90 | if tokenizer._pad_token is None: 91 | raise ValueError( 92 | "You are attempting to pad samples but the tokenizer you are using" 93 | f" ({tokenizer.__class__.__name__}) does not have a pad token." 94 | ) 95 | 96 | # Creating the full tensor and filling it with our data. 97 | max_length = max(x.size(0) for x in examples) 98 | if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): 99 | max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of 100 | 101 | if padding_token_id is not None: 102 | result = examples[0].new_full([len(examples), max_length], padding_token_id) 103 | else: 104 | result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) 105 | 106 | for i, example in enumerate(examples): 107 | if tokenizer.padding_side == "right": 108 | result[i, : example.shape[0]] = example 109 | else: 110 | result[i, -example.shape[0] :] = example 111 | return result 112 | 113 | 114 | def tolist(x: Union[List[Any], torch.Tensor]): 115 | return x.tolist() if isinstance(x, torch.Tensor) else x 116 | 117 | 118 | @dataclass 119 | class DataCollatorForSeq2Seq: 120 | 121 | tokenizer: PreTrainedTokenizerBase 122 | model: Optional[PreTrainedModel] = None 123 | padding: Union[bool, str, PaddingStrategy] = True 124 | max_length: Optional[int] = None 125 | pad_to_multiple_of: Optional[int] = None 126 | label_pad_token_id: int = -100 127 | 128 | def __call__(self, features): 129 | labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None 130 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the 131 | # same length to return tensors. 132 | if labels is not None: 133 | max_label_length = max(len(l) for l in labels) 134 | padding_side = self.tokenizer.padding_side 135 | for feature in features: 136 | remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) 137 | feature["labels"] = ( 138 | feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] 139 | ) 140 | 141 | features = self.tokenizer.pad( 142 | features, 143 | padding=self.padding, 144 | max_length=self.max_length, 145 | pad_to_multiple_of=self.pad_to_multiple_of, 146 | return_tensors="pt", 147 | ) 148 | 149 | # prepare decoder_input_ids 150 | if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"): 151 | decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"]) 152 | features["decoder_input_ids"] = decoder_input_ids 153 | 154 | return features 155 | 156 | 157 | def default_data_collator(features: List[InputDataClass]): 158 | """ 159 | Very simple data collator that simply collates batches of dict-like objects and performs special handling for 160 | potential keys named: 161 | 162 | - ``label``: handles a single value (int or float) per object 163 | - ``label_ids``: handles a list of values per object 164 | 165 | Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs 166 | to the model. See glue and ner for example of how it's useful. 167 | """ 168 | 169 | # In this function we'll make the assumption that all `features` in the batch 170 | # have the same attributes. 171 | # So we will look at the first element as a proxy for what attributes exist 172 | # on the whole batch. 173 | if not isinstance(features[0], (dict, BatchEncoding)): 174 | features = [vars(f) for f in features] 175 | 176 | first = features[0] 177 | batch = {} 178 | 179 | # Special handling for labels. 180 | # Ensure that tensor is created with the correct type 181 | # (it should be automatically the case, but let's make sure of it.) 182 | if "label" in first and first["label"] is not None: 183 | label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"] 184 | dtype = torch.long if isinstance(label, int) else torch.float 185 | batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) 186 | elif "label_ids" in first and first["label_ids"] is not None: 187 | if isinstance(first["label_ids"], torch.Tensor): 188 | batch["labels"] = torch.stack([f["label_ids"] for f in features]) 189 | else: 190 | dtype = torch.long if type(first["label_ids"][0]) is int else torch.float 191 | batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype) 192 | 193 | # Handling of all other possible keys. 194 | # Again, we will use the first element to figure out which key/values are not None for this model. 195 | for k, v in first.items(): 196 | if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): 197 | if isinstance(v, torch.Tensor): 198 | batch[k] = torch.stack([f[k] for f in features]) 199 | else: 200 | batch[k] = torch.tensor([f[k] for f in features]) 201 | 202 | return batch 203 | -------------------------------------------------------------------------------- /finetune_wi_wiktionary/ftTrainer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | import logging 4 | import os 5 | import sys 6 | import time 7 | import warnings 8 | from logging import StreamHandler 9 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 10 | 11 | from tqdm.auto import tqdm 12 | import datetime 13 | 14 | from transformers.integrations import ( # isort: split 15 | hp_params, 16 | is_fairscale_available, 17 | get_reporting_integration_callbacks, 18 | ) 19 | 20 | from transformers.deepspeed import deepspeed_init 21 | 22 | import numpy as np 23 | import torch 24 | from packaging import version 25 | from torch import nn 26 | from torch.utils.data.dataloader import DataLoader 27 | from torch.utils.data.dataset import Dataset 28 | from torch.utils.data.distributed import DistributedSampler 29 | 30 | from transformers import __version__ 31 | from transformers import Trainer 32 | from transformers.configuration_utils import PretrainedConfig 33 | from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator 34 | from transformers.debug_utils import DebugOption, DebugUnderflowOverflow 35 | from transformers.dependency_versions_check import dep_version_check 36 | from transformers.file_utils import ( 37 | CONFIG_NAME, 38 | WEIGHTS_NAME, 39 | is_apex_available, 40 | is_sagemaker_dp_enabled, 41 | is_sagemaker_mp_enabled, 42 | is_training_run_on_sagemaker, 43 | ) 44 | from transformers.modeling_utils import PreTrainedModel 45 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 46 | from transformers.trainer_callback import ( 47 | CallbackHandler, 48 | DefaultFlowCallback, 49 | PrinterCallback, 50 | ProgressCallback, 51 | TrainerCallback, 52 | TrainerControl, 53 | TrainerState, 54 | ) 55 | 56 | from transformers.trainer_pt_utils import ( 57 | IterableDatasetShard, 58 | LabelSmoother, 59 | ) 60 | from transformers.trainer_utils import ( 61 | EvalPrediction, 62 | ShardedDDPOption, 63 | TrainerMemoryTracker, 64 | TrainOutput, 65 | get_last_checkpoint, 66 | set_seed, 67 | speed_metrics, 68 | ) 69 | from transformers.training_args import TrainingArguments 70 | from transformers.utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES 71 | 72 | _is_native_amp_available = False 73 | 74 | DEFAULT_CALLBACKS = [DefaultFlowCallback] 75 | DEFAULT_PROGRESS_CALLBACK = ProgressCallback 76 | 77 | if is_apex_available(): 78 | from apex import amp 79 | 80 | if version.parse(torch.__version__) >= version.parse("1.6"): 81 | _is_native_amp_available = True 82 | from torch.cuda.amp import autocast 83 | 84 | if is_fairscale_available(): 85 | dep_version_check("fairscale") 86 | import fairscale 87 | from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP 88 | from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP 89 | from fairscale.optim.grad_scaler import ShardedGradScaler 90 | 91 | if is_sagemaker_dp_enabled(): 92 | import smdistributed.dataparallel.torch.distributed as dist 93 | from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP 94 | else: 95 | import torch.distributed as dist 96 | 97 | if is_sagemaker_mp_enabled(): 98 | import smdistributed.modelparallel.torch as smp 99 | 100 | from .trainer_pt_utils import smp_forward_backward 101 | 102 | if is_training_run_on_sagemaker(): 103 | logging.add_handler(StreamHandler(sys.stdout)) 104 | 105 | 106 | if TYPE_CHECKING: 107 | import optuna 108 | 109 | glue_tasks = ["cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"] 110 | 111 | task_to_metrics = { 112 | # GLEU benchmark 113 | "cola": "matthews_correlation", 114 | "mnli": "accuracy", 115 | "mnli-mm": "accuracy", 116 | "mrpc": "f1", 117 | "qnli": "accuracy", 118 | "qqp": "f1", 119 | "rte": "accuracy", 120 | "sst2": "accuracy", 121 | "stsb": "spearmanr", 122 | "wnli": "accuracy", 123 | } 124 | 125 | logger = logging.getLogger(__name__) 126 | logger.setLevel(logging.INFO) 127 | 128 | class ftTrainer(Trainer): 129 | 130 | from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state 131 | 132 | def __init__( 133 | self, 134 | model: Union[PreTrainedModel, torch.nn.Module] = None, 135 | args: TrainingArguments = None, 136 | data_collator: Optional[DataCollator] = None, 137 | dataset: Optional[Dataset] = None, 138 | train_dataset: Optional[Dataset] = None, 139 | eval_dataset: Optional[Dataset] = None, 140 | predict_dataset: Optional[Dataset] = None, 141 | label_list: Optional[Dataset] = None, 142 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 143 | model_init: Callable[[], PreTrainedModel] = None, 144 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 145 | callbacks: Optional[List[TrainerCallback]] = None, 146 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), 147 | ): 148 | if args is None: 149 | output_dir = "tmp_trainer" 150 | logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") 151 | args = TrainingArguments(output_dir=output_dir) 152 | self.args = args 153 | # Seed must be set before instantiating the model when using model 154 | set_seed(self.args.seed) 155 | self.hp_name = None 156 | self.deepspeed = None 157 | self.is_in_train = False 158 | 159 | # memory metrics - must set up as early as possible 160 | self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) 161 | self._memory_tracker.start() 162 | 163 | # force device and distributed setup init explicitly 164 | args._setup_devices 165 | 166 | if model is None: 167 | if model_init is not None: 168 | self.model_init = model_init 169 | model = self.call_model_init() 170 | else: 171 | raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") 172 | else: 173 | if model_init is not None: 174 | warnings.warn( 175 | "`Trainer` requires either a `model` or `model_init` argument, but not both. " 176 | "`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.", 177 | FutureWarning, 178 | ) 179 | self.model_init = model_init 180 | 181 | if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: 182 | self.is_model_parallel = True 183 | else: 184 | self.is_model_parallel = False 185 | 186 | # Setup Sharded DDP training 187 | self.sharded_ddp = None 188 | if len(args.sharded_ddp) > 0: 189 | if args.deepspeed: 190 | raise ValueError( 191 | "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." 192 | ) 193 | 194 | if args.local_rank == -1: 195 | raise ValueError("Using sharded DDP only works in distributed training.") 196 | elif not is_fairscale_available(): 197 | raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") 198 | elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: 199 | raise ImportError( 200 | "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " 201 | f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." 202 | ) 203 | elif ShardedDDPOption.SIMPLE in args.sharded_ddp: 204 | self.sharded_ddp = ShardedDDPOption.SIMPLE 205 | elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: 206 | self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 207 | elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: 208 | self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 209 | 210 | # one place to sort out whether to place the model on device or not 211 | # postpone switching model to cuda when: 212 | # 1. MP - since we are trying to fit a much bigger than 1 gpu model 213 | # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, 214 | # and we only use deepspeed for training at the moment 215 | # 3. full fp16 eval - since the model needs to be half'ed first 216 | # 4. Sharded DDP - same as MP 217 | self.place_model_on_device = args.place_model_on_device 218 | if ( 219 | self.is_model_parallel 220 | or args.deepspeed 221 | or (args.fp16_full_eval and not args.do_train) 222 | or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) 223 | ): 224 | self.place_model_on_device = False 225 | 226 | default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) 227 | self.data_collator = data_collator if data_collator is not None else default_collator 228 | self.dataset = dataset 229 | self.train_dataset = train_dataset 230 | self.eval_dataset = eval_dataset 231 | self.predict_dataset = predict_dataset 232 | self.label_list = label_list 233 | self.tokenizer = tokenizer 234 | 235 | if self.place_model_on_device: 236 | model = model.to(args.device) 237 | 238 | # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs 239 | if self.is_model_parallel: 240 | self.args._n_gpu = 1 241 | 242 | # later use `self.model is self.model_wrapped` to check if it's wrapped or not 243 | self.model_wrapped = model 244 | self.model = model 245 | 246 | self.compute_metrics = compute_metrics 247 | self.optimizer, self.lr_scheduler = optimizers 248 | if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): 249 | raise RuntimeError( 250 | "Passing a `model_init` is incompatible with providing the `optimizers` argument." 251 | "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." 252 | ) 253 | default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) 254 | callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks 255 | self.callback_handler = CallbackHandler( 256 | callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler 257 | ) 258 | self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) 259 | 260 | # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. 261 | self._loggers_initialized = False 262 | 263 | # Create output directory if needed 264 | if self.is_world_process_zero(): 265 | os.makedirs(self.args.output_dir, exist_ok=True) 266 | if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): 267 | raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") 268 | 269 | if args.max_steps > 0: 270 | logger.info("max_steps is given, it will override any value given in num_train_epochs") 271 | 272 | if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0: 273 | raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") 274 | 275 | self._signature_columns = None 276 | 277 | # Mixed precision setup 278 | self.use_apex = False 279 | self.use_amp = False 280 | self.fp16_backend = None 281 | 282 | if args.fp16: 283 | if args.fp16_backend == "auto": 284 | self.fp16_backend = "amp" if _is_native_amp_available else "apex" 285 | else: 286 | self.fp16_backend = args.fp16_backend 287 | logger.info(f"Using {self.fp16_backend} fp16 backend") 288 | 289 | if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16 290 | if self.fp16_backend == "amp": 291 | self.use_amp = True 292 | if is_sagemaker_mp_enabled(): 293 | self.scaler = smp.amp.GradScaler() 294 | elif self.sharded_ddp is not None: 295 | self.scaler = ShardedGradScaler() 296 | else: 297 | self.scaler = torch.cuda.amp.GradScaler() 298 | else: 299 | if not is_apex_available(): 300 | raise ImportError( 301 | "Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex." 302 | ) 303 | self.use_apex = True 304 | 305 | # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. 306 | if is_sagemaker_mp_enabled() and self.use_amp and args.max_grad_norm is not None and args.max_grad_norm > 0: 307 | raise ValueError( 308 | "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " 309 | "along 'max_grad_norm': 0 in your hyperparameters." 310 | ) 311 | 312 | # Label smoothing 313 | if self.args.label_smoothing_factor != 0: 314 | self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) 315 | else: 316 | self.label_smoother = None 317 | 318 | self.state = TrainerState() 319 | self.control = TrainerControl() 320 | # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then 321 | # returned to 0 every time flos need to be logged 322 | self.current_flos = 0 323 | self.hp_search_backend = None 324 | self.use_tune_checkpoints = False 325 | default_label_names = ( 326 | ["start_positions", "end_positions"] 327 | if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values() 328 | else ["labels"] 329 | ) 330 | self.label_names = default_label_names if self.args.label_names is None else self.args.label_names 331 | self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) 332 | 333 | # very last 334 | self._memory_tracker.stop_and_update_metrics() 335 | 336 | 337 | def train( 338 | self, resume_from_checkpoint: Optional[Union[str, bool]] = None, 339 | trial: Union["optuna.Trial", Dict[str, Any]] = None, **kwargs, 340 | ): 341 | 342 | # memory metrics - must set up as early as possible 343 | self._memory_tracker.start() 344 | 345 | args = self.args 346 | 347 | self.is_in_train = True 348 | 349 | # do_train is not a reliable argument, as it might not be set and .train() still called, so 350 | # the following is a workaround: 351 | if args.fp16_full_eval and not args.do_train: 352 | self.model = self.model.to(args.device) 353 | 354 | if "model_path" in kwargs: 355 | resume_from_checkpoint = kwargs.pop("model_path") 356 | warnings.warn( 357 | "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " 358 | "instead.", 359 | FutureWarning, 360 | ) 361 | if len(kwargs) > 0: 362 | raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") 363 | # This might change the seed so needs to run first. 364 | self._hp_search_setup(trial) 365 | 366 | # Model re-init 367 | model_reloaded = False 368 | if self.model_init is not None: 369 | # Seed must be set before instantiating the model when using model_init. 370 | set_seed(args.seed) 371 | self.model = self.call_model_init(trial) 372 | model_reloaded = True 373 | # Reinitializes optimizer and scheduler 374 | self.optimizer, self.lr_scheduler = None, None 375 | 376 | # Load potential model checkpoint 377 | if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: 378 | resume_from_checkpoint = get_last_checkpoint(args.output_dir) 379 | if resume_from_checkpoint is None: 380 | raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") 381 | 382 | if resume_from_checkpoint is not None: 383 | if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): 384 | raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") 385 | 386 | logger.info(f"Loading model from {resume_from_checkpoint}).") 387 | 388 | if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): 389 | config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) 390 | checkpoint_version = config.transformers_version 391 | if checkpoint_version is not None and checkpoint_version != __version__: 392 | logger.warn( 393 | f"You are resuming training from a checkpoint trained with {checkpoint_version} of " 394 | f"Transformers but your current version is {__version__}. This is not recommended and could " 395 | "yield to errors or unwanted behaviors." 396 | ) 397 | 398 | if args.deepspeed: 399 | # will be resumed in deepspeed_init 400 | pass 401 | else: 402 | # We load the model state dict on the CPU to avoid an OOM error. 403 | state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") 404 | # If the model is on the GPU, it still works! 405 | self._load_state_dict_in_model(state_dict) 406 | 407 | # If model was re-initialized, put it on the right device and update self.model_wrapped 408 | if model_reloaded: 409 | if self.place_model_on_device: 410 | self.model = self.model.to(args.device) 411 | self.model_wrapped = self.model 412 | 413 | # Keeping track whether we can can len() on the dataset or not 414 | train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized) 415 | 416 | # Data loader and number of training steps 417 | train_dataloader = self.get_train_dataloader() 418 | 419 | # Setting up training control variables: 420 | # number of training epochs: num_train_epochs 421 | # number of training steps per epoch: num_update_steps_per_epoch 422 | # total number of training steps to execute: max_steps 423 | if train_dataset_is_sized: 424 | num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps 425 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) 426 | if args.max_steps > 0: 427 | max_steps = args.max_steps 428 | num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( 429 | args.max_steps % num_update_steps_per_epoch > 0 430 | ) 431 | else: 432 | max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) 433 | num_train_epochs = math.ceil(args.num_train_epochs) 434 | else: 435 | # see __init__. max_steps is set when the dataset has no __len__ 436 | max_steps = args.max_steps 437 | num_train_epochs = int(args.num_train_epochs) 438 | num_update_steps_per_epoch = max_steps 439 | 440 | if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: 441 | debug_overflow = DebugUnderflowOverflow(self.model) # noqa 442 | 443 | delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE 444 | if args.deepspeed: 445 | deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( 446 | self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint 447 | ) 448 | self.model = deepspeed_engine.module 449 | self.model_wrapped = deepspeed_engine 450 | self.deepspeed = deepspeed_engine 451 | self.optimizer = optimizer 452 | self.lr_scheduler = lr_scheduler 453 | elif not delay_optimizer_creation: 454 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 455 | 456 | self.state = TrainerState() 457 | self.state.is_hyper_param_search = trial is not None 458 | 459 | model = self._wrap_model(self.model_wrapped) 460 | 461 | # for the rest of this function `model` is the outside model, whether it was wrapped or not 462 | if model is not self.model: 463 | self.model_wrapped = model 464 | 465 | if delay_optimizer_creation: 466 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 467 | 468 | # Check if saved optimizer or scheduler states exist 469 | self._load_optimizer_and_scheduler(resume_from_checkpoint) 470 | 471 | # Train! 472 | if args.local_rank != -1: 473 | world_size = dist.get_world_size() 474 | else: 475 | world_size = 1 476 | 477 | total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size 478 | num_examples = ( 479 | self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps 480 | ) 481 | 482 | logger.info("***** Running training *****") 483 | logger.info(f" Num examples = {num_examples}") 484 | logger.info(f" Num Epochs = {num_train_epochs}") 485 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 486 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 487 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 488 | logger.info(f" Total optimization steps = {max_steps}") 489 | 490 | self.state.epoch = 0 491 | start_time = time.time() 492 | epochs_trained = 0 493 | steps_trained_in_current_epoch = 0 494 | steps_trained_progress_bar = None 495 | 496 | # Check if continuing training from a checkpoint 497 | if resume_from_checkpoint is not None and os.path.isfile( 498 | os.path.join(resume_from_checkpoint, "trainer_state.json") 499 | ): 500 | self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json")) 501 | epochs_trained = self.state.global_step // num_update_steps_per_epoch 502 | if not args.ignore_data_skip: 503 | steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) 504 | steps_trained_in_current_epoch *= args.gradient_accumulation_steps 505 | else: 506 | steps_trained_in_current_epoch = 0 507 | 508 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 509 | logger.info(f" Continuing training from epoch {epochs_trained}") 510 | logger.info(f" Continuing training from global step {self.state.global_step}") 511 | if not args.ignore_data_skip: 512 | logger.info( 513 | f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " 514 | "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " 515 | "flag to your launch command, but you will resume the training on data already seen by your model." 516 | ) 517 | if self.is_local_process_zero() and not args.disable_tqdm: 518 | steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) 519 | steps_trained_progress_bar.set_description("Skipping the first batches") 520 | 521 | # Update the references 522 | self.callback_handler.model = self.model 523 | self.callback_handler.optimizer = self.optimizer 524 | self.callback_handler.lr_scheduler = self.lr_scheduler 525 | self.callback_handler.train_dataloader = train_dataloader 526 | self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None 527 | self.state.trial_params = hp_params(trial) if trial is not None else None 528 | # This should be the same if the state has been saved but in case the training arguments changed, it's safer 529 | # to set this after the load. 530 | self.state.max_steps = max_steps 531 | self.state.num_train_epochs = num_train_epochs 532 | self.state.is_local_process_zero = self.is_local_process_zero() 533 | self.state.is_world_process_zero = self.is_world_process_zero() 534 | 535 | # tr_loss is a tensor to avoid synchronization of TPUs through .item() 536 | tr_loss = torch.tensor(0.0).to(args.device) 537 | # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses 538 | self._total_loss_scalar = 0.0 539 | self._globalstep_last_logged = self.state.global_step 540 | model.zero_grad() 541 | 542 | self.control = self.callback_handler.on_train_begin(args, self.state, self.control) 543 | 544 | # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. 545 | if not args.ignore_data_skip: 546 | for epoch in range(epochs_trained): 547 | # We just need to begin an iteration to create the randomization of the sampler. 548 | for _ in train_dataloader: 549 | break 550 | 551 | best_metric, best_epoch = 0, 0 552 | for epoch in range(epochs_trained, num_train_epochs): 553 | if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): 554 | train_dataloader.sampler.set_epoch(epoch) 555 | elif isinstance(train_dataloader.dataset, IterableDatasetShard): 556 | train_dataloader.dataset.set_epoch(epoch) 557 | 558 | epoch_iterator = train_dataloader 559 | 560 | # Reset the past mems state at the beginning of each epoch if necessary. 561 | if args.past_index >= 0: 562 | self._past = None 563 | 564 | steps_in_epoch = ( 565 | len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps 566 | ) 567 | self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) 568 | 569 | for step, inputs in enumerate(epoch_iterator): 570 | 571 | starttime = datetime.datetime.now() 572 | # Skip past any already trained steps if resuming training 573 | if steps_trained_in_current_epoch > 0: 574 | steps_trained_in_current_epoch -= 1 575 | if steps_trained_progress_bar is not None: 576 | steps_trained_progress_bar.update(1) 577 | if steps_trained_in_current_epoch == 0: 578 | self._load_rng_state(resume_from_checkpoint) 579 | continue 580 | elif steps_trained_progress_bar is not None: 581 | steps_trained_progress_bar.close() 582 | steps_trained_progress_bar = None 583 | 584 | if step % args.gradient_accumulation_steps == 0: 585 | self.control = self.callback_handler.on_step_begin(args, self.state, self.control) 586 | 587 | if ( 588 | ((step + 1) % args.gradient_accumulation_steps != 0) 589 | and args.local_rank != -1 590 | and args._no_sync_in_gradient_accumulation 591 | ): 592 | # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. 593 | with model.no_sync(): 594 | tr_loss += self.training_step(model, inputs) 595 | else: 596 | tr_loss += self.training_step(model, inputs) 597 | self.current_flos += float(self.floating_point_ops(inputs)) 598 | 599 | # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps 600 | if self.deepspeed: 601 | self.deepspeed.step() 602 | 603 | if (step + 1) % args.gradient_accumulation_steps == 0 or ( 604 | # last step in epoch but step is always smaller than gradient_accumulation_steps 605 | steps_in_epoch <= args.gradient_accumulation_steps 606 | and (step + 1) == steps_in_epoch 607 | ): 608 | # Gradient clipping 609 | if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: 610 | # deepspeed does its own clipping 611 | 612 | if self.use_amp: 613 | # AMP: gradients need unscaling 614 | self.scaler.unscale_(self.optimizer) 615 | 616 | if hasattr(self.optimizer, "clip_grad_norm"): 617 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 618 | self.optimizer.clip_grad_norm(args.max_grad_norm) 619 | elif hasattr(model, "clip_grad_norm_"): 620 | # Some models (like FullyShardedDDP) have a specific way to do gradient clipping 621 | model.clip_grad_norm_(args.max_grad_norm) 622 | else: 623 | # Revert to normal clipping otherwise, handling Apex or full precision 624 | torch.nn.utils.clip_grad_norm_( 625 | amp.master_params(self.optimizer) if self.use_apex else model.parameters(), 626 | args.max_grad_norm, 627 | ) 628 | 629 | # Optimizer step 630 | optimizer_was_run = True 631 | if self.deepspeed: 632 | pass # called outside the loop 633 | elif self.use_amp: 634 | scale_before = self.scaler.get_scale() 635 | self.scaler.step(self.optimizer) 636 | self.scaler.update() 637 | scale_after = self.scaler.get_scale() 638 | optimizer_was_run = scale_before <= scale_after 639 | else: 640 | self.optimizer.step() 641 | 642 | if optimizer_was_run and not self.deepspeed: 643 | self.lr_scheduler.step() 644 | 645 | model.zero_grad() 646 | self.state.global_step += 1 647 | self.state.epoch = epoch + (step + 1) / steps_in_epoch 648 | 649 | self.control = self.callback_handler.on_step_end(args, self.state, self.control) 650 | self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) 651 | 652 | if self.control.should_epoch_stop or self.control.should_training_stop: 653 | break 654 | 655 | # TODO eval after epoch 656 | # self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) 657 | epoch_metric = self._evaluate_after_each_epoch(tr_loss, model, trial, epoch+1) 658 | 659 | if epoch_metric >= best_metric: 660 | best_metric = epoch_metric 661 | best_epoch = epoch + 1 662 | 663 | self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) 664 | if self.control.should_training_stop: 665 | break 666 | 667 | if args.past_index and hasattr(self, "_past"): 668 | # Clean the state at the end of training 669 | delattr(self, "_past") 670 | 671 | self.state.best_metric = {task_to_metrics[self.args.task_name]: best_metric} 672 | self.state.best_model_checkpoint = {'epoch': best_epoch} 673 | 674 | metrics = speed_metrics("train", start_time, self.state.max_steps) 675 | self.store_flos() 676 | metrics["total_flos"] = self.state.total_flos 677 | self.log(metrics) 678 | 679 | self.control = self.callback_handler.on_train_end(args, self.state, self.control) 680 | self._total_loss_scalar += tr_loss.item() 681 | 682 | self.is_in_train = False 683 | self._memory_tracker.stop_and_update_metrics(metrics) 684 | 685 | return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics) 686 | 687 | def _evaluate_after_each_epoch(self, tr_loss, model, trial, epoch): 688 | 689 | ''' logging training information ''' 690 | logs: Dict[str, float] = {} 691 | tr_loss_scalar = tr_loss.item() 692 | tr_loss -= tr_loss 693 | logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) 694 | logs["learning_rate"] = self._get_learning_rate() 695 | self._total_loss_scalar += tr_loss_scalar 696 | self._globalstep_last_logged = self.state.global_step 697 | self.log(logs) 698 | 699 | ''' Evaluating and logging validation information ''' 700 | tasks = [self.args.task_name] 701 | eval_datasets = [self.eval_dataset] 702 | if self.args.task_name == "mnli": 703 | tasks.append("mnli-mm") 704 | eval_datasets.append(self.dataset["validation_mismatched"]) 705 | 706 | for eval_dataset, task in zip(eval_datasets, tasks): 707 | eval_metrics = self.evaluate(eval_dataset=eval_dataset, metric_key_prefix='eval') 708 | 709 | self.log_metrics(f"Epoch{epoch} {task} Dev", metrics=eval_metrics) 710 | self.save_metrics(f"epoch{epoch}_{task}_dev", metrics=eval_metrics, combined=False) 711 | 712 | if self.args.do_predict: 713 | # self._save_checkpoint(model, trial, metrics=metrics) 714 | if self.args.task_name not in glue_tasks: 715 | ''' Evaluating and logging testing information ''' 716 | tasks = [self.args.task_name] 717 | predict_datasets = [self.predict_dataset] 718 | if self.args.task_name == "mnli": 719 | tasks.append("mnli-mm") 720 | predict_datasets.append(self.dataset["validation_mismatched"]) 721 | 722 | for predict_dataset, task in zip(predict_datasets, tasks): 723 | test_metrics = self.evaluate(eval_dataset=predict_dataset, metric_key_prefix='test') 724 | 725 | self.log_metrics(f"Epoch{epoch} {task} Test", metrics=test_metrics) 726 | self.save_metrics(f"epoch{epoch}_{task}_test", metrics=test_metrics, combined=False) 727 | else: 728 | tasks = [self.args.task_name] 729 | predict_datasets = [self.predict_dataset] 730 | if self.args.task_name == "mnli": 731 | tasks.append("mnli-mm") 732 | predict_datasets.append(self.dataset["test_mismatched"]) 733 | 734 | is_regression = self.args.task_name == "stsb" 735 | 736 | for predict_dataset, task in zip(predict_datasets, tasks): 737 | 738 | try: # Removing the `label` columns because it contains -1 and Trainer won't like that. 739 | predict_dataset.remove_columns_("label") 740 | except: # Except the keyerror happens when column "label" has already been removed. 741 | pass 742 | 743 | predictions = self.predict(predict_dataset, metric_key_prefix="predict").predictions 744 | predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) 745 | 746 | output_predict_file = os.path.join(self.args.output_dir, f"epoch{epoch}_{task}_test_results.tsv") 747 | 748 | with open(output_predict_file, "w") as writer: 749 | logger.info(f"***** Predict results {task} *****") 750 | writer.write("index\tprediction\n") 751 | for index, item in enumerate(predictions): 752 | if is_regression: 753 | writer.write(f"{index}\t{item:3.3f}\n") 754 | elif task in ["mnli", "mnli-mm", "qnli", "rte"]: 755 | item = self.label_list[item] 756 | writer.write(f"{index}\t{item}\n") 757 | else: 758 | writer.write(f"{index}\t{item}\n") 759 | 760 | primary_metric = task_to_metrics[self.args.task_name] 761 | if self.args.task_name in glue_tasks: 762 | return eval_metrics['eval_' + primary_metric] 763 | else: 764 | return test_metrics['test_' + primary_metric] 765 | 766 | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: 767 | 768 | model.train() 769 | inputs = self._prepare_inputs(inputs) 770 | 771 | if is_sagemaker_mp_enabled(): 772 | scaler = self.scaler if self.use_amp else None 773 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) 774 | return loss_mb.reduce_mean().detach().to(self.args.device) 775 | 776 | if self.use_amp: 777 | with autocast(): 778 | loss = self.compute_loss(model, inputs) 779 | else: 780 | loss = self.compute_loss(model, inputs) 781 | 782 | if self.args.n_gpu > 1: 783 | loss = loss.mean() # mean() to average on multi-gpu parallel training 784 | 785 | if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: 786 | # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` 787 | loss = loss / self.args.gradient_accumulation_steps 788 | 789 | if self.use_amp: 790 | self.scaler.scale(loss).backward() 791 | elif self.use_apex: 792 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 793 | scaled_loss.backward() 794 | elif self.deepspeed: 795 | # loss gets scaled under gradient_accumulation_steps in deepspeed 796 | loss = self.deepspeed.backward(loss) 797 | else: 798 | loss.backward() 799 | 800 | return loss.detach() 801 | 802 | def compute_loss(self, model, inputs, return_outputs=False): 803 | 804 | if self.label_smoother is not None and "labels" in inputs: 805 | labels = inputs.pop("labels") 806 | else: 807 | labels = None 808 | 809 | outputs = model(**inputs) 810 | # Save past state if it exists 811 | # TODO: this needs to be fixed and made cleaner later. 812 | if self.args.past_index >= 0: 813 | self._past = outputs[self.args.past_index] 814 | 815 | if labels is not None: 816 | loss = self.label_smoother(outputs, labels) 817 | else: 818 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 819 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 820 | 821 | return (loss, outputs) if return_outputs else loss 822 | -------------------------------------------------------------------------------- /finetune_wo_wiktionary/finetune.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import sys 5 | from dataclasses import dataclass, field 6 | from typing import Optional 7 | 8 | import numpy as np 9 | from datasets import load_dataset, load_metric 10 | 11 | import transformers 12 | from transformers import ( 13 | AutoConfig, 14 | AutoModelForSequenceClassification, 15 | AutoTokenizer, 16 | 17 | EvalPrediction, 18 | HfArgumentParser, 19 | PretrainedConfig, 20 | TrainingArguments, 21 | set_seed, 22 | ) 23 | 24 | from transformers.trainer_utils import is_main_process 25 | from transformers.utils import check_min_version 26 | 27 | from ftTrainer import ftTrainer 28 | from ftCollator import ( 29 | default_data_collator, 30 | DataCollatorWithPadding, 31 | ) 32 | 33 | check_min_version("4.7.0.dev0") 34 | 35 | task_to_keys = { 36 | "cola": ("sentence", None), 37 | "mnli": ("premise", "hypothesis"), 38 | "mrpc": ("sentence1", "sentence2"), 39 | "qnli": ("question", "sentence"), 40 | "qqp": ("question1", "question2"), 41 | "rte": ("sentence1", "sentence2"), 42 | "sst2": ("sentence", None), 43 | "stsb": ("sentence1", "sentence2"), 44 | "wnli": ("sentence1", "sentence2"), 45 | } 46 | 47 | logger = logging.getLogger(__name__) 48 | logger.setLevel(logging.INFO) 49 | 50 | @dataclass 51 | class DataTrainingArguments: 52 | """ 53 | Arguments pertaining to what data we are going to input our model for training and eval. 54 | 55 | Using `HfArgumentParser` we can turn this class 56 | into argparse arguments to be able to specify them on 57 | the command line. 58 | """ 59 | 60 | task_name: Optional[str] = field( 61 | default=None, 62 | metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, 63 | ) 64 | dataset_config_name: Optional[str] = field( 65 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 66 | ) 67 | max_seq_length: int = field( 68 | default=128, 69 | metadata={ 70 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 71 | "than this will be truncated, sequences shorter will be padded." 72 | }, 73 | ) 74 | overwrite_cache: bool = field( 75 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 76 | ) 77 | pad_to_max_length: bool = field( 78 | default=False, 79 | metadata={ 80 | "help": "Whether to pad all samples to `max_seq_length`. " 81 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 82 | }, 83 | ) 84 | max_train_ratios: Optional[float] = field( 85 | default=None, 86 | metadata={ 87 | "help": "For debugging purposes or quicker training, truncate the ratio of training examples to this " 88 | "value if set." 89 | }, 90 | ) 91 | max_train_samples: Optional[int] = field( 92 | default=None, 93 | metadata={ 94 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 95 | "value if set." 96 | }, 97 | ) 98 | max_eval_samples: Optional[int] = field( 99 | default=None, 100 | metadata={ 101 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 102 | "value if set." 103 | }, 104 | ) 105 | max_predict_samples: Optional[int] = field( 106 | default=None, 107 | metadata={ 108 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 109 | "value if set." 110 | }, 111 | ) 112 | train_file: Optional[str] = field( 113 | default=None, metadata={"help": "A csv or a json file containing the training data."} 114 | ) 115 | validation_file: Optional[str] = field( 116 | default=None, metadata={"help": "A csv or a json file containing the validation data."} 117 | ) 118 | test_file: Optional[str] = field( 119 | default=None, metadata={"help": "A csv or a json file containing the test data."}) 120 | 121 | def __post_init__(self): 122 | if self.task_name is not None: 123 | self.task_name = self.task_name.lower() 124 | # if self.task_name not in task_to_keys.keys(): 125 | # self.task_name = None 126 | elif self.train_file is None or self.validation_file is None: 127 | raise ValueError("Need either a GLUE task, a training/validation file or a dataset name.") 128 | else: 129 | train_extension = self.train_file.split(".")[-1] 130 | assert train_extension in ["csv", "json"], "`train_file` should be a csv or a json file." 131 | validation_extension = self.validation_file.split(".")[-1] 132 | assert ( 133 | validation_extension == train_extension 134 | ), "`validation_file` should have the same extension (csv or json) as `train_file`." 135 | 136 | 137 | @dataclass 138 | class ModelArguments: 139 | """ 140 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 141 | """ 142 | 143 | model_name_or_path: str = field( 144 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 145 | ) 146 | config_name: Optional[str] = field( 147 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 148 | ) 149 | tokenizer_name: Optional[str] = field( 150 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 151 | ) 152 | cache_dir: Optional[str] = field( 153 | default=None, 154 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 155 | ) 156 | use_fast_tokenizer: bool = field( 157 | default=True, 158 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 159 | ) 160 | model_revision: str = field( 161 | default="main", 162 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 163 | ) 164 | use_auth_token: bool = field( 165 | default=False, 166 | metadata={ 167 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 168 | "with private models)." 169 | }, 170 | ) 171 | 172 | def tokenize_dataset(datasets, data_args, model, tokenizer, is_regression, num_labels, label_list): 173 | 174 | # Some models have set the order of the labels to use, so let's make sure we do use it. 175 | label_to_id = None 176 | if ( 177 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 178 | and data_args.task_name in task_to_keys.keys() 179 | and not is_regression 180 | ): 181 | # Some have all caps in their config, some don't. 182 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 183 | if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): 184 | label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} 185 | else: 186 | logger.warning( 187 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 188 | f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." 189 | "\nIgnoring the model labels as a result.", 190 | ) 191 | elif data_args.task_name not in task_to_keys.keys() and not is_regression: 192 | label_to_id = {v: i for i, v in enumerate(label_list)} 193 | 194 | if data_args.max_seq_length > tokenizer.model_max_length: 195 | logger.warning( 196 | f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" 197 | f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." 198 | ) 199 | max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) 200 | 201 | # Preprocessing the datasets 202 | if data_args.task_name in task_to_keys.keys(): 203 | sentence1_key, sentence2_key = task_to_keys[data_args.task_name] 204 | else: 205 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 206 | non_label_column_names = [name for name in datasets["train"].column_names if name != "label"] 207 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 208 | sentence1_key, sentence2_key = "sentence1", "sentence2" 209 | else: 210 | sentence1_key, sentence2_key = 'text', None 211 | 212 | # Padding strategy 213 | if data_args.pad_to_max_length: 214 | padding = "max_length" 215 | else: 216 | # We will pad later, dynamically at batch creation, to the max sequence length in each batch 217 | padding = False 218 | 219 | def preprocess_function(examples): 220 | 221 | # Tokenize the texts 222 | args = ( 223 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 224 | ) 225 | result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) 226 | # Map labels to IDs (not necessary for GLUE tasks) 227 | if label_to_id is not None and "label" in examples: 228 | result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] 229 | return result 230 | 231 | datasets = datasets.map(preprocess_function, 232 | batched=True, 233 | load_from_cache_file=not data_args.overwrite_cache 234 | ) 235 | 236 | return datasets 237 | 238 | def main(): 239 | # See all possible arguments in src/transformers/training_args.py 240 | # or by passing the --help flag to this script. 241 | # We now keep distinct sets of args, for a cleaner separation of concerns. 242 | 243 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 244 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 245 | # If we pass only one argument to the script and it's the path to a json file, 246 | # let's parse it to get our arguments. 247 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 248 | else: 249 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 250 | 251 | # Setup logging 252 | logging.basicConfig( 253 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 254 | datefmt="%m/%d/%Y %H:%M:%S", 255 | handlers=[logging.StreamHandler(sys.stdout)], 256 | ) 257 | logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) 258 | 259 | # Log on each process the small summary: 260 | logger.warning( 261 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 262 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 263 | ) 264 | # Set the verbosity to info of the Transformers logger (on main process only): 265 | if is_main_process(training_args.local_rank): 266 | transformers.utils.logging.set_verbosity_info() 267 | transformers.utils.logging.enable_default_handler() 268 | transformers.utils.logging.enable_explicit_format() 269 | logger.info(f"Training/evaluation parameters {training_args}") 270 | 271 | # Set seed before initializing model. 272 | set_seed(training_args.seed) 273 | 274 | if data_args.task_name in task_to_keys.keys(): 275 | # Downloading and loading a dataset from the hub. 276 | datasets = load_dataset("glue", data_args.task_name, cache_dir=model_args.cache_dir) 277 | else: 278 | # Loading a dataset from your local files. 279 | # CSV/JSON training and evaluation files are needed. 280 | data_files = {"train": data_args.train_file, "validation": data_args.validation_file} 281 | 282 | # Get the test dataset: you can provide your own CSV/JSON test file (see below) 283 | # when you use `do_predict` without specifying a GLUE benchmark task. 284 | if training_args.do_predict: 285 | if data_args.test_file is not None: 286 | train_extension = data_args.train_file.split(".")[-1] 287 | test_extension = data_args.test_file.split(".")[-1] 288 | assert ( 289 | test_extension == train_extension 290 | ), "`test_file` should have the same extension (csv or json) as `train_file`." 291 | data_files["test"] = data_args.test_file 292 | else: 293 | raise ValueError("Need either a GLUE task or a test file for `do_predict`.") 294 | 295 | for key in data_files.keys(): 296 | logger.info(f"load a local file for {key}: {data_files[key]}") 297 | 298 | if data_args.train_file.endswith(".csv"): 299 | # Loading a dataset from local csv files 300 | datasets = load_dataset("csv", data_files=data_files, cache_dir=model_args.cache_dir) 301 | else: 302 | # Loading a dataset from local json files 303 | datasets = load_dataset("json", data_files=data_files, cache_dir=model_args.cache_dir) 304 | # See more about loading any type of standard or custom dataset at 305 | # https://huggingface.co/docs/datasets/loading_datasets.html. 306 | 307 | # Labels 308 | label_list = None 309 | if data_args.task_name in task_to_keys.keys(): 310 | is_regression = data_args.task_name == "stsb" 311 | if not is_regression: 312 | label_list = datasets["train"].features["label"].names 313 | num_labels = len(label_list) 314 | else: 315 | num_labels = 1 316 | else: 317 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 318 | is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"] 319 | if is_regression: 320 | num_labels = 1 321 | else: 322 | # A useful fast method: 323 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 324 | label_list = datasets["train"].unique("label") 325 | label_list.sort() # Let's sort it for determinism 326 | num_labels = len(label_list) 327 | 328 | # Load pretrained model and tokenizer 329 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 330 | # download model & vocab. 331 | config = AutoConfig.from_pretrained( 332 | model_args.model_name_or_path, 333 | num_labels=num_labels, 334 | finetuning_task=data_args.task_name, 335 | cache_dir=model_args.cache_dir, 336 | revision=model_args.model_revision, 337 | use_auth_token=True if model_args.use_auth_token else None, 338 | ) 339 | tokenizer = AutoTokenizer.from_pretrained( 340 | model_args.model_name_or_path, 341 | cache_dir=model_args.cache_dir, 342 | use_fast=model_args.use_fast_tokenizer, 343 | revision=model_args.model_revision, 344 | use_auth_token=True if model_args.use_auth_token else None, 345 | ) 346 | model = AutoModelForSequenceClassification.from_pretrained( 347 | model_args.model_name_or_path, 348 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 349 | config=config, 350 | cache_dir=model_args.cache_dir, 351 | revision=model_args.model_revision, 352 | use_auth_token=True if model_args.use_auth_token else None, 353 | ) 354 | 355 | datasets = tokenize_dataset(datasets, data_args, model, tokenizer, is_regression, num_labels, label_list) 356 | 357 | if training_args.do_train: 358 | if "train" not in datasets: 359 | raise ValueError("--do_train requires a train dataset") 360 | train_dataset = datasets["train"] 361 | if data_args.max_train_ratios is not None: 362 | data_args.max_train_samples = int(len(train_dataset) * data_args.max_train_ratios) 363 | if data_args.max_train_samples is not None: 364 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 365 | 366 | if training_args.do_eval: 367 | if "validation" not in datasets and "validation_matched" not in datasets: 368 | raise ValueError("--do_eval requires a validation dataset") 369 | eval_dataset = datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] 370 | if data_args.max_eval_samples is not None: 371 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 372 | 373 | if training_args.do_predict or data_args.task_name is not None or data_args.test_file is not None: 374 | if "test" not in datasets and "test_matched" not in datasets: 375 | raise ValueError("--do_predict requires a test dataset") 376 | predict_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] 377 | if data_args.max_predict_samples is not None: 378 | predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) 379 | 380 | # Log a few random samples from the training set: 381 | if training_args.do_train: 382 | for index in random.sample(range(len(train_dataset)), 3): 383 | logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") 384 | 385 | # Get the metric function 386 | if data_args.task_name in task_to_keys.keys(): 387 | metric = load_metric("glue", data_args.task_name) 388 | else: 389 | metric_acc = load_metric("accuracy") 390 | metric_f1 = load_metric("f1") 391 | 392 | # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a 393 | # predictions and label_ids field) and has to return a dictionary string to float. 394 | def compute_metrics(p: EvalPrediction): 395 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 396 | preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) 397 | if data_args.task_name in task_to_keys.keys(): 398 | result = metric.compute(predictions=preds, references=p.label_ids) 399 | if len(result) > 1: 400 | result["combined_score"] = np.mean(list(result.values())).item() 401 | return result 402 | elif is_regression: 403 | return {"mse": ((preds - p.label_ids) ** 2).mean().item()} 404 | else: 405 | result = metric_acc.compute(predictions=preds, references=p.label_ids) 406 | result['mi_f1'] = metric_f1.compute(predictions=preds, references=p.label_ids, average="micro")['f1'] 407 | result['ma_f1'] = metric_f1.compute(predictions=preds, references=p.label_ids, average="macro")['f1'] 408 | return result 409 | 410 | # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. 411 | if data_args.pad_to_max_length: 412 | data_collator = default_data_collator 413 | elif training_args.fp16: 414 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) 415 | else: 416 | data_collator = None 417 | 418 | training_args.task_name = data_args.task_name 419 | # Initialize our Trainer 420 | trainer = ftTrainer( 421 | model=model, 422 | args=training_args, 423 | dataset=datasets, 424 | train_dataset=train_dataset if training_args.do_train else None, 425 | eval_dataset=eval_dataset if training_args.do_eval else None, 426 | predict_dataset=predict_dataset if training_args.do_predict else None, 427 | label_list=label_list if training_args.do_predict else None, 428 | compute_metrics=compute_metrics, 429 | tokenizer=tokenizer, 430 | data_collator=data_collator, 431 | ) 432 | 433 | # Training 434 | if training_args.do_train: 435 | 436 | trainer.train() 437 | trainer.save_state() 438 | 439 | if training_args.do_predict: 440 | 441 | logger.info("*** Predict ***") 442 | 443 | # Loop to handle MNLI double evaluation (matched, mis-matched) 444 | tasks = [data_args.task_name] 445 | predict_datasets = [predict_dataset] 446 | if data_args.task_name == "mnli": 447 | tasks.append("mnli-mm") 448 | predict_datasets.append(datasets["test_mismatched"]) 449 | 450 | for predict_dataset, task in zip(predict_datasets, tasks): 451 | # Removing the `label` columns because it contains -1 and Trainer won't like that. 452 | predictions = trainer.predict(predict_dataset, metric_key_prefix="predict").predictions 453 | predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) 454 | 455 | output_predict_file = os.path.join(training_args.output_dir, f"predict_results_{task}.tsv") 456 | if trainer.is_world_process_zero(): 457 | with open(output_predict_file, "w") as writer: 458 | logger.info(f"***** Predict results {task} *****") 459 | writer.write("index\tprediction\n") 460 | for index, item in enumerate(predictions): 461 | if is_regression: 462 | writer.write(f"{index}\t{item:3.3f}\n") 463 | else: 464 | item = label_list[item] 465 | writer.write(f"{index}\t{item}\n") 466 | 467 | 468 | if __name__ == "__main__": 469 | main() 470 | -------------------------------------------------------------------------------- /finetune_wo_wiktionary/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export task=cola 4 | export outfolder=output-${task} 5 | # export model_name=bert-base-uncased 6 | export model_name=wyu1/DictBERT 7 | 8 | 9 | if [[ ${task} =~ $'cola' ]] || [[ ${task} =~ $'sst2' ]] 10 | then 11 | epoch=10; max_length=128 12 | elif [[ ${task} =~ $'qqp' ]] || [[ ${task} =~ $'mnli' ]] || [[ ${task} =~ $'qnli' ]] 13 | then 14 | epoch=5; max_length=128 15 | elif [[ ${OUTPUT_NAME} =~ $'rte' ]] 16 | then 17 | epoch=10; max_length=256 18 | elif [[ ${OUTPUT_NAME} =~ $'mrpc' ]] || [[ ${task} =~ $'stsb' ]] 19 | then 20 | epoch=5; max_length=256 21 | fi 22 | 23 | 24 | python -u finetune.py \ 25 | --model_name_or_path $model_name \ 26 | --task_name $task \ 27 | --do_train \ 28 | --do_eval \ 29 | --do_predict \ 30 | --max_seq_length $max_length \ 31 | --per_device_train_batch_size 32 \ 32 | --per_device_eval_batch_size 32 \ 33 | --learning_rate 2e-5 \ 34 | --num_train_epochs $epoch \ 35 | --output_dir $outfolder/$task/ 36 | -------------------------------------------------------------------------------- /finetune_wo_wiktionary/ftCollator.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Callable, Dict, List, NewType, Optional, Union 3 | 4 | import torch 5 | from transformers.file_utils import PaddingStrategy 6 | from transformers.modeling_utils import PreTrainedModel 7 | from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase 8 | 9 | InputDataClass = NewType("InputDataClass", Any) 10 | DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]]) 11 | 12 | 13 | @dataclass 14 | class DataCollatorWithPadding: 15 | 16 | tokenizer: PreTrainedTokenizerBase 17 | padding: Union[bool, str, PaddingStrategy] = True 18 | max_length: Optional[int] = None 19 | pad_to_multiple_of: Optional[int] = None 20 | 21 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: 22 | 23 | batch = self.tokenizer.pad( 24 | features, 25 | padding=self.padding, 26 | max_length=self.max_length, 27 | pad_to_multiple_of=self.pad_to_multiple_of, 28 | return_tensors="pt", 29 | ) 30 | 31 | if "label" in batch: 32 | batch["labels"] = batch["label"] 33 | del batch["label"] 34 | if "label_ids" in batch: 35 | batch["labels"] = batch["label_ids"] 36 | del batch["label_ids"] 37 | 38 | return batch 39 | 40 | 41 | def _collate_batch(examples, tokenizer, pad_to_multiple_of: Optional[int] = None): 42 | """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" 43 | # Tensorize if necessary. 44 | if isinstance(examples[0], (list, tuple)): 45 | examples = [torch.tensor(e, dtype=torch.long) for e in examples] 46 | 47 | # Check if padding is necessary. 48 | length_of_first = examples[0].size(0) 49 | are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) 50 | if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): 51 | return torch.stack(examples, dim=0) 52 | 53 | # If yes, check if we have a `pad_token`. 54 | if tokenizer._pad_token is None: 55 | raise ValueError( 56 | "You are attempting to pad samples but the tokenizer you are using" 57 | f" ({tokenizer.__class__.__name__}) does not have a pad token." 58 | ) 59 | 60 | # Creating the full tensor and filling it with our data. 61 | max_length = max(x.size(0) for x in examples) 62 | if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): 63 | max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of 64 | result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id) 65 | for i, example in enumerate(examples): 66 | if tokenizer.padding_side == "right": 67 | result[i, : example.shape[0]] = example 68 | else: 69 | result[i, -example.shape[0] :] = example 70 | return result 71 | 72 | 73 | def tolist(x: Union[List[Any], torch.Tensor]): 74 | return x.tolist() if isinstance(x, torch.Tensor) else x 75 | 76 | 77 | @dataclass 78 | class DataCollatorForSeq2Seq: 79 | 80 | tokenizer: PreTrainedTokenizerBase 81 | model: Optional[PreTrainedModel] = None 82 | padding: Union[bool, str, PaddingStrategy] = True 83 | max_length: Optional[int] = None 84 | pad_to_multiple_of: Optional[int] = None 85 | label_pad_token_id: int = -100 86 | 87 | def __call__(self, features): 88 | labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None 89 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the 90 | # same length to return tensors. 91 | if labels is not None: 92 | max_label_length = max(len(l) for l in labels) 93 | padding_side = self.tokenizer.padding_side 94 | for feature in features: 95 | remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) 96 | feature["labels"] = ( 97 | feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] 98 | ) 99 | 100 | features = self.tokenizer.pad( 101 | features, 102 | padding=self.padding, 103 | max_length=self.max_length, 104 | pad_to_multiple_of=self.pad_to_multiple_of, 105 | return_tensors="pt", 106 | ) 107 | 108 | # prepare decoder_input_ids 109 | if self.model is not None and hasattr(self.model, "prepare_decoder_input_ids_from_labels"): 110 | decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"]) 111 | features["decoder_input_ids"] = decoder_input_ids 112 | 113 | return features 114 | 115 | 116 | def default_data_collator(features: List[InputDataClass]) -> Dict[str, torch.Tensor]: 117 | """ 118 | Very simple data collator that simply collates batches of dict-like objects and performs special handling for 119 | potential keys named: 120 | 121 | - ``label``: handles a single value (int or float) per object 122 | - ``label_ids``: handles a list of values per object 123 | 124 | Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs 125 | to the model. See glue and ner for example of how it's useful. 126 | """ 127 | 128 | # In this function we'll make the assumption that all `features` in the batch 129 | # have the same attributes. 130 | # So we will look at the first element as a proxy for what attributes exist 131 | # on the whole batch. 132 | if not isinstance(features[0], (dict, BatchEncoding)): 133 | features = [vars(f) for f in features] 134 | 135 | first = features[0] 136 | batch = {} 137 | 138 | # Special handling for labels. 139 | # Ensure that tensor is created with the correct type 140 | # (it should be automatically the case, but let's make sure of it.) 141 | if "label" in first and first["label"] is not None: 142 | label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"] 143 | dtype = torch.long if isinstance(label, int) else torch.float 144 | batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype) 145 | elif "label_ids" in first and first["label_ids"] is not None: 146 | if isinstance(first["label_ids"], torch.Tensor): 147 | batch["labels"] = torch.stack([f["label_ids"] for f in features]) 148 | else: 149 | dtype = torch.long if type(first["label_ids"][0]) is int else torch.float 150 | batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype) 151 | 152 | # Handling of all other possible keys. 153 | # Again, we will use the first element to figure out which key/values are not None for this model. 154 | for k, v in first.items(): 155 | if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): 156 | if isinstance(v, torch.Tensor): 157 | batch[k] = torch.stack([f[k] for f in features]) 158 | else: 159 | batch[k] = torch.tensor([f[k] for f in features]) 160 | 161 | print(batch['input_ids'].shape) 162 | return batch 163 | -------------------------------------------------------------------------------- /finetune_wo_wiktionary/ftTrainer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import math 3 | import logging 4 | import os 5 | import sys 6 | import time 7 | import warnings 8 | from logging import StreamHandler 9 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 10 | 11 | from tqdm.auto import tqdm 12 | import datetime 13 | 14 | from transformers.integrations import ( # isort: split 15 | hp_params, 16 | is_fairscale_available, 17 | get_reporting_integration_callbacks, 18 | ) 19 | 20 | from transformers.deepspeed import deepspeed_init 21 | 22 | import numpy as np 23 | import torch 24 | from packaging import version 25 | from torch import nn 26 | from torch.utils.data.dataloader import DataLoader 27 | from torch.utils.data.dataset import Dataset 28 | from torch.utils.data.distributed import DistributedSampler 29 | 30 | from transformers import __version__ 31 | from transformers import Trainer 32 | from transformers.configuration_utils import PretrainedConfig 33 | from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator 34 | from transformers.debug_utils import DebugOption, DebugUnderflowOverflow 35 | from transformers.dependency_versions_check import dep_version_check 36 | from transformers.file_utils import ( 37 | CONFIG_NAME, 38 | WEIGHTS_NAME, 39 | is_apex_available, 40 | is_datasets_available, 41 | is_sagemaker_dp_enabled, 42 | is_sagemaker_mp_enabled, 43 | is_training_run_on_sagemaker, 44 | ) 45 | from transformers.modeling_utils import PreTrainedModel 46 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 47 | from transformers.trainer_callback import ( 48 | CallbackHandler, 49 | DefaultFlowCallback, 50 | PrinterCallback, 51 | ProgressCallback, 52 | TrainerCallback, 53 | TrainerControl, 54 | TrainerState, 55 | ) 56 | 57 | from transformers.trainer_pt_utils import ( 58 | IterableDatasetShard, 59 | LabelSmoother, 60 | ) 61 | from transformers.trainer_utils import ( 62 | EvalPrediction, 63 | ShardedDDPOption, 64 | TrainerMemoryTracker, 65 | TrainOutput, 66 | get_last_checkpoint, 67 | set_seed, 68 | speed_metrics, 69 | ) 70 | from transformers.training_args import TrainingArguments 71 | from transformers.utils.modeling_auto_mapping import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES 72 | 73 | _is_native_amp_available = False 74 | 75 | DEFAULT_CALLBACKS = [DefaultFlowCallback] 76 | DEFAULT_PROGRESS_CALLBACK = ProgressCallback 77 | 78 | if is_apex_available(): 79 | from apex import amp 80 | 81 | if version.parse(torch.__version__) >= version.parse("1.6"): 82 | _is_native_amp_available = True 83 | from torch.cuda.amp import autocast 84 | 85 | if is_fairscale_available(): 86 | dep_version_check("fairscale") 87 | import fairscale 88 | from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP 89 | from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP 90 | from fairscale.optim.grad_scaler import ShardedGradScaler 91 | 92 | if is_sagemaker_dp_enabled(): 93 | import smdistributed.dataparallel.torch.distributed as dist 94 | from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP 95 | else: 96 | import torch.distributed as dist 97 | 98 | if is_sagemaker_mp_enabled(): 99 | import smdistributed.modelparallel.torch as smp 100 | 101 | from .trainer_pt_utils import smp_forward_backward 102 | 103 | if is_training_run_on_sagemaker(): 104 | logging.add_handler(StreamHandler(sys.stdout)) 105 | 106 | 107 | if TYPE_CHECKING: 108 | import optuna 109 | 110 | glue_tasks = ["cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"] 111 | 112 | task_to_metrics = { 113 | "cola": "matthews_correlation", 114 | "mnli": "accuracy", 115 | "mnli-mm": "accuracy", 116 | "mrpc": "f1", 117 | "qnli": "accuracy", 118 | "qqp": "f1", 119 | "rte": "accuracy", 120 | "sst2": "accuracy", 121 | "stsb": "spearmanr", 122 | "wnli": "accuracy", 123 | } 124 | 125 | logger = logging.getLogger(__name__) 126 | logger.setLevel(logging.INFO) 127 | 128 | class ftTrainer(Trainer): 129 | 130 | from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state 131 | 132 | def __init__( 133 | self, 134 | model: Union[PreTrainedModel, torch.nn.Module] = None, 135 | args: TrainingArguments = None, 136 | data_collator: Optional[DataCollator] = None, 137 | dataset: Optional[Dataset] = None, 138 | train_dataset: Optional[Dataset] = None, 139 | eval_dataset: Optional[Dataset] = None, 140 | predict_dataset: Optional[Dataset] = None, 141 | label_list: Optional[Dataset] = None, 142 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 143 | model_init: Callable[[], PreTrainedModel] = None, 144 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 145 | callbacks: Optional[List[TrainerCallback]] = None, 146 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), 147 | ): 148 | if args is None: 149 | output_dir = "tmp_trainer" 150 | logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") 151 | args = TrainingArguments(output_dir=output_dir) 152 | self.args = args 153 | # Seed must be set before instantiating the model when using model 154 | set_seed(self.args.seed) 155 | self.hp_name = None 156 | self.deepspeed = None 157 | self.is_in_train = False 158 | 159 | # memory metrics - must set up as early as possible 160 | self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) 161 | self._memory_tracker.start() 162 | 163 | # force device and distributed setup init explicitly 164 | args._setup_devices 165 | 166 | if model is None: 167 | if model_init is not None: 168 | self.model_init = model_init 169 | model = self.call_model_init() 170 | else: 171 | raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") 172 | else: 173 | if model_init is not None: 174 | warnings.warn( 175 | "`Trainer` requires either a `model` or `model_init` argument, but not both. " 176 | "`model_init` will overwrite your model when calling the `train` method. This will become a fatal error in the next release.", 177 | FutureWarning, 178 | ) 179 | self.model_init = model_init 180 | 181 | if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: 182 | self.is_model_parallel = True 183 | else: 184 | self.is_model_parallel = False 185 | 186 | # Setup Sharded DDP training 187 | self.sharded_ddp = None 188 | if len(args.sharded_ddp) > 0: 189 | if args.deepspeed: 190 | raise ValueError( 191 | "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." 192 | ) 193 | 194 | if args.local_rank == -1: 195 | raise ValueError("Using sharded DDP only works in distributed training.") 196 | elif not is_fairscale_available(): 197 | raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") 198 | elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: 199 | raise ImportError( 200 | "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " 201 | f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." 202 | ) 203 | elif ShardedDDPOption.SIMPLE in args.sharded_ddp: 204 | self.sharded_ddp = ShardedDDPOption.SIMPLE 205 | elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: 206 | self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 207 | elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: 208 | self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 209 | 210 | # one place to sort out whether to place the model on device or not 211 | # postpone switching model to cuda when: 212 | # 1. MP - since we are trying to fit a much bigger than 1 gpu model 213 | # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, 214 | # and we only use deepspeed for training at the moment 215 | # 3. full fp16 eval - since the model needs to be half'ed first 216 | # 4. Sharded DDP - same as MP 217 | self.place_model_on_device = args.place_model_on_device 218 | if ( 219 | self.is_model_parallel 220 | or args.deepspeed 221 | or (args.fp16_full_eval and not args.do_train) 222 | or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) 223 | ): 224 | self.place_model_on_device = False 225 | 226 | default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) 227 | self.data_collator = data_collator if data_collator is not None else default_collator 228 | self.dataset = dataset 229 | self.train_dataset = train_dataset 230 | self.eval_dataset = eval_dataset 231 | self.predict_dataset = predict_dataset 232 | self.label_list = label_list 233 | self.tokenizer = tokenizer 234 | 235 | if self.place_model_on_device: 236 | model = model.to(args.device) 237 | 238 | # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs 239 | if self.is_model_parallel: 240 | self.args._n_gpu = 1 241 | 242 | # later use `self.model is self.model_wrapped` to check if it's wrapped or not 243 | self.model_wrapped = model 244 | self.model = model 245 | 246 | self.compute_metrics = compute_metrics 247 | self.optimizer, self.lr_scheduler = optimizers 248 | if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): 249 | raise RuntimeError( 250 | "Passing a `model_init` is incompatible with providing the `optimizers` argument." 251 | "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." 252 | ) 253 | default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) 254 | callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks 255 | self.callback_handler = CallbackHandler( 256 | callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler 257 | ) 258 | self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) 259 | 260 | # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. 261 | self._loggers_initialized = False 262 | 263 | # Create output directory if needed 264 | if self.is_world_process_zero(): 265 | os.makedirs(self.args.output_dir, exist_ok=True) 266 | if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): 267 | raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") 268 | 269 | if args.max_steps > 0: 270 | logger.info("max_steps is given, it will override any value given in num_train_epochs") 271 | 272 | if train_dataset is not None and not isinstance(train_dataset, collections.abc.Sized) and args.max_steps <= 0: 273 | raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") 274 | 275 | self._signature_columns = None 276 | 277 | # Mixed precision setup 278 | self.use_apex = False 279 | self.use_amp = False 280 | self.fp16_backend = None 281 | 282 | if args.fp16: 283 | if args.fp16_backend == "auto": 284 | self.fp16_backend = "amp" if _is_native_amp_available else "apex" 285 | else: 286 | self.fp16_backend = args.fp16_backend 287 | logger.info(f"Using {self.fp16_backend} fp16 backend") 288 | 289 | if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16 290 | if self.fp16_backend == "amp": 291 | self.use_amp = True 292 | if is_sagemaker_mp_enabled(): 293 | self.scaler = smp.amp.GradScaler() 294 | elif self.sharded_ddp is not None: 295 | self.scaler = ShardedGradScaler() 296 | else: 297 | self.scaler = torch.cuda.amp.GradScaler() 298 | else: 299 | if not is_apex_available(): 300 | raise ImportError( 301 | "Using FP16 with APEX but APEX is not installed, please refer to https://www.github.com/nvidia/apex." 302 | ) 303 | self.use_apex = True 304 | 305 | # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. 306 | if is_sagemaker_mp_enabled() and self.use_amp and args.max_grad_norm is not None and args.max_grad_norm > 0: 307 | raise ValueError( 308 | "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " 309 | "along 'max_grad_norm': 0 in your hyperparameters." 310 | ) 311 | 312 | # Label smoothing 313 | if self.args.label_smoothing_factor != 0: 314 | self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) 315 | else: 316 | self.label_smoother = None 317 | 318 | self.state = TrainerState() 319 | self.control = TrainerControl() 320 | # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then 321 | # returned to 0 every time flos need to be logged 322 | self.current_flos = 0 323 | self.hp_search_backend = None 324 | self.use_tune_checkpoints = False 325 | default_label_names = ( 326 | ["start_positions", "end_positions"] 327 | if type(self.model).__name__ in MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES.values() 328 | else ["labels"] 329 | ) 330 | self.label_names = default_label_names if self.args.label_names is None else self.args.label_names 331 | self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) 332 | 333 | # very last 334 | self._memory_tracker.stop_and_update_metrics() 335 | 336 | 337 | def train( 338 | self, resume_from_checkpoint: Optional[Union[str, bool]] = None, 339 | trial: Union["optuna.Trial", Dict[str, Any]] = None, **kwargs, 340 | ): 341 | 342 | # memory metrics - must set up as early as possible 343 | self._memory_tracker.start() 344 | 345 | args = self.args 346 | 347 | self.is_in_train = True 348 | 349 | # do_train is not a reliable argument, as it might not be set and .train() still called, so 350 | # the following is a workaround: 351 | if args.fp16_full_eval and not args.do_train: 352 | self.model = self.model.to(args.device) 353 | 354 | if "model_path" in kwargs: 355 | resume_from_checkpoint = kwargs.pop("model_path") 356 | warnings.warn( 357 | "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " 358 | "instead.", 359 | FutureWarning, 360 | ) 361 | if len(kwargs) > 0: 362 | raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") 363 | # This might change the seed so needs to run first. 364 | self._hp_search_setup(trial) 365 | 366 | # Model re-init 367 | model_reloaded = False 368 | if self.model_init is not None: 369 | # Seed must be set before instantiating the model when using model_init. 370 | set_seed(args.seed) 371 | self.model = self.call_model_init(trial) 372 | model_reloaded = True 373 | # Reinitializes optimizer and scheduler 374 | self.optimizer, self.lr_scheduler = None, None 375 | 376 | # Load potential model checkpoint 377 | if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: 378 | resume_from_checkpoint = get_last_checkpoint(args.output_dir) 379 | if resume_from_checkpoint is None: 380 | raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") 381 | 382 | if resume_from_checkpoint is not None: 383 | if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): 384 | raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") 385 | 386 | logger.info(f"Loading model from {resume_from_checkpoint}).") 387 | 388 | if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): 389 | config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) 390 | checkpoint_version = config.transformers_version 391 | if checkpoint_version is not None and checkpoint_version != __version__: 392 | logger.warn( 393 | f"You are resuming training from a checkpoint trained with {checkpoint_version} of " 394 | f"Transformers but your current version is {__version__}. This is not recommended and could " 395 | "yield to errors or unwanted behaviors." 396 | ) 397 | 398 | if args.deepspeed: 399 | # will be resumed in deepspeed_init 400 | pass 401 | else: 402 | # We load the model state dict on the CPU to avoid an OOM error. 403 | state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") 404 | # If the model is on the GPU, it still works! 405 | self._load_state_dict_in_model(state_dict) 406 | 407 | # If model was re-initialized, put it on the right device and update self.model_wrapped 408 | if model_reloaded: 409 | if self.place_model_on_device: 410 | self.model = self.model.to(args.device) 411 | self.model_wrapped = self.model 412 | 413 | # Keeping track whether we can can len() on the dataset or not 414 | train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized) 415 | 416 | # Data loader and number of training steps 417 | train_dataloader = self.get_train_dataloader() 418 | 419 | # Setting up training control variables: 420 | # number of training epochs: num_train_epochs 421 | # number of training steps per epoch: num_update_steps_per_epoch 422 | # total number of training steps to execute: max_steps 423 | if train_dataset_is_sized: 424 | num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps 425 | num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) 426 | if args.max_steps > 0: 427 | max_steps = args.max_steps 428 | num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( 429 | args.max_steps % num_update_steps_per_epoch > 0 430 | ) 431 | else: 432 | max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) 433 | num_train_epochs = math.ceil(args.num_train_epochs) 434 | else: 435 | # see __init__. max_steps is set when the dataset has no __len__ 436 | max_steps = args.max_steps 437 | num_train_epochs = int(args.num_train_epochs) 438 | num_update_steps_per_epoch = max_steps 439 | 440 | if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: 441 | debug_overflow = DebugUnderflowOverflow(self.model) # noqa 442 | 443 | delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE 444 | if args.deepspeed: 445 | deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( 446 | self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint 447 | ) 448 | self.model = deepspeed_engine.module 449 | self.model_wrapped = deepspeed_engine 450 | self.deepspeed = deepspeed_engine 451 | self.optimizer = optimizer 452 | self.lr_scheduler = lr_scheduler 453 | elif not delay_optimizer_creation: 454 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 455 | 456 | self.state = TrainerState() 457 | self.state.is_hyper_param_search = trial is not None 458 | 459 | model = self._wrap_model(self.model_wrapped) 460 | 461 | # for the rest of this function `model` is the outside model, whether it was wrapped or not 462 | if model is not self.model: 463 | self.model_wrapped = model 464 | 465 | if delay_optimizer_creation: 466 | self.create_optimizer_and_scheduler(num_training_steps=max_steps) 467 | 468 | # Check if saved optimizer or scheduler states exist 469 | self._load_optimizer_and_scheduler(resume_from_checkpoint) 470 | 471 | # Train! 472 | if args.local_rank != -1: 473 | world_size = dist.get_world_size() 474 | else: 475 | world_size = 1 476 | 477 | total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size 478 | num_examples = ( 479 | self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps 480 | ) 481 | 482 | logger.info("***** Running training *****") 483 | logger.info(f" Num examples = {num_examples}") 484 | logger.info(f" Num Epochs = {num_train_epochs}") 485 | logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 486 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 487 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 488 | logger.info(f" Total optimization steps = {max_steps}") 489 | 490 | self.state.epoch = 0 491 | start_time = time.time() 492 | epochs_trained = 0 493 | steps_trained_in_current_epoch = 0 494 | steps_trained_progress_bar = None 495 | 496 | # Check if continuing training from a checkpoint 497 | if resume_from_checkpoint is not None and os.path.isfile( 498 | os.path.join(resume_from_checkpoint, "trainer_state.json") 499 | ): 500 | self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, "trainer_state.json")) 501 | epochs_trained = self.state.global_step // num_update_steps_per_epoch 502 | if not args.ignore_data_skip: 503 | steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) 504 | steps_trained_in_current_epoch *= args.gradient_accumulation_steps 505 | else: 506 | steps_trained_in_current_epoch = 0 507 | 508 | logger.info(" Continuing training from checkpoint, will skip to saved global_step") 509 | logger.info(f" Continuing training from epoch {epochs_trained}") 510 | logger.info(f" Continuing training from global step {self.state.global_step}") 511 | if not args.ignore_data_skip: 512 | logger.info( 513 | f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " 514 | "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " 515 | "flag to your launch command, but you will resume the training on data already seen by your model." 516 | ) 517 | if self.is_local_process_zero() and not args.disable_tqdm: 518 | steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) 519 | steps_trained_progress_bar.set_description("Skipping the first batches") 520 | 521 | # Update the references 522 | self.callback_handler.model = self.model 523 | self.callback_handler.optimizer = self.optimizer 524 | self.callback_handler.lr_scheduler = self.lr_scheduler 525 | self.callback_handler.train_dataloader = train_dataloader 526 | self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None 527 | self.state.trial_params = hp_params(trial) if trial is not None else None 528 | # This should be the same if the state has been saved but in case the training arguments changed, it's safer 529 | # to set this after the load. 530 | self.state.max_steps = max_steps 531 | self.state.num_train_epochs = num_train_epochs 532 | self.state.is_local_process_zero = self.is_local_process_zero() 533 | self.state.is_world_process_zero = self.is_world_process_zero() 534 | 535 | # tr_loss is a tensor to avoid synchronization of TPUs through .item() 536 | tr_loss = torch.tensor(0.0).to(args.device) 537 | # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses 538 | self._total_loss_scalar = 0.0 539 | self._globalstep_last_logged = self.state.global_step 540 | model.zero_grad() 541 | 542 | self.control = self.callback_handler.on_train_begin(args, self.state, self.control) 543 | 544 | # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. 545 | if not args.ignore_data_skip: 546 | for epoch in range(epochs_trained): 547 | # We just need to begin an iteration to create the randomization of the sampler. 548 | for _ in train_dataloader: 549 | break 550 | 551 | best_metric, best_epoch = 0, 0 552 | for epoch in range(epochs_trained, num_train_epochs): 553 | if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): 554 | train_dataloader.sampler.set_epoch(epoch) 555 | elif isinstance(train_dataloader.dataset, IterableDatasetShard): 556 | train_dataloader.dataset.set_epoch(epoch) 557 | 558 | epoch_iterator = train_dataloader 559 | 560 | # Reset the past mems state at the beginning of each epoch if necessary. 561 | if args.past_index >= 0: 562 | self._past = None 563 | 564 | steps_in_epoch = ( 565 | len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps 566 | ) 567 | self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) 568 | 569 | for step, inputs in enumerate(epoch_iterator): 570 | 571 | starttime = datetime.datetime.now() 572 | # Skip past any already trained steps if resuming training 573 | if steps_trained_in_current_epoch > 0: 574 | steps_trained_in_current_epoch -= 1 575 | if steps_trained_progress_bar is not None: 576 | steps_trained_progress_bar.update(1) 577 | if steps_trained_in_current_epoch == 0: 578 | self._load_rng_state(resume_from_checkpoint) 579 | continue 580 | elif steps_trained_progress_bar is not None: 581 | steps_trained_progress_bar.close() 582 | steps_trained_progress_bar = None 583 | 584 | if step % args.gradient_accumulation_steps == 0: 585 | self.control = self.callback_handler.on_step_begin(args, self.state, self.control) 586 | 587 | if ( 588 | ((step + 1) % args.gradient_accumulation_steps != 0) 589 | and args.local_rank != -1 590 | and args._no_sync_in_gradient_accumulation 591 | ): 592 | # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. 593 | with model.no_sync(): 594 | tr_loss += self.training_step(model, inputs) 595 | else: 596 | tr_loss += self.training_step(model, inputs) 597 | self.current_flos += float(self.floating_point_ops(inputs)) 598 | 599 | # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps 600 | if self.deepspeed: 601 | self.deepspeed.step() 602 | 603 | if (step + 1) % args.gradient_accumulation_steps == 0 or ( 604 | # last step in epoch but step is always smaller than gradient_accumulation_steps 605 | steps_in_epoch <= args.gradient_accumulation_steps 606 | and (step + 1) == steps_in_epoch 607 | ): 608 | # Gradient clipping 609 | if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: 610 | # deepspeed does its own clipping 611 | 612 | if self.use_amp: 613 | # AMP: gradients need unscaling 614 | self.scaler.unscale_(self.optimizer) 615 | 616 | if hasattr(self.optimizer, "clip_grad_norm"): 617 | # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping 618 | self.optimizer.clip_grad_norm(args.max_grad_norm) 619 | elif hasattr(model, "clip_grad_norm_"): 620 | # Some models (like FullyShardedDDP) have a specific way to do gradient clipping 621 | model.clip_grad_norm_(args.max_grad_norm) 622 | else: 623 | # Revert to normal clipping otherwise, handling Apex or full precision 624 | torch.nn.utils.clip_grad_norm_( 625 | amp.master_params(self.optimizer) if self.use_apex else model.parameters(), 626 | args.max_grad_norm, 627 | ) 628 | 629 | # Optimizer step 630 | optimizer_was_run = True 631 | if self.deepspeed: 632 | pass # called outside the loop 633 | elif self.use_amp: 634 | scale_before = self.scaler.get_scale() 635 | self.scaler.step(self.optimizer) 636 | self.scaler.update() 637 | scale_after = self.scaler.get_scale() 638 | optimizer_was_run = scale_before <= scale_after 639 | else: 640 | self.optimizer.step() 641 | 642 | if optimizer_was_run and not self.deepspeed: 643 | self.lr_scheduler.step() 644 | 645 | model.zero_grad() 646 | self.state.global_step += 1 647 | self.state.epoch = epoch + (step + 1) / steps_in_epoch 648 | 649 | self.control = self.callback_handler.on_step_end(args, self.state, self.control) 650 | self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) 651 | 652 | if self.control.should_epoch_stop or self.control.should_training_stop: 653 | break 654 | 655 | # TODO eval after epoch 656 | # self._maybe_log_save_evaluate(tr_loss, model, trial, epoch) 657 | epoch_metric = self._evaluate_after_each_epoch(tr_loss, model, trial, epoch+1) 658 | 659 | if epoch_metric >= best_metric: 660 | best_metric = epoch_metric 661 | best_epoch = epoch + 1 662 | 663 | self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) 664 | if self.control.should_training_stop: 665 | break 666 | 667 | if args.past_index and hasattr(self, "_past"): 668 | # Clean the state at the end of training 669 | delattr(self, "_past") 670 | 671 | self.state.best_metric = {task_to_metrics[self.args.task_name]: best_metric} 672 | self.state.best_model_checkpoint = {'epoch': best_epoch} 673 | 674 | metrics = speed_metrics("train", start_time, self.state.max_steps) 675 | self.store_flos() 676 | metrics["total_flos"] = self.state.total_flos 677 | self.log(metrics) 678 | 679 | self.control = self.callback_handler.on_train_end(args, self.state, self.control) 680 | self._total_loss_scalar += tr_loss.item() 681 | 682 | self.is_in_train = False 683 | self._memory_tracker.stop_and_update_metrics(metrics) 684 | 685 | return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics) 686 | 687 | def _evaluate_after_each_epoch(self, tr_loss, model, trial, epoch): 688 | 689 | ''' logging training information ''' 690 | logs: Dict[str, float] = {} 691 | tr_loss_scalar = tr_loss.item() 692 | tr_loss -= tr_loss 693 | logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) 694 | logs["learning_rate"] = self._get_learning_rate() 695 | self._total_loss_scalar += tr_loss_scalar 696 | self._globalstep_last_logged = self.state.global_step 697 | self.log(logs) 698 | 699 | ''' Evaluating and logging validation information ''' 700 | tasks = [self.args.task_name] 701 | eval_datasets = [self.eval_dataset] 702 | if self.args.task_name == "mnli": 703 | tasks.append("mnli-mm") 704 | eval_datasets.append(self.dataset["validation_mismatched"]) 705 | 706 | for eval_dataset, task in zip(eval_datasets, tasks): 707 | eval_metrics = self.evaluate(eval_dataset=eval_dataset, metric_key_prefix='eval') 708 | 709 | self.log_metrics(f"Epoch{epoch} {task} Dev", metrics=eval_metrics) 710 | self.save_metrics(f"epoch{epoch}_{task}_dev", metrics=eval_metrics, combined=False) 711 | 712 | if self.args.do_predict: 713 | # self._save_checkpoint(model, trial, metrics=metrics) 714 | if self.args.task_name not in glue_tasks: 715 | ''' Evaluating and logging testing information ''' 716 | tasks = [self.args.task_name] 717 | predict_datasets = [self.predict_dataset] 718 | if self.args.task_name == "mnli": 719 | tasks.append("mnli-mm") 720 | predict_datasets.append(self.dataset["validation_mismatched"]) 721 | 722 | for predict_dataset, task in zip(predict_datasets, tasks): 723 | test_metrics = self.evaluate(eval_dataset=predict_dataset, metric_key_prefix='test') 724 | 725 | self.log_metrics(f"Epoch{epoch} {task} Test", metrics=test_metrics) 726 | self.save_metrics(f"epoch{epoch}_{task}_test", metrics=test_metrics, combined=False) 727 | else: 728 | tasks = [self.args.task_name] 729 | predict_datasets = [self.predict_dataset] 730 | if self.args.task_name == "mnli": 731 | tasks.append("mnli-mm") 732 | predict_datasets.append(self.dataset["test_mismatched"]) 733 | 734 | is_regression = self.args.task_name == "stsb" 735 | 736 | for predict_dataset, task in zip(predict_datasets, tasks): 737 | 738 | try: # Removing the `label` columns because it contains -1 and Trainer won't like that. 739 | predict_dataset.remove_columns_("label") 740 | except: # Except the keyerror happens when column "label" has already been removed. 741 | pass 742 | 743 | predictions = self.predict(predict_dataset, metric_key_prefix="predict").predictions 744 | predictions = np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) 745 | 746 | output_predict_file = os.path.join(self.args.output_dir, f"epoch{epoch}_{task}_test_results.tsv") 747 | 748 | with open(output_predict_file, "w") as writer: 749 | logger.info(f"***** Predict results {task} *****") 750 | writer.write("index\tprediction\n") 751 | for index, item in enumerate(predictions): 752 | if is_regression: 753 | writer.write(f"{index}\t{item:3.3f}\n") 754 | elif task in ["mnli", "mnli-mm", "qnli", "rte"]: 755 | item = self.label_list[item] 756 | writer.write(f"{index}\t{item}\n") 757 | else: 758 | writer.write(f"{index}\t{item}\n") 759 | 760 | primary_metric = task_to_metrics[self.args.task_name] 761 | if self.args.task_name in glue_tasks: 762 | return eval_metrics['eval_' + primary_metric] 763 | else: 764 | return test_metrics['test_' + primary_metric] 765 | 766 | def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: 767 | 768 | model.train() 769 | inputs = self._prepare_inputs(inputs) 770 | 771 | if is_sagemaker_mp_enabled(): 772 | scaler = self.scaler if self.use_amp else None 773 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) 774 | return loss_mb.reduce_mean().detach().to(self.args.device) 775 | 776 | if self.use_amp: 777 | with autocast(): 778 | loss = self.compute_loss(model, inputs) 779 | else: 780 | loss = self.compute_loss(model, inputs) 781 | 782 | if self.args.n_gpu > 1: 783 | loss = loss.mean() # mean() to average on multi-gpu parallel training 784 | 785 | if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: 786 | # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` 787 | loss = loss / self.args.gradient_accumulation_steps 788 | 789 | if self.use_amp: 790 | self.scaler.scale(loss).backward() 791 | elif self.use_apex: 792 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 793 | scaled_loss.backward() 794 | elif self.deepspeed: 795 | # loss gets scaled under gradient_accumulation_steps in deepspeed 796 | loss = self.deepspeed.backward(loss) 797 | else: 798 | loss.backward() 799 | 800 | return loss.detach() 801 | 802 | def compute_loss(self, model, inputs, return_outputs=False): 803 | 804 | if self.label_smoother is not None and "labels" in inputs: 805 | labels = inputs.pop("labels") 806 | else: 807 | labels = None 808 | 809 | outputs = model(**inputs) 810 | # Save past state if it exists 811 | # TODO: this needs to be fixed and made cleaner later. 812 | if self.args.past_index >= 0: 813 | self._past = outputs[self.args.past_index] 814 | 815 | if labels is not None: 816 | loss = self.label_smoother(outputs, labels) 817 | else: 818 | # We don't use .loss here since the model may return tuples instead of ModelOutput. 819 | loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] 820 | 821 | return (loss, outputs) if return_outputs else loss 822 | -------------------------------------------------------------------------------- /preprocess_datasets/get_description.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import spacy 4 | from nltk import word_tokenize 5 | 6 | import nltk 7 | nltk.download('punkt') 8 | 9 | wikidata_descr = { 10 | 'walmart': 'U.S. discount retailer based in Arkansas.', 11 | 'wyoming': 'least populous state of the United States of America', 12 | 'safeway': 'American supermarket chain', 13 | 'mcdonalds': 'American fast food restaurant chain', 14 | 'washington d.c': 'capital city of the United States', 15 | 'espn': 'American pay television sports network', 16 | 'windows 95': 'operating system from Microsoft' 17 | } 18 | my_mapping = { 19 | 'jumping rope': 'jump rope', 20 | 'eden': 'Eden', 21 | 'contemplating': 'contemplate', 22 | 'rehabilitating': 'rehabilitate', 23 | 'catalog': 'catalogue', 24 | 'works': 'work', 25 | 'hoping': 'hope', 26 | 'wetlands': 'wetland', 27 | 'waiting': 'wait', 28 | 'sunglass': 'sunglasses', 29 | 'centre': 'center', 30 | 'bath room': 'bathroom', 31 | 'phd': 'ph.d.', 32 | 'sunglasses': 'sunglasses', 33 | } 34 | 35 | patterns = [ 36 | 'plural of ', 37 | "past participle of ", 38 | "present participle of ", 39 | "third-person singular simple present indicative form of ", 40 | "alternative form of ", 41 | "alternative spelling of ", 42 | "alternative letter-case form of", 43 | "obsolete form of", 44 | "non-oxford british english standard spelling of", 45 | "obsolete spelling", 46 | ] 47 | 48 | nlp=spacy.load('en_core_web_sm') 49 | 50 | 51 | bad_form_of = [] 52 | def lemma_first(qc): 53 | words = nlp(qc) 54 | qc_words = [w.text for w in words] 55 | lemma_word = words[0].lemma_ if words[0].lemma_ != '-PRON-' else words[0].text 56 | if qc_words[0] == lemma_word: 57 | return qc, qc_words 58 | else: 59 | qc_words[0] = lemma_word 60 | qc_new = ' '.join(qc_words) 61 | return qc_new, qc_words 62 | 63 | 64 | def check_my_rules(meaning): 65 | 66 | for p in patterns: 67 | if p in meaning.lower(): 68 | matched_word = meaning.split(p)[-1] 69 | # print(meaning, matched_word) 70 | return matched_word 71 | return None 72 | 73 | 74 | def resolve_meaning(qc, wik_dict, round=0): 75 | 76 | if round > 3: return None 77 | 78 | qc = qc.lower() 79 | if qc in wikidata_descr: 80 | return wikidata_descr[qc] 81 | if qc == '': 82 | return None 83 | if qc in my_mapping: 84 | print('replacing {} with {}'.format(qc, my_mapping[qc])) 85 | qc = my_mapping[qc] 86 | if qc in wik_dict: 87 | for meaning in wik_dict[qc]: 88 | if 'senses' in meaning: 89 | for sense in meaning['senses']: 90 | if 'glosses' in sense: 91 | mstr = '{}'.format(sense['glosses'][0]) 92 | if 'surname' in mstr.lower() or 'given name' in mstr.lower(): 93 | return 'a surname / given name in English.' 94 | qc_new = check_my_rules(mstr) 95 | if not qc_new: 96 | return mstr 97 | else: 98 | return resolve_meaning(qc_new, wik_dict, round+1) 99 | return None 100 | 101 | 102 | def remove_upprintable_chars(s): 103 | return ''.join(x for x in s if x.isprintable()) 104 | 105 | 106 | def skip_special_tokens(s): 107 | if not bool(re.search('[A-Za-z]', s)): 108 | return False 109 | elif len(s) < 3: 110 | return False 111 | else: 112 | return True 113 | 114 | 115 | def construct_dict_mapping_file(ipath, opath, wik_dict): 116 | with open(ipath, 'r', encoding='utf-8') as f, \ 117 | open(opath, 'w', encoding='utf-8') as g: 118 | for _, line in enumerate(f.readlines()): 119 | vocab, count = line.strip().split('\t') 120 | if count == 'None': continue 121 | if int(count) > 3000: continue 122 | if not skip_special_tokens(vocab): continue 123 | meaning = resolve_meaning(vocab, wik_dict) 124 | 125 | if not meaning: continue 126 | 127 | if not vocab.isprintable(): 128 | vocab = remove_upprintable_chars(vocab) 129 | if not meaning.isprintable(): 130 | meaning = remove_upprintable_chars(meaning) 131 | 132 | meaning = ' '.join(word_tokenize(meaning)) 133 | line = {'word': vocab, 'text': meaning} 134 | g.write(f'{json.dumps(line)}\n') 135 | 136 | 137 | def load_dict(ipath, opath, wikidict): 138 | return construct_dict_mapping_file(ipath, opath, wikidict) 139 | -------------------------------------------------------------------------------- /preprocess_datasets/load_dataset.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import json 3 | from datasets import load_dataset 4 | 5 | for name in ['cola', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb']: 6 | 7 | os.makedirs(os.path.join(os.path.abspath(os.pardir), 'glue_datasets', name), exist_ok=True) 8 | 9 | dataset = load_dataset('glue', name) 10 | print(dataset) 11 | 12 | for k in ['train', 'validation', 'test']: 13 | 14 | out_file = open(f'datasets/{name}/{k}.json', 'w') 15 | 16 | for line in dataset[k]: 17 | line = json.dumps(line) 18 | 19 | out_file.write(f'{line}\n') 20 | 21 | 22 | for name in ['mnli']: 23 | 24 | os.makedirs(os.path.join(os.path.abspath(os.pardir), 'glue_datasets', name), exist_ok=True) 25 | 26 | dataset = load_dataset('glue', name) 27 | print(dataset) 28 | 29 | for k in ['train', 'validation_matched', 'test_matched', 'validation_mismatched', 'test_mismatched']: 30 | 31 | out_file = open(f'datasets/{name}/{k}.json', 'w') 32 | 33 | for line in dataset[k]: 34 | line = json.dumps(line) 35 | 36 | out_file.write(f'{line}\n') 37 | -------------------------------------------------------------------------------- /preprocess_datasets/load_preprocess.sh: -------------------------------------------------------------------------------- 1 | python load_dataset.py 2 | python preprocess.py 3 | python select_word.py -------------------------------------------------------------------------------- /preprocess_datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import json 4 | import string 5 | 6 | # extract_patterns does the following things: 7 | # - special patterns are extracted 8 | # - email addresses 9 | # - urls 10 | # - files 11 | # - tokenize by hyphens in words, light-hearted etc. 12 | # 13 | # !! requires filter_and_cleanup_lines.py. 14 | # !! requires all lowercase input. 15 | def extract_patterns(line): 16 | email = r'^([a-z0-9_\-\.]+)@((\[[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.)|(([a-z0-9\-]+\.)+))([a-z]{2,4}|[0-9]{1,3})(\]?)$' 17 | # ref: https://www.w3schools.com/php/php_form_url_email.asp 18 | url1 = r'^(?:(?:https?|ftp):\/\/|www\.)[-a-z0-9+&@#\/%?=~_|!:,.;]*[-a-z0-9+&@#\/%=~_|]$' 19 | # simple fooo-bar.com cases without the prefix 20 | url2 = r'^[^$s]{3}[^$s]*\.(?:com|net|org|edu|gov)$' 21 | # file: prefix len >=5, suffix given. 22 | file = r'^[a-z_-][a-z0-9_-]{4}[a-z0-9_-]*\.(?:pdf|mp4|mp3|doc|xls|docx|ppt|pptx|wav|wma|csv|tsv|cpp|py|bat|reg|png|jpg|mov|avi|gif|rtf|txt|bmp|mid)$' 23 | newline = '' 24 | for w in line.split(): 25 | w = re.sub(url1, '', w) 26 | w = re.sub(url2, '', w) 27 | w = re.sub(email, '', w) 28 | w = re.sub(file, '', w) 29 | w = ' - '.join(w.split('-')) 30 | newline += ' ' + w 31 | return newline.lstrip() 32 | 33 | 34 | # pre-process 35 | def pre_cleanup(line): 36 | line = line.replace('\t', ' ') # replace tab with spaces 37 | line = ' '.join(line.strip().split()) # remove redundant spaces 38 | line = re.sub(r'\.{4,}', '...', line) # remove extra dots 39 | line = line.replace('<<', '«').replace('>>', '»') # group << together 40 | line = re.sub(' (,:\.\)\]»)', r'\1', line) # remove space before >> 41 | line = re.sub('(\[\(«) ', r'\1', line) # remove space after << 42 | line = line.replace(',,', ',').replace(',.', '.') # remove redundant punctuations 43 | line = re.sub(r' \*([^\s])', r' \1', line) # remove redundant asterisks 44 | return ' '.join(line.strip().split()) # remove redundant spaces 45 | 46 | 47 | # post_cleanup does the following things: 48 | # - remove all backslashes 49 | # - normalize and remove redundant spaces (including \t, etc.) 50 | # - tokenize by spaces, and start/end puncts, normalize each word 51 | # - puncts in the middle are regarded as a part of the word, for example, 52 | # 1.23, y'all, etc. 53 | # - replace '...' with ' ... ' 54 | def post_cleanup(line): 55 | line = re.sub(r'\\', ' ', line) # remove all backslashes 56 | line = re.sub(r'\s\s+', ' ', line) # remove all redundant spaces 57 | line = re.sub(r'\.\.\.', ' ... ', line) 58 | line = re.sub(r'\.\.', ' .. ', line) 59 | newline = '' 60 | for w in line.split(): 61 | ls = w.lstrip(string.punctuation) 62 | rs = ls.rstrip(string.punctuation) 63 | lw = len(w) 64 | lstart = lw - len(ls) 65 | rstart = lstart + len(rs) 66 | for i in range(lstart): 67 | newline += ' ' + w[i] 68 | if rs: 69 | newline += ' ' + rs 70 | for i in range(rstart, lw): 71 | newline += ' ' + w[i] 72 | return newline.lstrip() 73 | 74 | 75 | def preprocess_text(line): 76 | line = line.strip().lower() 77 | line = pre_cleanup(line) 78 | line = post_cleanup(line) 79 | return line 80 | 81 | 82 | def get_all_files(path): 83 | if os.path.isfile(path): return [path] 84 | return [f for d in os.listdir(path) 85 | for f in get_all_files(os.path.join(path, d))] 86 | 87 | 88 | task_to_keys = { 89 | "cola": ("sentence"), 90 | "mnli": ("premise", "hypothesis"), 91 | "mrpc": ("sentence1", "sentence2"), 92 | "qnli": ("question", "sentence"), 93 | "qqp": ("question1", "question2"), 94 | "rte": ("sentence1", "sentence2"), 95 | "sst2": ("sentence"), 96 | "stsb": ("sentence1", "sentence2"), 97 | "wnli": ("sentence1", "sentence2"), 98 | } 99 | 100 | def main(folder_path): 101 | 102 | all_files = get_all_files(folder_path) 103 | for idx, file in enumerate(all_files): 104 | if not file.endswith('.json'): continue 105 | if file.endswith('prc.json'): continue 106 | 107 | print(f'Pre-processing the file {file}, in total {idx+1}/{len(all_files)} files.') 108 | 109 | new_file_path = file.replace('.json', '.prc.json') 110 | 111 | with open(file, 'r') as f, open(new_file_path, 'w') as g: 112 | for idx, line in enumerate(f.readlines()): 113 | line = json.loads(line) 114 | 115 | if line.get('sentence'): 116 | line['sentence'] = preprocess_text(line['sentence']) 117 | if line.get('sentence1'): 118 | line['sentence1'] = preprocess_text(line['sentence1']) 119 | if line.get('sentence2'): 120 | line['sentence2'] = preprocess_text(line['sentence2']) 121 | if line.get('question'): 122 | line['question'] = preprocess_text(line['question']) 123 | if line.get('question1'): 124 | line['question1'] = preprocess_text(line['question1']) 125 | if line.get('question2'): 126 | line['question2'] = preprocess_text(line['question2']) 127 | if line.get('premise'): 128 | line['premise'] = preprocess_text(line['premise']) 129 | if line.get('hypothesis'): 130 | line['hypothesis'] = preprocess_text(line['hypothesis']) 131 | 132 | line = json.dumps(line) 133 | g.write(f'{line}\n') 134 | 135 | 136 | if __name__ == '__main__': 137 | main(os.path.join(os.path.abspath(os.pardir), 'glue_datasets')) 138 | -------------------------------------------------------------------------------- /preprocess_datasets/select_word.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from collections import Counter 5 | from get_description import load_dict 6 | 7 | def get_all_files(path): 8 | if os.path.isfile(path): return [path] 9 | return [f for d in os.listdir(path) 10 | for f in get_all_files(os.path.join(path, d))] 11 | 12 | def replace_special_tokens(s): 13 | return s.replace('-', ' ').replace('\'s', ' ').replace('/', ' ').replace('\'', ' ').replace('\"', ' ').replace('?', ' ').split() 14 | 15 | 16 | wiktionary_path = os.path.join( 17 | os.path.abspath(os.pardir), 'preprocess_wiktionary', 'wiktionary.json') 18 | 19 | wiktionary = json.load(open(wiktionary_path, encoding='utf-8')) 20 | print('wiktionary is successfully loaded!') 21 | print('Started!') 22 | 23 | ifolder = os.path.join(os.path.abspath(os.pardir), 'glue_datasets') 24 | for task in ['cola', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'mnli']: 25 | print(task) 26 | task_files = get_all_files(os.path.join(ifolder, task)) 27 | 28 | vocab_path = os.path.join(ifolder, task, 'vocab.txt') 29 | rare_vocab_path = os.path.join(ifolder, task, 'vocab.rare.txt') 30 | output_dict_path = os.path.join(ifolder, task, 'vocab.90.json') 31 | 32 | vocab_file = open(vocab_path, 'w') 33 | rare_vocab = open(rare_vocab_path, 'w') 34 | 35 | word_collections = [] 36 | for idx, file in enumerate(task_files): 37 | if not file.endswith('prc.json'): continue 38 | 39 | with open(file, 'r', encoding='utf-8') as f: 40 | for idx, line in enumerate(f.readlines()): 41 | line = json.loads(line) 42 | 43 | for key, value in line.items(): 44 | if key in ['idx', 'labels', 'label']: continue 45 | word_collections += replace_special_tokens(value) 46 | 47 | word2freq = Counter(word_collections) 48 | 49 | vocab = sorted(word2freq.items(), key=lambda k: k[1], reverse=True) 50 | threshold = np.sum(list(word2freq.values())) * 0.90 51 | 52 | for word, count in vocab: 53 | vocab_file.write('{}\t{}\n'.format(word, count)) 54 | if threshold - count > 0: 55 | threshold -= count 56 | continue 57 | rare_vocab.write('{}\t{}\n'.format(word, count)) 58 | 59 | load_dict(rare_vocab_path, output_dict_path, wiktionary) 60 | -------------------------------------------------------------------------------- /preprocess_wiktionary/construct_wiktionary.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pdb 3 | from collections import defaultdict 4 | 5 | wik_file = 'kaikki.org-dictionary-English.json' 6 | wik_dict = defaultdict(list) 7 | 8 | with open(wik_file, 'r', encoding='utf-8') as f: 9 | 10 | counter = 0 11 | for line in f: 12 | if counter % 100000 == 0: 13 | print(f'preprocessed {counter} items') 14 | data = json.loads(line) 15 | if 'redirect' in data: 16 | pdb.set_trace() 17 | data.pop('pronunciations', 'no') 18 | data.pop('lang', 'no') 19 | data.pop('translations', 'no') 20 | data.pop('sounds', 'no') 21 | wik_dict[data['word']].append(data) 22 | counter += 1 23 | 24 | # join lower case with upper case 25 | new_wik_dict = {} 26 | for word in wik_dict: 27 | if word.lower() not in new_wik_dict: 28 | new_wik_dict[word.lower()] = wik_dict[word] 29 | elif word.lower() == word: 30 | new_wik_dict[word.lower()] = wik_dict[word] + new_wik_dict[word.lower()] 31 | else: 32 | new_wik_dict[word.lower()] = new_wik_dict[word.lower()] + wik_dict[word] 33 | 34 | wik_dict = new_wik_dict 35 | output_file = 'wiktionary.json' 36 | json.dump(wik_dict, open(output_file, 'w', encoding='utf-8')) -------------------------------------------------------------------------------- /preprocess_wiktionary/download_wiktionary.sh: -------------------------------------------------------------------------------- 1 | # License: cc-by-4.0 2 | 3 | # download the Wiktionary file 4 | curl -O https://kaikki.org/dictionary/English/kaikki.org-dictionary-English.json 5 | 6 | # data example ... 7 | # {"pos": "noun", "heads": [{"template_name": "en-noun"}], "forms": [{"form": "zymurgies", "tags": ["plural"]}], "word": "zymurgy", "lang": "English", "lang_code": "en", "senses": [{"glosses": ["The chemistry of fermentation with yeasts, especially the science involved in beer and winemaking."], "derived": [{"word": "zymurgic"}, {"word": "zymurgical"}, {"word": "zymurgist"}], "related": [{"word": "zythepsary"}], "categories": ["Beer", "Zymurgy"], "id": "zymurgy-noun"}]} 8 | 9 | # clean the downloaded Wiktionary file 10 | python construct_wiktionary.py 11 | 12 | # remove the original file 13 | rm kaikki.org-dictionary-English.json --------------------------------------------------------------------------------