├── README.md ├── callbacks.py ├── finetune_bart_pl.py ├── finetune_gpt2.py ├── finetune_xlnet.py ├── lightning_base.py ├── manual.pdf ├── util_summarization.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Project 4 | This repository is the basic implementation and introduction for utilizing language models for reconstructing implicit knowledge, as described in our paper (Becker et al. 2021). 5 | 6 | ## Important Installation Requirements 7 | PyTorch (conda environment) 8 | 9 | `conda install pytorch torchvision cudatoolkit=10.1 -c pytorch ` 10 | 11 | Install tensorflow-datasets 12 | 13 | `pip install tensorflow-datasets` 14 | 15 | `pip install tensorflow-gpu` 16 | 17 | Install transformers 18 | 19 | `pip install transformers` 20 | 21 | `pip install pytorch-lightning` 22 | 23 | ## Data examples 24 | Source Sentence 1: Not everyone should be obliged to pay the TV & radio licence. 25 | Source Sentence 2: Particularly the younger generations are no longer dependent on the programming of public broadcasters. 26 | Concept1: public broadcasters 27 | Concept2: financed by the TV & radio licence 28 | Path predicted by CONNECT: (public broadcasters, financed by the TV & radio licence) -> receives action 29 | Target Sentence: Public broadcasters are financed by the TV & radio licence. 30 | 31 | 32 | ## Models to be fine-tuned 33 | [GPT-2](https://github.com/openai/gpt-2) | [XLNet](https://github.com/zihangdai/xlnet) | [BART](https://github.com/pytorch/fairseq/tree/master/examples/bart) 34 | 35 | Our best performing language model BART, finetuned on e-SLNI (without constraints; with concepts as constraints; and with commonsense knowledge paths as constraints), can be downloaded from [here](https://drive.google.com/drive/folders/1FBiBlB_-V-6wgfIjUN_0JUn7VsEXO117). 36 | 37 | ### Prepare training data (exclude the target sentence for test data in each line in case of GPT-2 and XLNet models) 38 | line = [sentence1, sentence2, concept1, concept2, target sentence] 39 | 40 | training: 41 | GPT-2: 42 | source line: line[0]+'<|endoftext|>' + line[1] + '<|endoftext|>' + line[2] + '' + line[3] + '' + line[4] + '\n' 43 | XLNet: 44 | source line: line[0]+'' + line[1] + '' + line[2] + '' + line[3] + '' + line[4] + '\n' 45 | BART: 46 | source line: line[0]+'' + line[1] + '' + line[2] + '' + line[3]+ '\n' 47 | target line: line[4] +'\n' 48 | 49 | testing: 50 | GPT-2: 51 | source line: line[0]+'<|endoftext|>' + line[1] + '<|endoftext|>' + line[2] + '' + line[3] + '\n' 52 | XLNet: 53 | source line: line[0]+'' + line[1] + '' + line[2] + '' + line[3] + '\n' 54 | BART: 55 | source line: line[0]+'' + line[1] + '' + line[2] + '' + line[3]+ '\n' 56 | 57 | target line: line[4] +'\n' 58 | 59 | ### Write the prepared lines into files 60 | 61 | from preprocess import write_gpt2_file, write_xlnet_file, write_bart_file 62 | 63 | For example: 64 | 65 | gpt2_path = 'data/gpt2/' 66 | file = 'train.source' 67 | write_gpt2_file(gpt2_train_lines, gpt2_path + file, mode ='train') 68 | 69 | Write the prepared training sources lines for GPT-2 and XLNet model with pad tokens (BART model pads the data itself during training): 70 | 71 | from transformers import AutoTokenizer 72 | from preprocess import pad_sources 73 | 74 | gpt2_tokenizer = AutoTokenizer('gpt2') 75 | xlnet_tokenizer = AutoTokenizer('xlnet-large-cased') 76 | 77 | pad_sources(gpt2_tokenizer, new_path, gpt2_train_lines, block_size, model = 'gpt2') 78 | pad_sources(xlnet_tokenizer, new_path, xlnet_train_lines, block_size, model = 'xlnet') 79 | 80 | 81 | ## Fine-tuning 82 | GPT-2: 83 | python finetune_gpt2.py \ 84 | --model_name_or_path=gpt2 \ 85 | --model_type=gpt2 \ 86 | --per_device_train_batch_size=8 \ 87 | --per_gpu_train_batch_size=8 \ 88 | --train_data_file=data/esnli/gpt2/train.source \ 89 | --valid_data_file=data/esnli/gpt2/valid.source 90 | --output_dir=./finetune_gpt2_esnli \ 91 | --do_train \ 92 | --block_size=96 \ 93 | --save_steps=500 \ 94 | --save_total_limit=1 \ 95 | 96 | XLNet: (the block_size has to be even) 97 | python finetune_xlnet.py \ 98 | --model_name_or_path=xlnet-large-cased \ 99 | --model_type=xlnet \ 100 | --per_device_train_batch_size=8 \ 101 | --per_gpu_train_batch_size=8 \ 102 | --train_data_file=data/esnli/xlnet/train.source \ 103 | --output_dir=./finetune_xlnet_esnli_heads_nl \ 104 | --save_steps=500 \ 105 | --block_size=96 \ 106 | --save_total_limit=1 \ 107 | --do_train 108 | 109 | BART: 110 | python finetune_bart_pl.py \ 111 | --model_name_or_path=facebook/bart-large-cnn \ 112 | --tokenizer_name=facebook/bart-large-cnn \ 113 | --learning_rate=3e-5 \ 114 | --gpus=1 \ 115 | --num_train_epochs=3 \ 116 | --max_source_length=80 \ 117 | --max_target_length=20 \ 118 | --train_batch_size=8 \ 119 | --data_dir=../data/esnli/bart/ \ 120 | --output_dir=./finetune_bart_esnli \ 121 | --do_train 122 | 123 | ## Generation 124 | Source lines should be prepared differently for each type of model: 125 | GPT-2: 126 | source line: line[0]+'<|endoftext|>' + line[1] + '<|endoftext|>' + line[2] + '' + line[3] + '\n' 127 | XLNet: 128 | source line: line[0]+'' + line[1] + '' + line[2] + '' + line[3] + '\n' 129 | BART: 130 | source line: line[0]+'' + line[1] + '' + line[2] + '' + line[3]+ '\n' 131 | 132 | 133 | Script: lm_generate.py 134 | 135 | The required arguments for running the generation script: 136 | --model_path: where the fine-tuned model directory is stored 137 | --model_type: gpt2, xlnet or bart 138 | --test_src: the path to the test source file 139 | --save_path: where to save the generations 140 | 141 | For example: 142 | GPT-2: 143 | python lm_generate.py \ 144 | --model_path=finetune_gpt2_esnli_corec_path \ 145 | --model_type=gpt2 \ 146 | --test_src=data/ikat/ikat_test_corec_path.gpt2_source \ 147 | --save_path=data/esnli/ikat_test_corec_path.gpt2_pred 148 | 149 | XLNet: 150 | python lm_generate.py \ 151 | --model_path=finetune_xlnet_esnli_corec_path \ 152 | --model_type=xlnet \ 153 | --test_src=data/ikat/ikat_test_corec_path.xlnet_source \ 154 | --save_path=data/esnli/ikat_test_corec_path.xlnet_pred 155 | 156 | BART: 157 | python lm_generate.py \ 158 | --model_path=seq2seq/finetune_bart_esnli_corec_path/best_tfmr \ 159 | --model_type=bart \ 160 | --test_src=data/ikat/ikat_test_corec_path.bart_source \ 161 | --save_path=data/esnli/ikat_test_corec_path.bart_pred 162 | 163 | ### postprocess 164 | from postprocess import process_generations, write_generations 165 | 166 | For example: 167 | 1. get the predicted lines from one prediction file: 168 | path = 'data/esnli/ikat_test_corec_path.gpt2_pred' 169 | lines = [line.strip() for line in open(path).readlines()] 170 | generations = process_generations(lines, model_name='gpt2') 171 | 172 | 2. postprocess and write the generations for the prediction files of one model: 173 | path = 'data/esnli/gpt2/' 174 | new_path = 'generations/esnli/gpt2/' 175 | write_generations(path, new_path, model_name = 'gpt2') 176 | 177 | 178 | If you use our model, please cite: 179 | 180 | Becker, M., Liang, S., and Frank, A. (2021c). Reconstructing Implicit Knowledge with Language Models. Accepted at: Deep Learning Inside Out (DeeLIO): Workshop on Knowledge Extraction and Integration for Deep Learning Architectures. 181 | 182 | For questions or comments email us: mbecker@cl.uni-heidelberg.de 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 9 | from pytorch_lightning.utilities import rank_zero_only 10 | 11 | 12 | def count_trainable_parameters(model): 13 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 14 | params = sum([np.prod(p.size()) for p in model_parameters]) 15 | return params 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class Seq2SeqLoggingCallback(pl.Callback): 22 | def on_batch_end(self, trainer, pl_module): 23 | lrs = {f"lr_group_{i}": param["lr"] for i, param in enumerate(pl_module.trainer.optimizers[0].param_groups)} 24 | pl_module.logger.log_metrics(lrs) 25 | 26 | @rank_zero_only 27 | def _write_logs( 28 | self, trainer: pl.Trainer, pl_module: pl.LightningModule, type_path: str, save_generations=True 29 | ) -> None: 30 | logger.info(f"***** {type_path} results at step {trainer.global_step:05d} *****") 31 | metrics = trainer.callback_metrics 32 | trainer.logger.log_metrics({k: v for k, v in metrics.items() if k not in ["log", "progress_bar", "preds"]}) 33 | # Log results 34 | od = Path(pl_module.hparams.output_dir) 35 | if type_path == "test": 36 | results_file = od / "test_results.txt" 37 | generations_file = od / "test_generations.txt" 38 | else: 39 | # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json 40 | # If people want this it will be easy enough to add back. 41 | results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" 42 | generations_file = od / f"{type_path}_generations/{trainer.global_step:05d}.txt" 43 | results_file.parent.mkdir(exist_ok=True) 44 | generations_file.parent.mkdir(exist_ok=True) 45 | with open(results_file, "a+") as writer: 46 | for key in sorted(metrics): 47 | if key in ["log", "progress_bar", "preds"]: 48 | continue 49 | val = metrics[key] 50 | if isinstance(val, torch.Tensor): 51 | val = val.item() 52 | msg = f"{key}: {val:.6f}\n" 53 | writer.write(msg) 54 | 55 | if not save_generations: 56 | return 57 | 58 | if "preds" in metrics: 59 | content = "\n".join(metrics["preds"]) 60 | generations_file.open("w+").write(content) 61 | 62 | @rank_zero_only 63 | def on_train_start(self, trainer, pl_module): 64 | try: 65 | npars = pl_module.model.model.num_parameters() 66 | except AttributeError: 67 | npars = pl_module.model.num_parameters() 68 | 69 | n_trainable_pars = count_trainable_parameters(pl_module) 70 | # mp stands for million parameters 71 | trainer.logger.log_metrics({"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}) 72 | 73 | @rank_zero_only 74 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 75 | return self._write_logs(trainer, pl_module, "test") 76 | 77 | 78 | def get_checkpoint_callback(output_dir, metric): 79 | """Saves the best model by validation ROUGE2 score.""" 80 | if metric == "rouge2": 81 | exp = "{val_avg_rouge2:.4f}-{step_count}" 82 | elif metric == "bleu": 83 | exp = "{val_avg_bleu:.4f}-{step_count}" 84 | else: 85 | raise NotImplementedError( 86 | f"seq2seq callbacks only support rouge2 and bleu, got {metric}, You can make your own by adding to this function." 87 | ) 88 | 89 | checkpoint_callback = ModelCheckpoint( 90 | filepath=os.path.join(output_dir, exp), 91 | monitor=f"val_{metric}", 92 | mode="max", 93 | save_top_k=1, 94 | period=0, # maybe save a checkpoint every time val is run, not just end of epoch. 95 | ) 96 | return checkpoint_callback 97 | 98 | 99 | def get_early_stopping_callback(metric, patience): 100 | return EarlyStopping(monitor=f"val_{metric}", mode="max", patience=patience, verbose=True,) -------------------------------------------------------------------------------- /finetune_bart_pl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import logging 4 | import os 5 | import time 6 | import warnings 7 | from collections import defaultdict 8 | from pathlib import Path 9 | from typing import Dict, List, Tuple 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import DataLoader 14 | import pytorch_lightning as pl 15 | from lightning_base import BaseTransformer, add_generic_args, generic_train 16 | from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup 17 | 18 | 19 | try: 20 | from .util_summarization import ( 21 | assert_all_frozen, 22 | use_task_specific_params, 23 | lmap, 24 | flatten_list, 25 | pickle_save, 26 | save_git_info, 27 | save_json, 28 | freeze_params, 29 | calculate_rouge, 30 | get_git_info, 31 | ROUGE_KEYS, 32 | calculate_bleu_score, 33 | Seq2SeqDataset, 34 | TranslationDataset, 35 | label_smoothed_nll_loss, 36 | ) 37 | from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback 38 | except ImportError: 39 | from util_summarization import ( 40 | Seq2SeqDataset, 41 | TranslationDataset, 42 | assert_all_frozen, 43 | use_task_specific_params, 44 | lmap, 45 | flatten_list, 46 | pickle_save, 47 | save_git_info, 48 | save_json, 49 | freeze_params, 50 | calculate_rouge, 51 | get_git_info, 52 | ROUGE_KEYS, 53 | calculate_bleu_score, 54 | label_smoothed_nll_loss, 55 | ) 56 | from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback 57 | 58 | logger = logging.getLogger(__name__) 59 | 60 | 61 | class SummarizationModule(BaseTransformer): 62 | mode = "summarization" 63 | loss_names = ["loss"] 64 | metric_names = ROUGE_KEYS 65 | val_metric = "rouge2" 66 | 67 | def __init__(self, hparams, **kwargs): 68 | super().__init__(hparams, num_labels=None, mode=self.mode, **kwargs) 69 | use_task_specific_params(self.model, "summarization") 70 | save_git_info(self.hparams.output_dir) 71 | self.metrics_save_path = Path(self.output_dir) / "metrics.json" 72 | self.hparams_save_path = Path(self.output_dir) / "hparams.pkl" 73 | pickle_save(self.hparams, self.hparams_save_path) 74 | self.step_count = 0 75 | self.metrics = defaultdict(list) 76 | 77 | self.dataset_kwargs: dict = dict( 78 | data_dir=self.hparams.data_dir, 79 | max_source_length=self.hparams.max_source_length, 80 | prefix=self.model.config.prefix or "", 81 | ) 82 | n_observations_per_split = { 83 | "train": self.hparams.n_train, 84 | "val": self.hparams.n_val, 85 | "test": self.hparams.n_test, 86 | } 87 | self.n_obs = {k: v if v >= 0 else None for k, v in n_observations_per_split.items()} 88 | 89 | self.target_lens = { 90 | "train": self.hparams.max_target_length, 91 | "val": self.hparams.val_max_target_length, 92 | "test": self.hparams.test_max_target_length, 93 | } 94 | assert self.target_lens["train"] <= self.target_lens["val"], f"target_lens: {self.target_lens}" 95 | assert self.target_lens["train"] <= self.target_lens["test"], f"target_lens: {self.target_lens}" 96 | 97 | if self.hparams.freeze_embeds: 98 | self.freeze_embeds() 99 | if self.hparams.freeze_encoder: 100 | freeze_params(self.model.get_encoder()) 101 | #assert_all_frozen(self.model.get_encoder()) 102 | 103 | self.hparams.git_sha = get_git_info()["repo_sha"] 104 | self.num_workers = hparams.num_workers 105 | self.decoder_start_token_id = None 106 | if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer): 107 | self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang] 108 | self.model.config.decoder_start_token_id = self.decoder_start_token_id 109 | if isinstance(self.tokenizer, MBartTokenizer) or isinstance(self.tokenizer, MarianTokenizer): 110 | self.dataset_class = TranslationDataset 111 | else: 112 | self.dataset_class = Seq2SeqDataset 113 | 114 | def freeze_embeds(self): 115 | """Freeze token embeddings and positional embeddings for bart, just token embeddings for t5.""" 116 | try: 117 | freeze_params(self.model.model.shared) 118 | for d in [self.model.model.encoder, self.model.model.decoder]: 119 | freeze_params(d.embed_positions) 120 | freeze_params(d.embed_tokens) 121 | except AttributeError: 122 | freeze_params(self.model.shared) 123 | for d in [self.model.encoder, self.model.decoder]: 124 | freeze_params(d.embed_tokens) 125 | 126 | def forward(self, input_ids, **kwargs): 127 | return self.model(input_ids, **kwargs) 128 | 129 | def ids_to_clean_text(self, generated_ids: List[int]): 130 | gen_text = self.tokenizer.batch_decode( 131 | generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True 132 | ) 133 | return lmap(str.strip, gen_text) 134 | 135 | def _step(self, batch: dict) -> Tuple: 136 | pad_token_id = self.tokenizer.pad_token_id 137 | source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"] 138 | 139 | if isinstance(self.model, T5ForConditionalGeneration): 140 | decoder_input_ids = self.model._shift_right(target_ids) 141 | lm_labels = target_ids 142 | else: 143 | decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line? 144 | lm_labels = target_ids[:, 1:].clone() # why clone? 145 | 146 | outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False) 147 | 148 | if self.hparams.label_smoothing == 0: 149 | # Same behavior as modeling_bart.py 150 | loss_fct = torch.nn.CrossEntropyLoss(ignore_index=pad_token_id) 151 | lm_logits = outputs[0] 152 | assert lm_logits.shape[-1] == self.model.config.vocab_size 153 | loss = loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), lm_labels.view(-1)) 154 | else: 155 | lprobs = torch.nn.functional.log_softmax(outputs[0], dim=-1) 156 | loss, nll_loss = label_smoothed_nll_loss( 157 | lprobs, lm_labels, self.hparams.label_smoothing, ignore_index=pad_token_id 158 | ) 159 | return (loss,) 160 | 161 | @property 162 | def pad(self) -> int: 163 | return self.tokenizer.pad_token_id 164 | 165 | def training_step(self, batch, batch_idx) -> Dict: 166 | loss_tensors = self._step(batch) 167 | 168 | logs = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} 169 | # tokens per batch 170 | logs["tpb"] = batch["input_ids"].ne(self.pad).sum() + batch["decoder_input_ids"].ne(self.pad).sum() 171 | return {"loss": loss_tensors[0], "log": logs} 172 | 173 | def validation_step(self, batch, batch_idx) -> Dict: 174 | return self._generative_step(batch) 175 | 176 | def validation_epoch_end(self, outputs, prefix="val") -> Dict: 177 | self.step_count += 1 178 | losses = {k: torch.stack([x[k] for x in outputs]).mean() for k in self.loss_names} 179 | loss = losses["loss"] 180 | rouges = {k: np.array([x[k] for x in outputs]).mean() for k in self.metric_names + ["gen_time", "gen_len"]} 181 | rouge_tensor: torch.FloatTensor = torch.tensor(rouges[self.val_metric]).type_as(loss) 182 | rouges.update({k: v.item() for k, v in losses.items()}) 183 | losses.update(rouges) 184 | metrics = {f"{prefix}_avg_{k}": x for k, x in losses.items()} 185 | metrics["step_count"] = self.step_count 186 | self.save_metrics(metrics, prefix) # writes to self.metrics_save_path 187 | preds = flatten_list([x["preds"] for x in outputs]) 188 | return {"log": metrics, "preds": preds, f"{prefix}_loss": loss, f"{prefix}_{self.val_metric}": rouge_tensor} 189 | 190 | def save_metrics(self, latest_metrics, type_path) -> None: 191 | self.metrics[type_path].append(latest_metrics) 192 | save_json(self.metrics, self.metrics_save_path) 193 | 194 | def calc_generative_metrics(self, preds, target) -> Dict: 195 | return calculate_rouge(preds, target) 196 | 197 | def _generative_step(self, batch: dict) -> dict: 198 | t0 = time.time() 199 | generated_ids = self.model.generate( 200 | batch["input_ids"], 201 | attention_mask=batch["attention_mask"], 202 | use_cache=True, 203 | decoder_start_token_id=self.decoder_start_token_id, 204 | ) 205 | gen_time = (time.time() - t0) / batch["input_ids"].shape[0] 206 | preds: List[str] = self.ids_to_clean_text(generated_ids) 207 | target: List[str] = self.ids_to_clean_text(batch["decoder_input_ids"]) 208 | loss_tensors = self._step(batch) 209 | base_metrics = {name: loss for name, loss in zip(self.loss_names, loss_tensors)} 210 | rouge: Dict = self.calc_generative_metrics(preds, target) 211 | summ_len = np.mean(lmap(len, generated_ids)) 212 | base_metrics.update(gen_time=gen_time, gen_len=summ_len, preds=preds, target=target, **rouge) 213 | return base_metrics 214 | 215 | def test_step(self, batch, batch_idx): 216 | return self._generative_step(batch) 217 | 218 | def test_epoch_end(self, outputs): 219 | return self.validation_epoch_end(outputs, prefix="test") 220 | 221 | def get_dataset(self, type_path) -> Seq2SeqDataset: 222 | n_obs = self.n_obs[type_path] 223 | max_target_length = self.target_lens[type_path] 224 | dataset = self.dataset_class( 225 | self.tokenizer, 226 | type_path=type_path, 227 | n_obs=n_obs, 228 | max_target_length=max_target_length, 229 | **self.dataset_kwargs, 230 | ) 231 | return dataset 232 | 233 | def get_dataloader(self, type_path: str, batch_size: int, shuffle: bool = False) -> DataLoader: 234 | dataset = self.get_dataset(type_path) 235 | sampler = None 236 | if self.hparams.sortish_sampler and type_path == "train": 237 | assert self.hparams.gpus <= 1 # TODO: assert earlier 238 | sampler = dataset.make_sortish_sampler(batch_size) 239 | shuffle = False 240 | 241 | dataloader = DataLoader( 242 | dataset, 243 | batch_size=batch_size, 244 | collate_fn=dataset.collate_fn, 245 | shuffle=shuffle, 246 | num_workers=self.num_workers, 247 | sampler=sampler, 248 | ) 249 | return dataloader 250 | 251 | def train_dataloader(self) -> DataLoader: 252 | dataloader = self.get_dataloader("train", batch_size=self.hparams.train_batch_size, shuffle=True) 253 | t_total = ( 254 | (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.gpus))) 255 | // self.hparams.accumulate_grad_batches 256 | * float(self.hparams.max_epochs) 257 | ) 258 | scheduler = get_linear_schedule_with_warmup( 259 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total 260 | ) 261 | if max(scheduler.get_last_lr()) > 0: 262 | warnings.warn("All learning rates are 0") 263 | self.lr_scheduler = scheduler 264 | return dataloader 265 | 266 | def val_dataloader(self) -> DataLoader: 267 | return self.get_dataloader("val", batch_size=self.hparams.eval_batch_size) 268 | 269 | def test_dataloader(self) -> DataLoader: 270 | return self.get_dataloader("test", batch_size=self.hparams.eval_batch_size) 271 | 272 | @staticmethod 273 | def add_model_specific_args(parser, root_dir): 274 | BaseTransformer.add_model_specific_args(parser, root_dir) 275 | add_generic_args(parser, root_dir) 276 | parser.add_argument( 277 | "--max_source_length", 278 | default=156, 279 | type=int, 280 | help="The maximum total input sequence length after tokenization. Sequences longer " 281 | "than this will be truncated, sequences shorter will be padded.", 282 | ) 283 | parser.add_argument( 284 | "--max_target_length", 285 | default=20, 286 | type=int, 287 | help="The maximum total input sequence length after tokenization. Sequences longer " 288 | "than this will be truncated, sequences shorter will be padded.", 289 | ) 290 | parser.add_argument( 291 | "--val_max_target_length", 292 | default=20, # these defaults are optimized for CNNDM. For xsum, see README.md. 293 | type=int, 294 | help="The maximum total input sequence length after tokenization. Sequences longer " 295 | "than this will be truncated, sequences shorter will be padded.", 296 | ) 297 | parser.add_argument( 298 | "--test_max_target_length", 299 | default=20, 300 | type=int, 301 | help="The maximum total input sequence length after tokenization. Sequences longer " 302 | "than this will be truncated, sequences shorter will be padded.", 303 | ) 304 | parser.add_argument( 305 | "--data_dir", 306 | type=str, 307 | required=True, 308 | help="The input data dir. Should contain train.source, train.target, val.source, val.target, test.source, test.target", 309 | ) 310 | parser.add_argument("--freeze_encoder", action="store_true") 311 | parser.add_argument("--freeze_embeds", action="store_true") 312 | parser.add_argument("--sortish_sampler", action="store_true", default=False) 313 | parser.add_argument("--logger_name", type=str, choices=["default", "wandb", "wandb_shared"], default="default") 314 | parser.add_argument("--n_train", type=int, default=-1, required=False, help="# examples. -1 means use all.") 315 | parser.add_argument("--n_val", type=int, default=500, required=False, help="# examples. -1 means use all.") 316 | parser.add_argument("--n_test", type=int, default=-1, required=False, help="# examples. -1 means use all.") 317 | parser.add_argument( 318 | "--task", type=str, default="summarization", required=False, help="# examples. -1 means use all." 319 | ) 320 | parser.add_argument("--label_smoothing", type=float, default=0.0, required=False) 321 | parser.add_argument("--src_lang", type=str, default="", required=False) 322 | parser.add_argument("--tgt_lang", type=str, default="", required=False) 323 | parser.add_argument( 324 | "--early_stopping_patience", 325 | type=int, 326 | default=-1, 327 | required=False, 328 | help="-1 means never early stop. early_stopping_patience is measured in validation checks, not epochs. So val_check_interval will effect it.", 329 | ) 330 | return parser 331 | 332 | 333 | class TranslationModule(SummarizationModule): 334 | mode = "translation" 335 | loss_names = ["loss"] 336 | metric_names = ["bleu"] 337 | val_metric = "bleu" 338 | 339 | def __init__(self, hparams, **kwargs): 340 | super().__init__(hparams, **kwargs) 341 | self.dataset_kwargs["src_lang"] = hparams.src_lang 342 | self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang 343 | 344 | def calc_generative_metrics(self, preds, target) -> dict: 345 | return calculate_bleu_score(preds, target) 346 | 347 | 348 | def main(args, model=None) -> SummarizationModule: 349 | Path(args.output_dir).mkdir(exist_ok=True) 350 | if len(os.listdir(args.output_dir)) > 3 and args.do_train: 351 | raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) 352 | if model is None: 353 | if args.task == "summarization": 354 | model: SummarizationModule = SummarizationModule(args) 355 | else: 356 | model: SummarizationModule = TranslationModule(args) 357 | 358 | dataset = Path(args.data_dir).name 359 | if ( 360 | args.logger_name == "default" 361 | or args.fast_dev_run 362 | or str(args.output_dir).startswith("/tmp") 363 | or str(args.output_dir).startswith("/var") 364 | ): 365 | logger = True # don't pollute wandb logs unnecessarily 366 | elif args.logger_name == "wandb": 367 | from pytorch_lightning.loggers import WandbLogger 368 | 369 | project = os.environ.get("WANDB_PROJECT", dataset) 370 | logger = WandbLogger(name=model.output_dir.name, project=project) 371 | 372 | elif args.logger_name == "wandb_shared": 373 | from pytorch_lightning.loggers import WandbLogger 374 | 375 | logger = WandbLogger(name=model.output_dir.name, project=f"hf_{dataset}") 376 | 377 | if args.early_stopping_patience >= 0: 378 | es_callback = get_early_stopping_callback(model.val_metric, args.early_stopping_patience) 379 | else: 380 | es_callback = False 381 | trainer: pl.Trainer = generic_train( 382 | model, 383 | args, 384 | logging_callback=Seq2SeqLoggingCallback(), 385 | checkpoint_callback=get_checkpoint_callback(args.output_dir, model.val_metric), 386 | early_stopping_callback=es_callback, 387 | logger=logger, 388 | # TODO: early stopping callback seems messed up 389 | ) 390 | pickle_save(model.hparams, model.output_dir / "hparams.pkl") 391 | if not args.do_predict: 392 | return model 393 | 394 | model.hparams.test_checkpoint = "" 395 | checkpoints = list(sorted(glob.glob(os.path.join(args.output_dir, "*.ckpt"), recursive=True))) 396 | if checkpoints: 397 | model.hparams.test_checkpoint = checkpoints[-1] 398 | trainer.resume_from_checkpoint = checkpoints[-1] 399 | trainer.logger.log_hyperparams(model.hparams) 400 | 401 | # test() without a model tests using the best checkpoint automatically 402 | trainer.test() 403 | return model 404 | 405 | 406 | if __name__ == "__main__": 407 | parser = argparse.ArgumentParser() 408 | parser = pl.Trainer.add_argparse_args(parser) 409 | parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) 410 | 411 | args = parser.parse_args() 412 | main(args) 413 | -------------------------------------------------------------------------------- /finetune_gpt2.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | from dataclasses import dataclass, field 5 | from typing import Optional 6 | 7 | from transformers import ( 8 | CONFIG_MAPPING, 9 | MODEL_WITH_LM_HEAD_MAPPING, 10 | AutoConfig, 11 | AutoModelWithLMHead, 12 | AutoTokenizer, 13 | DataCollatorForLanguageModeling, 14 | HfArgumentParser, 15 | LineByLineTextDataset, 16 | PreTrainedTokenizer, 17 | TextDataset, 18 | Trainer, 19 | TrainingArguments, 20 | set_seed, 21 | ) 22 | 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys()) 28 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 29 | 30 | 31 | @dataclass 32 | class ModelArguments: 33 | """ 34 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 35 | """ 36 | 37 | model_name_or_path: Optional[str] = field( 38 | default=None, 39 | metadata={ 40 | "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch." 41 | }, 42 | ) 43 | model_type: Optional[str] = field( 44 | default=None, 45 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 46 | ) 47 | config_name: Optional[str] = field( 48 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 49 | ) 50 | tokenizer_name: Optional[str] = field( 51 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 52 | ) 53 | cache_dir: Optional[str] = field( 54 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 55 | ) 56 | 57 | 58 | @dataclass 59 | class DataTrainingArguments: 60 | """ 61 | Arguments pertaining to what data we are going to input our model for training and eval. 62 | """ 63 | 64 | train_data_file: Optional[str] = field( 65 | default=None, metadata={"help": "The input training data file (a text file)."} 66 | ) 67 | eval_data_file: Optional[str] = field( 68 | default=None, 69 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 70 | ) 71 | line_by_line: bool = field( 72 | default=True, 73 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, 74 | ) 75 | 76 | mlm: bool = field( 77 | default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."} 78 | ) 79 | mlm_probability: float = field( 80 | default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} 81 | ) 82 | plm_probability: float = field( 83 | default=1 / 6, 84 | metadata={ 85 | "help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling." 86 | }, 87 | ) 88 | max_span_length: int = field( 89 | default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."} 90 | ) 91 | # esnli data src length = 75, trg length = 35 92 | block_size: int = field( 93 | default=128, 94 | metadata={ 95 | "help": "Optional input sequence length after tokenization." 96 | "The training dataset will be truncated in block of this size for training." 97 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 98 | }, 99 | ) 100 | overwrite_cache: bool = field( 101 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 102 | ) 103 | 104 | 105 | def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False): 106 | file_path = args.eval_data_file if evaluate else args.train_data_file 107 | if args.line_by_line: 108 | print("get dataset block size ", args.block_size) 109 | return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) 110 | else: 111 | return TextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache 112 | ) 113 | 114 | 115 | def main(): 116 | # See all possible arguments in src/transformers/training_args.py 117 | # or by passing the --help flag to this script. 118 | # We now keep distinct sets of args, for a cleaner separation of concerns. 119 | 120 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 121 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 122 | training_args.per_device_train_batch_size = 2 123 | training_args.per_device_eval_batch_size = 2 124 | 125 | 126 | if data_args.eval_data_file is None and training_args.do_eval: 127 | raise ValueError( 128 | "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " 129 | "or remove the --do_eval argument." 130 | ) 131 | 132 | if ( 133 | os.path.exists(training_args.output_dir) 134 | and os.listdir(training_args.output_dir) 135 | and training_args.do_train 136 | and not training_args.overwrite_output_dir 137 | ): 138 | raise ValueError( 139 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 140 | ) 141 | 142 | # Setup logging 143 | logging.basicConfig( 144 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 145 | datefmt="%m/%d/%Y %H:%M:%S", 146 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 147 | ) 148 | logger.warning( 149 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 150 | training_args.local_rank, 151 | training_args.device, 152 | training_args.n_gpu, 153 | bool(training_args.local_rank != -1), 154 | training_args.fp16, 155 | ) 156 | logger.info("Training/evaluation parameters %s", training_args) 157 | 158 | # Set seed 159 | set_seed(training_args.seed) 160 | 161 | # Load pretrained model and tokenizer 162 | # 163 | # Distributed training: 164 | # The .from_pretrained methods guarantee that only one local process can concurrently 165 | # download model & vocab. 166 | 167 | if model_args.config_name: 168 | config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) 169 | elif model_args.model_name_or_path: 170 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 171 | else: 172 | config = CONFIG_MAPPING[model_args.model_type]() 173 | logger.warning("You are instantiating a new config instance from scratch.") 174 | 175 | if model_args.tokenizer_name: 176 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir) 177 | elif model_args.model_name_or_path: 178 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 179 | else: 180 | raise ValueError( 181 | "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it," 182 | "and load it from here, using --tokenizer_name" 183 | ) 184 | 185 | if model_args.model_name_or_path: 186 | model = AutoModelWithLMHead.from_pretrained( 187 | model_args.model_name_or_path, 188 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 189 | config=config, 190 | cache_dir=model_args.cache_dir, 191 | ) 192 | else: 193 | logger.info("Training new model from scratch") 194 | model = AutoModelWithLMHead.from_config(config) 195 | 196 | model.resize_token_embeddings(len(tokenizer)) 197 | 198 | if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm: 199 | raise ValueError( 200 | "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the" 201 | "--mlm flag (masked language modeling)." 202 | ) 203 | 204 | if data_args.block_size <= 0: 205 | print("tokenizer max len", tokenizer.max_len) 206 | data_args.block_size = tokenizer.max_len 207 | # Our input block size will be the max possible for the model 208 | else: 209 | data_args.block_size = min(data_args.block_size, tokenizer.max_len) 210 | 211 | # Get datasets 212 | 213 | train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None 214 | eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None 215 | #if config.model_type == "xlnet": 216 | # data_collator = DataCollatorForPermutationLanguageModeling( 217 | # tokenizer=tokenizer, plm_probability=data_args.plm_probability, max_span_length=data_args.max_span_length, 218 | # ) 219 | #else: 220 | data_collator = DataCollatorForLanguageModeling( 221 | tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability 222 | ) 223 | 224 | # Initialize our Trainer 225 | trainer = Trainer( 226 | model=model, 227 | args=training_args, 228 | data_collator=data_collator, 229 | train_dataset=train_dataset, 230 | eval_dataset=eval_dataset, 231 | prediction_loss_only=True, 232 | ) 233 | 234 | # Training 235 | if training_args.do_train: 236 | model_path = ( 237 | model_args.model_name_or_path 238 | if model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path) 239 | else None 240 | ) 241 | trainer.train(model_path=model_path) 242 | trainer.save_model() 243 | # For convenience, we also re-save the tokenizer to the same directory, 244 | # so that you can share your model easily on huggingface.co/models =) 245 | if trainer.is_world_master(): 246 | tokenizer.save_pretrained(training_args.output_dir) 247 | 248 | # Evaluation 249 | results = {} 250 | if training_args.do_eval: 251 | logger.info("*** Evaluate ***") 252 | 253 | eval_output = trainer.evaluate() 254 | 255 | perplexity = math.exp(eval_output["eval_loss"]) 256 | result = {"perplexity": perplexity} 257 | 258 | output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt") 259 | if trainer.is_world_master(): 260 | with open(output_eval_file, "w") as writer: 261 | logger.info("***** Eval results *****") 262 | for key in sorted(result.keys()): 263 | logger.info(" %s = %s", key, str(result[key])) 264 | writer.write("%s = %s\n" % (key, str(result[key]))) 265 | 266 | results.update(result) 267 | 268 | return results 269 | 270 | 271 | def _mp_fn(index): 272 | # For xla_spawn (TPUs) 273 | main() 274 | 275 | 276 | if __name__ == "__main__": 277 | main() 278 | -------------------------------------------------------------------------------- /finetune_xlnet.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | from dataclasses import dataclass, field 5 | from typing import Optional 6 | 7 | from transformers import ( 8 | CONFIG_MAPPING, 9 | MODEL_WITH_LM_HEAD_MAPPING, 10 | AutoConfig, 11 | AutoModelWithLMHead, 12 | AutoTokenizer, 13 | DataCollatorForLanguageModeling, 14 | DataCollatorForPermutationLanguageModeling, 15 | HfArgumentParser, 16 | LineByLineTextDataset, 17 | PreTrainedTokenizer, 18 | TextDataset, 19 | Trainer, 20 | TrainingArguments, 21 | set_seed, 22 | ) 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys()) 29 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 30 | 31 | 32 | @dataclass 33 | class ModelArguments: 34 | """ 35 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 36 | """ 37 | 38 | model_name_or_path: Optional[str] = field( 39 | default=None, 40 | metadata={ 41 | "help": "The model checkpoint for weights initialization. Leave None if you want to train a model from scratch." 42 | }, 43 | ) 44 | model_type: Optional[str] = field( 45 | default=None, 46 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 47 | ) 48 | config_name: Optional[str] = field( 49 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 50 | ) 51 | tokenizer_name: Optional[str] = field( 52 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 53 | ) 54 | cache_dir: Optional[str] = field( 55 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 56 | ) 57 | 58 | 59 | @dataclass 60 | class DataTrainingArguments: 61 | """ 62 | Arguments pertaining to what data we are going to input our model for training and eval. 63 | """ 64 | 65 | train_data_file: Optional[str] = field( 66 | default=None, metadata={"help": "The input training data file (a text file)."} 67 | ) 68 | eval_data_file: Optional[str] = field( 69 | default=None, 70 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 71 | ) 72 | line_by_line: bool = field( 73 | default=True, 74 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, 75 | ) 76 | 77 | mlm: bool = field( 78 | default=False, metadata={"help": "Train with masked-language modeling loss instead of language modeling."} 79 | ) 80 | mlm_probability: float = field( 81 | default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} 82 | ) 83 | plm_probability: float = field( 84 | default=1 / 6, 85 | metadata={ 86 | "help": "Ratio of length of a span of masked tokens to surrounding context length for permutation language modeling." 87 | }, 88 | ) 89 | max_span_length: int = field( 90 | default=5, metadata={"help": "Maximum length of a span of masked tokens for permutation language modeling."} 91 | ) 92 | block_size: int = field( 93 | default=152, 94 | metadata={ 95 | "help": "Optional input sequence length after tokenization." 96 | "The training dataset will be truncated in block of this size for training." 97 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 98 | }, 99 | ) 100 | overwrite_cache: bool = field( 101 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 102 | ) 103 | 104 | def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False): 105 | file_path = args.eval_data_file if evaluate else args.train_data_file 106 | if args.line_by_line: 107 | print("get dataset block size ", args.block_size) 108 | return LineByLineTextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size) 109 | else: 110 | return TextDataset(tokenizer=tokenizer, file_path=file_path, block_size=args.block_size, overwrite_cache=args.overwrite_cache 111 | ) 112 | 113 | 114 | def main(): 115 | # See all possible arguments in src/transformers/training_args.py 116 | # or by passing the --help flag to this script. 117 | # We now keep distinct sets of args, for a cleaner separation of concerns. 118 | 119 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 120 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 121 | training_args.per_device_train_batch_size = 2 122 | training_args.per_device_eval_batch_size = 2 123 | 124 | 125 | if data_args.eval_data_file is None and training_args.do_eval: 126 | raise ValueError( 127 | "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " 128 | "or remove the --do_eval argument." 129 | ) 130 | 131 | if ( 132 | os.path.exists(training_args.output_dir) 133 | and os.listdir(training_args.output_dir) 134 | and training_args.do_train 135 | and not training_args.overwrite_output_dir 136 | ): 137 | raise ValueError( 138 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 139 | ) 140 | 141 | # Setup logging 142 | logging.basicConfig( 143 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 144 | datefmt="%m/%d/%Y %H:%M:%S", 145 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 146 | ) 147 | logger.warning( 148 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 149 | training_args.local_rank, 150 | training_args.device, 151 | training_args.n_gpu, 152 | bool(training_args.local_rank != -1), 153 | training_args.fp16, 154 | ) 155 | logger.info("Training/evaluation parameters %s", training_args) 156 | 157 | # Set seed 158 | set_seed(training_args.seed) 159 | 160 | # Load pretrained model and tokenizer 161 | # 162 | # Distributed training: 163 | # The .from_pretrained methods guarantee that only one local process can concurrently 164 | # download model & vocab. 165 | 166 | if model_args.config_name: 167 | config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) 168 | elif model_args.model_name_or_path: 169 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 170 | else: 171 | config = CONFIG_MAPPING[model_args.model_type]() 172 | logger.warning("You are instantiating a new config instance from scratch.") 173 | 174 | if model_args.tokenizer_name: 175 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir) 176 | elif model_args.model_name_or_path: 177 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) 178 | else: 179 | raise ValueError( 180 | "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it," 181 | "and load it from here, using --tokenizer_name" 182 | ) 183 | 184 | if model_args.model_name_or_path: 185 | model = AutoModelWithLMHead.from_pretrained( 186 | model_args.model_name_or_path, 187 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 188 | config=config, 189 | cache_dir=model_args.cache_dir, 190 | ) 191 | else: 192 | logger.info("Training new model from scratch") 193 | model = AutoModelWithLMHead.from_config(config) 194 | 195 | model.resize_token_embeddings(len(tokenizer)) 196 | 197 | if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm: 198 | raise ValueError( 199 | "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the" 200 | "--mlm flag (masked language modeling)." 201 | ) 202 | 203 | if data_args.block_size <= 0: 204 | print("tokenizer max len", tokenizer.max_len) 205 | data_args.block_size = tokenizer.max_len 206 | # Our input block size will be the max possible for the model 207 | else: 208 | data_args.block_size = min(data_args.block_size, tokenizer.max_len) 209 | 210 | # Get datasets 211 | 212 | train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None 213 | eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None 214 | if config.model_type == "xlnet": 215 | data_collator = DataCollatorForPermutationLanguageModeling( 216 | tokenizer=tokenizer, plm_probability=data_args.plm_probability, max_span_length=data_args.max_span_length, 217 | ) 218 | else: 219 | data_collator = DataCollatorForLanguageModeling( 220 | tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability 221 | ) 222 | 223 | # Initialize our Trainer 224 | trainer = Trainer( 225 | model=model, 226 | args=training_args, 227 | data_collator=data_collator, 228 | train_dataset=train_dataset, 229 | eval_dataset=eval_dataset, 230 | prediction_loss_only=True, 231 | ) 232 | 233 | # Training 234 | if training_args.do_train: 235 | model_path = ( 236 | model_args.model_name_or_path 237 | if model_args.model_name_or_path is not None and os.path.isdir(model_args.model_name_or_path) 238 | else None 239 | ) 240 | trainer.train(model_path=model_path) 241 | trainer.save_model() 242 | # For convenience, we also re-save the tokenizer to the same directory, 243 | # so that you can share your model easily on huggingface.co/models =) 244 | if trainer.is_world_master(): 245 | tokenizer.save_pretrained(training_args.output_dir) 246 | 247 | # Evaluation 248 | results = {} 249 | if training_args.do_eval: 250 | logger.info("*** Evaluate ***") 251 | 252 | eval_output = trainer.evaluate() 253 | 254 | perplexity = math.exp(eval_output["eval_loss"]) 255 | result = {"perplexity": perplexity} 256 | 257 | output_eval_file = os.path.join(training_args.output_dir, "eval_results_lm.txt") 258 | if trainer.is_world_master(): 259 | with open(output_eval_file, "w") as writer: 260 | logger.info("***** Eval results *****") 261 | for key in sorted(result.keys()): 262 | logger.info(" %s = %s", key, str(result[key])) 263 | writer.write("%s = %s\n" % (key, str(result[key]))) 264 | 265 | results.update(result) 266 | 267 | return results 268 | 269 | 270 | def _mp_fn(index): 271 | # For xla_spawn (TPUs) 272 | main() 273 | 274 | 275 | if __name__ == "__main__": 276 | main() 277 | -------------------------------------------------------------------------------- /lightning_base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from pathlib import Path 5 | from typing import Any, Dict 6 | 7 | import pytorch_lightning as pl 8 | from pytorch_lightning.utilities import rank_zero_info 9 | 10 | from transformers import ( 11 | AdamW, 12 | AutoConfig, 13 | AutoModel, 14 | AutoModelForPreTraining, 15 | AutoModelForQuestionAnswering, 16 | AutoModelForSeq2SeqLM, 17 | AutoModelForSequenceClassification, 18 | AutoModelForTokenClassification, 19 | AutoModelWithLMHead, 20 | AutoTokenizer, 21 | PretrainedConfig, 22 | PreTrainedTokenizer, 23 | get_linear_schedule_with_warmup, 24 | ) 25 | 26 | from transformers.optimization import ( 27 | get_cosine_schedule_with_warmup, 28 | get_cosine_with_hard_restarts_schedule_with_warmup, 29 | get_linear_schedule_with_warmup, 30 | 31 | ) 32 | 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | MODEL_MODES = { 38 | "base": AutoModel, 39 | "sequence-classification": AutoModelForSequenceClassification, 40 | "question-answering": AutoModelForQuestionAnswering, 41 | "pretraining": AutoModelForPreTraining, 42 | "token-classification": AutoModelForTokenClassification, 43 | "language-modeling": AutoModelWithLMHead, 44 | "summarization": AutoModelForSeq2SeqLM, 45 | "translation": AutoModelForSeq2SeqLM, 46 | } 47 | 48 | 49 | # update this and the import above to support new schedulers from transformers.optimization 50 | arg_to_scheduler = { 51 | "linear": get_linear_schedule_with_warmup, 52 | "cosine": get_cosine_schedule_with_warmup, 53 | "cosine_w_restarts": get_cosine_with_hard_restarts_schedule_with_warmup, 54 | 55 | # '': get_constant_schedule, # not supported for now 56 | # '': get_constant_schedule_with_warmup, # not supported for now 57 | } 58 | arg_to_scheduler_choices = sorted(arg_to_scheduler.keys()) 59 | arg_to_scheduler_metavar = "{" + ", ".join(arg_to_scheduler_choices) + "}" 60 | 61 | 62 | class BaseTransformer(pl.LightningModule): 63 | def __init__( 64 | self, 65 | hparams: argparse.Namespace, 66 | num_labels=None, 67 | mode="base", 68 | config=None, 69 | tokenizer=None, 70 | model=None, 71 | **config_kwargs 72 | ): 73 | """Initialize a model, tokenizer and config.""" 74 | super().__init__() 75 | # TODO: move to self.save_hyperparameters() 76 | # self.save_hyperparameters() 77 | # can also expand arguments into trainer signature for easier reading 78 | 79 | self.save_hyperparameters(hparams) 80 | self.step_count = 0 81 | self.output_dir = Path(self.hparams.output_dir) 82 | cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None 83 | if config is None: 84 | self.config = AutoConfig.from_pretrained( 85 | self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, 86 | **({"num_labels": num_labels} if num_labels is not None else {}), 87 | cache_dir=cache_dir, 88 | **config_kwargs, 89 | ) 90 | else: 91 | self.config: PretrainedConfig = config 92 | 93 | extra_model_params = ("encoder_layerdrop", "decoder_layerdrop", "dropout", "attention_dropout") 94 | for p in extra_model_params: 95 | if getattr(self.hparams, p, None): 96 | assert hasattr(self.config, p), f"model config doesn't have a `{p}` attribute" 97 | setattr(self.config, p, getattr(self.hparams, p)) 98 | 99 | if tokenizer is None: 100 | self.tokenizer = AutoTokenizer.from_pretrained( 101 | self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, 102 | cache_dir=cache_dir, 103 | ) 104 | else: 105 | self.tokenizer: PreTrainedTokenizer = tokenizer 106 | self.model_type = MODEL_MODES[mode] 107 | if model is None: 108 | self.model = self.model_type.from_pretrained( 109 | self.hparams.model_name_or_path, 110 | from_tf=bool(".ckpt" in self.hparams.model_name_or_path), 111 | config=self.config, 112 | cache_dir=cache_dir, 113 | ) 114 | else: 115 | self.model = model 116 | 117 | def load_hf_checkpoint(self, *args, **kwargs): 118 | self.model = self.model_type.from_pretrained(*args, **kwargs) 119 | 120 | def get_lr_scheduler(self): 121 | get_schedule_func = arg_to_scheduler[self.hparams.lr_scheduler] 122 | scheduler = get_schedule_func( 123 | self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.total_steps 124 | ) 125 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} 126 | return scheduler 127 | 128 | def configure_optimizers(self): 129 | """Prepare optimizer and schedule (linear warmup and decay)""" 130 | model = self.model 131 | no_decay = ["bias", "LayerNorm.weight"] 132 | optimizer_grouped_parameters = [ 133 | { 134 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 135 | "weight_decay": self.hparams.weight_decay, 136 | }, 137 | { 138 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 139 | "weight_decay": 0.0, 140 | }, 141 | ] 142 | optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) 143 | self.opt = optimizer 144 | 145 | scheduler = self.get_lr_scheduler() 146 | 147 | return [optimizer], [scheduler] 148 | 149 | def test_step(self, batch, batch_nb): 150 | return self.validation_step(batch, batch_nb) 151 | 152 | def test_epoch_end(self, outputs): 153 | return self.validation_end(outputs) 154 | 155 | def setup(self, step): 156 | train_batch_size = self.hparams.train_batch_size 157 | dataloader = self.get_dataloader("train", train_batch_size) 158 | self.train_loader = dataloader 159 | self.total_steps = ( 160 | (len(dataloader.dataset) // (train_batch_size * max(1, self.hparams.gpus))) 161 | // self.hparams.accumulate_grad_batches 162 | * float(self.hparams.max_epochs) 163 | ) 164 | 165 | def train_dataloader(self): 166 | return self.train_loader 167 | 168 | def val_dataloader(self): 169 | return self.get_dataloader("dev", self.hparams.eval_batch_size) 170 | 171 | def test_dataloader(self): 172 | return self.get_dataloader("test", self.hparams.eval_batch_size) 173 | 174 | def _feature_file(self, mode): 175 | return os.path.join( 176 | self.hparams.data_dir, 177 | "cached_{}_{}_{}".format( 178 | mode, 179 | list(filter(None, self.hparams.model_name_or_path.split("/"))).pop(), 180 | str(self.hparams.max_seq_length), 181 | ), 182 | ) 183 | 184 | @pl.utilities.rank_zero_only 185 | def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 186 | save_path = self.output_dir.joinpath("best_tfmr") 187 | self.model.config.save_step = self.step_count 188 | self.model.save_pretrained(save_path) 189 | self.tokenizer.save_pretrained(save_path) 190 | 191 | @staticmethod 192 | def add_model_specific_args(parser, root_dir): 193 | parser.add_argument( 194 | "--model_name_or_path", 195 | default=None, 196 | type=str, 197 | required=True, 198 | help="Path to pretrained model or model identifier from huggingface.co/models", 199 | ) 200 | parser.add_argument( 201 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 202 | ) 203 | parser.add_argument( 204 | "--tokenizer_name", 205 | default=None, 206 | type=str, 207 | help="Pretrained tokenizer name or path if not the same as model_name", 208 | ) 209 | parser.add_argument( 210 | "--cache_dir", 211 | default="", 212 | type=str, 213 | help="Where do you want to store the pre-trained models downloaded from s3", 214 | ) 215 | parser.add_argument( 216 | "--encoder_layerdrop", 217 | type=float, 218 | help="Encoder layer dropout probability (Optional). Goes into model.config", 219 | ) 220 | parser.add_argument( 221 | "--decoder_layerdrop", 222 | type=float, 223 | help="Decoder layer dropout probability (Optional). Goes into model.config", 224 | ) 225 | parser.add_argument( 226 | "--dropout", type=float, help="Dropout probability (Optional). Goes into model.config", 227 | ) 228 | parser.add_argument( 229 | "--attention_dropout", type=float, help="Attention dropout probability (Optional). Goes into model.config", 230 | ) 231 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 232 | parser.add_argument( 233 | "--lr_scheduler", 234 | default="linear", 235 | choices=arg_to_scheduler_choices, 236 | metavar=arg_to_scheduler_metavar, 237 | type=str, 238 | help="Learning rate scheduler", 239 | ) 240 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 241 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 242 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 243 | parser.add_argument("--num_workers", default=4, type=int, help="kwarg passed to DataLoader") 244 | parser.add_argument("--num_train_epochs", dest="max_epochs", default=3, type=int) 245 | parser.add_argument("--train_batch_size", default=32, type=int) 246 | parser.add_argument("--eval_batch_size", default=32, type=int) 247 | 248 | 249 | class LoggingCallback(pl.Callback): 250 | def on_batch_end(self, trainer, pl_module): 251 | lr_scheduler = trainer.lr_schedulers[0]["scheduler"] 252 | lrs = {f"lr_group_{i}": lr for i, lr in enumerate(lr_scheduler.get_lr())} 253 | pl_module.logger.log_metrics(lrs) 254 | 255 | def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 256 | rank_zero_info("***** Validation results *****") 257 | metrics = trainer.callback_metrics 258 | # Log results 259 | for key in sorted(metrics): 260 | if key not in ["log", "progress_bar"]: 261 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 262 | 263 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 264 | rank_zero_info("***** Test results *****") 265 | metrics = trainer.callback_metrics 266 | # Log and save results to file 267 | output_test_results_file = os.path.join(pl_module.hparams.output_dir, "test_results.txt") 268 | with open(output_test_results_file, "w") as writer: 269 | for key in sorted(metrics): 270 | if key not in ["log", "progress_bar"]: 271 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 272 | writer.write("{} = {}\n".format(key, str(metrics[key]))) 273 | 274 | 275 | def add_generic_args(parser, root_dir) -> None: 276 | # TODO(SS): allow all pl args? parser = pl.Trainer.add_argparse_args(parser) 277 | parser.add_argument( 278 | "--output_dir", 279 | default=None, 280 | type=str, 281 | required=True, 282 | help="The output directory where the model predictions and checkpoints will be written.", 283 | ) 284 | parser.add_argument( 285 | "--fp16", 286 | action="store_true", 287 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 288 | ) 289 | 290 | parser.add_argument( 291 | "--fp16_opt_level", 292 | type=str, 293 | default="O2", 294 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 295 | "See details at https://nvidia.github.io/apex/amp.html", 296 | ) 297 | parser.add_argument("--n_tpu_cores", dest="tpu_cores", type=int) 298 | parser.add_argument("--max_grad_norm", dest="gradient_clip_val", default=1.0, type=float, help="Max gradient norm") 299 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 300 | parser.add_argument("--do_predict", action="store_true", help="Whether to run predictions on the test set.") 301 | parser.add_argument( 302 | "--gradient_accumulation_steps", 303 | dest="accumulate_grad_batches", 304 | type=int, 305 | default=1, 306 | help="Number of updates steps to accumulate before performing a backward/update pass.", 307 | ) 308 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 309 | 310 | 311 | def generic_train( 312 | model: BaseTransformer, 313 | args: argparse.Namespace, 314 | early_stopping_callback=False, 315 | logger=True, # can pass WandbLogger() here 316 | extra_callbacks=[], 317 | checkpoint_callback=None, 318 | logging_callback=None, 319 | **extra_train_kwargs 320 | ): 321 | pl.seed_everything(args.seed) 322 | 323 | # init model 324 | odir = Path(model.hparams.output_dir) 325 | odir.mkdir(exist_ok=True) 326 | 327 | # add custom checkpoints 328 | if checkpoint_callback is None: 329 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 330 | filepath=args.output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=1 331 | ) 332 | if logging_callback is None: 333 | logging_callback = LoggingCallback() 334 | 335 | train_params = {} 336 | 337 | # TODO: remove with PyTorch 1.6 since pl uses native amp 338 | if args.fp16: 339 | train_params["precision"] = 16 340 | train_params["amp_level"] = args.fp16_opt_level 341 | 342 | if args.gpus > 1: 343 | train_params["distributed_backend"] = "ddp" 344 | 345 | trainer = pl.Trainer.from_argparse_args( 346 | args, 347 | weights_summary=None, 348 | callbacks=[logging_callback] + extra_callbacks, 349 | logger=logger, 350 | checkpoint_callback=checkpoint_callback, 351 | early_stop_callback=early_stopping_callback, 352 | **train_params, 353 | ) 354 | 355 | if args.do_train: 356 | trainer.fit(model) 357 | 358 | return trainer -------------------------------------------------------------------------------- /manual.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Heidelberg-NLP/LMs4Implicit-Knowledge-Generation/7834e31026d1080395bef0165f31d90d82246185/manual.pdf -------------------------------------------------------------------------------- /util_summarization.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import linecache 4 | import os 5 | import pickle 6 | import warnings 7 | from logging import getLogger 8 | from pathlib import Path 9 | from typing import Callable, Dict, Iterable, List 10 | 11 | import git 12 | import numpy as np 13 | import torch 14 | from rouge_score import rouge_scorer, scoring 15 | from sacrebleu import corpus_bleu 16 | from torch import nn 17 | from torch.utils.data import Dataset, Sampler 18 | 19 | from transformers import BartTokenizer 20 | 21 | 22 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): 23 | """From fairseq""" 24 | if target.dim() == lprobs.dim() - 1: 25 | target = target.unsqueeze(-1) 26 | nll_loss = -lprobs.gather(dim=-1, index=target) 27 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 28 | if ignore_index is not None: 29 | pad_mask = target.eq(ignore_index) 30 | nll_loss.masked_fill_(pad_mask, 0.0) 31 | smooth_loss.masked_fill_(pad_mask, 0.0) 32 | else: 33 | nll_loss = nll_loss.squeeze(-1) 34 | smooth_loss = smooth_loss.squeeze(-1) 35 | 36 | nll_loss = nll_loss.sum() # mean()? Scared to break other math. 37 | smooth_loss = smooth_loss.sum() 38 | eps_i = epsilon / lprobs.size(-1) 39 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 40 | return loss, nll_loss 41 | 42 | 43 | def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): 44 | extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} 45 | return tokenizer( 46 | [line], 47 | max_length=max_length, 48 | padding="max_length" if pad_to_max_length else None, 49 | truncation=True, 50 | return_tensors=return_tensors, 51 | **extra_kw, 52 | ) 53 | 54 | 55 | def lmap(f: Callable, x: Iterable) -> List: 56 | """list(map(f, x))""" 57 | return list(map(f, x)) 58 | 59 | 60 | def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict: 61 | """Uses sacrebleu's corpus_bleu implementation.""" 62 | return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score} 63 | 64 | 65 | def trim_batch( 66 | input_ids, pad_token_id, attention_mask=None, 67 | ): 68 | """Remove columns that are populated exclusively by pad_token_id""" 69 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 70 | if attention_mask is None: 71 | return input_ids[:, keep_column_mask] 72 | else: 73 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 74 | 75 | 76 | class Seq2SeqDataset(Dataset): 77 | def __init__( 78 | self, 79 | tokenizer, 80 | data_dir, 81 | max_source_length, 82 | max_target_length, 83 | type_path="train", 84 | n_obs=None, 85 | src_lang=None, 86 | tgt_lang=None, 87 | prefix="", 88 | ): 89 | super().__init__() 90 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 91 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 92 | self.src_lens = self.get_char_lens(self.src_file) 93 | self.max_source_length = max_source_length 94 | self.max_target_length = max_target_length 95 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 96 | self.tokenizer = tokenizer 97 | self.prefix = prefix 98 | if n_obs is not None: 99 | self.src_lens = self.src_lens[:n_obs] 100 | self.pad_token_id = self.tokenizer.pad_token_id 101 | self.src_lang = src_lang 102 | self.tgt_lang = tgt_lang 103 | 104 | def __len__(self): 105 | return len(self.src_lens) 106 | 107 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 108 | index = index + 1 # linecache starts at 1 109 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 110 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 111 | assert source_line, f"empty source line for index {index}" 112 | assert tgt_line, f"empty tgt line for index {index}" 113 | source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length) 114 | target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length) 115 | 116 | source_ids = source_inputs["input_ids"].squeeze() 117 | target_ids = target_inputs["input_ids"].squeeze() 118 | src_mask = source_inputs["attention_mask"].squeeze() 119 | return { 120 | "input_ids": source_ids, 121 | "attention_mask": src_mask, 122 | "decoder_input_ids": target_ids, 123 | } 124 | 125 | @staticmethod 126 | def get_char_lens(data_file): 127 | return [len(x) for x in Path(data_file).open().readlines()] 128 | 129 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 130 | input_ids = torch.stack([x["input_ids"] for x in batch]) 131 | masks = torch.stack([x["attention_mask"] for x in batch]) 132 | target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) 133 | pad_token_id = self.pad_token_id 134 | y = trim_batch(target_ids, pad_token_id) 135 | source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) 136 | batch = { 137 | "input_ids": source_ids, 138 | "attention_mask": source_mask, 139 | "decoder_input_ids": y, 140 | } 141 | return batch 142 | 143 | def make_sortish_sampler(self, batch_size): 144 | return SortishSampler(self.src_lens, batch_size) 145 | 146 | 147 | class TranslationDataset(Seq2SeqDataset): 148 | """A dataset that calls prepare_seq2seq_batch.""" 149 | 150 | def __init__(self, *args, **kwargs): 151 | super().__init__(*args, **kwargs) 152 | if self.max_source_length != self.max_target_length: 153 | warnings.warn( 154 | f"Mbart is using sequence lengths {self.max_source_length}, {self.max_target_length}. " 155 | f"Imbalanced sequence lengths may be undesired for translation tasks" 156 | ) 157 | 158 | def __getitem__(self, index) -> Dict[str, str]: 159 | index = index + 1 # linecache starts at 1 160 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 161 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 162 | assert source_line, f"empty source line for index {index}" 163 | assert tgt_line, f"empty tgt line for index {index}" 164 | return { 165 | "tgt_texts": tgt_line, 166 | "src_texts": source_line, 167 | } 168 | 169 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 170 | batch_encoding = self.tokenizer.prepare_seq2seq_batch( 171 | [x["src_texts"] for x in batch], 172 | src_lang=self.src_lang, 173 | tgt_texts=[x["tgt_texts"] for x in batch], 174 | tgt_lang=self.tgt_lang, 175 | max_length=self.max_source_length, 176 | max_target_length=self.max_target_length, 177 | ) 178 | return batch_encoding.data 179 | 180 | 181 | class SortishSampler(Sampler): 182 | "Go through the text data by order of src length with a bit of randomness. From fastai repo." 183 | 184 | def __init__(self, data, batch_size): 185 | self.data, self.bs = data, batch_size 186 | 187 | def key(self, i): 188 | return self.data[i] 189 | 190 | def __len__(self) -> int: 191 | return len(self.data) 192 | 193 | def __iter__(self): 194 | idxs = np.random.permutation(len(self.data)) 195 | sz = self.bs * 50 196 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] 197 | sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) 198 | sz = self.bs 199 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] 200 | max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, 201 | ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. 202 | sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) 203 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) 204 | return iter(sort_idx) 205 | 206 | 207 | logger = getLogger(__name__) 208 | 209 | 210 | def use_task_specific_params(model, task): 211 | """Update config with summarization specific params.""" 212 | task_specific_params = model.config.task_specific_params 213 | 214 | if task_specific_params is not None: 215 | pars = task_specific_params.get(task, {}) 216 | logger.info(f"using task specific params for {task}: {pars}") 217 | model.config.update(pars) 218 | 219 | 220 | def pickle_load(path): 221 | """pickle.load(path)""" 222 | with open(path, "rb") as f: 223 | return pickle.load(f) 224 | 225 | 226 | def pickle_save(obj, path): 227 | """pickle.dump(obj, path)""" 228 | with open(path, "wb") as f: 229 | return pickle.dump(obj, f) 230 | 231 | 232 | def flatten_list(summary_ids: List[List]): 233 | return [x for x in itertools.chain.from_iterable(summary_ids)] 234 | 235 | 236 | def save_git_info(folder_path: str) -> None: 237 | """Save git information to output_dir/git_log.json""" 238 | repo_infos = get_git_info() 239 | save_json(repo_infos, os.path.join(folder_path, "git_log.json")) 240 | 241 | 242 | def save_json(content, path): 243 | with open(path, "w") as f: 244 | json.dump(content, f, indent=4) 245 | 246 | 247 | def load_json(path): 248 | with open(path) as f: 249 | return json.load(f) 250 | 251 | 252 | def get_git_info(): 253 | repo = git.Repo(search_parent_directories=True) 254 | repo_infos = { 255 | "repo_id": str(repo), 256 | "repo_sha": str(repo.head.object.hexsha), 257 | "repo_branch": str(repo.active_branch), 258 | } 259 | return repo_infos 260 | 261 | 262 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] 263 | 264 | 265 | def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: 266 | scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) 267 | aggregator = scoring.BootstrapAggregator() 268 | 269 | for reference_ln, output_ln in zip(reference_lns, output_lns): 270 | scores = scorer.score(reference_ln, output_ln) 271 | aggregator.add_scores(scores) 272 | 273 | result = aggregator.aggregate() 274 | return {k: v.mid.fmeasure * 100 for k, v in result.items()} 275 | 276 | 277 | def freeze_params(model: nn.Module): 278 | for par in model.parameters(): 279 | par.requires_grad = False 280 | 281 | 282 | def grad_status(model: nn.Module) -> Iterable: 283 | return (par.requires_grad for par in model.parameters()) 284 | 285 | 286 | def any_requires_grad(model: nn.Module) -> bool: 287 | return any(grad_status(model)) 288 | 289 | 290 | def assert_all_frozen(model): 291 | model_grads: List[bool] = list(grad_status(model)) 292 | n_require_grad = sum(lmap(int, model_grads)) 293 | npars = len(model_grads) 294 | assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" 295 | 296 | 297 | def assert_not_all_frozen(model): 298 | model_grads: List[bool] = list(grad_status(model)) 299 | npars = len(model_grads) 300 | assert any(model_grads), f"none of {npars} weights require grad" -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import linecache 4 | import os 5 | import pickle 6 | import warnings 7 | from logging import getLogger 8 | from pathlib import Path 9 | from typing import Callable, Dict, Iterable, List 10 | 11 | import git 12 | import numpy as np 13 | import torch 14 | from rouge_score import rouge_scorer, scoring 15 | from sacrebleu import corpus_bleu 16 | from torch import nn 17 | from torch.utils.data import Dataset, Sampler 18 | 19 | from transformers import BartTokenizer 20 | 21 | 22 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): 23 | """From fairseq""" 24 | if target.dim() == lprobs.dim() - 1: 25 | target = target.unsqueeze(-1) 26 | nll_loss = -lprobs.gather(dim=-1, index=target) 27 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 28 | if ignore_index is not None: 29 | pad_mask = target.eq(ignore_index) 30 | nll_loss.masked_fill_(pad_mask, 0.0) 31 | smooth_loss.masked_fill_(pad_mask, 0.0) 32 | bs = pad_mask.long().sum() 33 | else: 34 | nll_loss = nll_loss.squeeze(-1) 35 | smooth_loss = smooth_loss.squeeze(-1) 36 | bs = lprobs.shape[0] 37 | 38 | nll_loss = nll_loss.sum() # mean()? Scared to break other math. 39 | smooth_loss = smooth_loss.sum() 40 | eps_i = epsilon / lprobs.size(-1) 41 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 42 | return loss / bs, nll_loss / bs 43 | 44 | 45 | def encode_line(tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"): 46 | extra_kw = {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} 47 | return tokenizer( 48 | [line], 49 | max_length=max_length, 50 | padding="max_length" if pad_to_max_length else None, 51 | truncation=True, 52 | return_tensors=return_tensors, 53 | **extra_kw, 54 | ) 55 | 56 | 57 | def lmap(f: Callable, x: Iterable) -> List: 58 | """list(map(f, x))""" 59 | return list(map(f, x)) 60 | 61 | 62 | def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict: 63 | """Uses sacrebleu's corpus_bleu implementation.""" 64 | return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score} 65 | 66 | 67 | def trim_batch( 68 | input_ids, pad_token_id, attention_mask=None, 69 | ): 70 | """Remove columns that are populated exclusively by pad_token_id""" 71 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 72 | if attention_mask is None: 73 | return input_ids[:, keep_column_mask] 74 | else: 75 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 76 | 77 | 78 | class Seq2SeqDataset(Dataset): 79 | def __init__( 80 | self, 81 | tokenizer, 82 | data_dir, 83 | max_source_length, 84 | max_target_length, 85 | type_path="train", 86 | n_obs=None, 87 | src_lang=None, 88 | tgt_lang=None, 89 | prefix="", 90 | ): 91 | super().__init__() 92 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 93 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 94 | self.src_lens = self.get_char_lens(self.src_file) 95 | self.max_source_length = max_source_length 96 | self.max_target_length = max_target_length 97 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 98 | self.tokenizer = tokenizer 99 | self.prefix = prefix 100 | if n_obs is not None: 101 | self.src_lens = self.src_lens[:n_obs] 102 | self.pad_token_id = self.tokenizer.pad_token_id 103 | self.src_lang = src_lang 104 | self.tgt_lang = tgt_lang 105 | 106 | def __len__(self): 107 | return len(self.src_lens) 108 | 109 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 110 | index = index + 1 # linecache starts at 1 111 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 112 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 113 | assert source_line, f"empty source line for index {index}" 114 | assert tgt_line, f"empty tgt line for index {index}" 115 | source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length) 116 | target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length) 117 | 118 | source_ids = source_inputs["input_ids"].squeeze() 119 | target_ids = target_inputs["input_ids"].squeeze() 120 | src_mask = source_inputs["attention_mask"].squeeze() 121 | return { 122 | "input_ids": source_ids, 123 | "attention_mask": src_mask, 124 | "decoder_input_ids": target_ids, 125 | } 126 | 127 | @staticmethod 128 | def get_char_lens(data_file): 129 | return [len(x) for x in Path(data_file).open().readlines()] 130 | 131 | @staticmethod 132 | def trim_seq2seq_batch(batch, pad_token_id) -> tuple: 133 | y = trim_batch(batch["decoder_input_ids"], pad_token_id) 134 | source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"]) 135 | return source_ids, source_mask, y 136 | 137 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 138 | input_ids = torch.stack([x["input_ids"] for x in batch]) 139 | masks = torch.stack([x["attention_mask"] for x in batch]) 140 | target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) 141 | pad_token_id = self.pad_token_id 142 | y = trim_batch(target_ids, pad_token_id) 143 | source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) 144 | batch = { 145 | "input_ids": source_ids, 146 | "attention_mask": source_mask, 147 | "decoder_input_ids": y, 148 | } 149 | return batch 150 | 151 | def make_sortish_sampler(self, batch_size): 152 | return SortishSampler(self.src_lens, batch_size) 153 | 154 | 155 | class MBartDataset(Seq2SeqDataset): 156 | def __init__(self, *args, **kwargs): 157 | super().__init__(*args, **kwargs) 158 | if self.max_source_length != self.max_target_length: 159 | warnings.warn( 160 | f"Mbart will ignore max_target_length = {self.max_target_length} and use {self.max_source_length} for both sides." 161 | ) 162 | 163 | def __getitem__(self, index) -> Dict[str, str]: 164 | index = index + 1 # linecache starts at 1 165 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip("\n") 166 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 167 | assert source_line, f"empty source line for index {index}" 168 | assert tgt_line, f"empty tgt line for index {index}" 169 | return { 170 | "tgt_texts": tgt_line, 171 | "src_texts": source_line, 172 | } 173 | 174 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 175 | batch_encoding = self.tokenizer.prepare_translation_batch( 176 | [x["src_texts"] for x in batch], 177 | src_lang=self.src_lang, 178 | tgt_texts=[x["tgt_texts"] for x in batch], 179 | tgt_lang=self.tgt_lang, 180 | max_length=self.max_source_length, 181 | ) 182 | return batch_encoding.data 183 | 184 | 185 | class SortishSampler(Sampler): 186 | "Go through the text data by order of src length with a bit of randomness. From fastai repo." 187 | 188 | def __init__(self, data, batch_size): 189 | self.data, self.bs = data, batch_size 190 | 191 | def key(self, i): 192 | return self.data[i] 193 | 194 | def __len__(self) -> int: 195 | return len(self.data) 196 | 197 | def __iter__(self): 198 | idxs = np.random.permutation(len(self.data)) 199 | sz = self.bs * 50 200 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] 201 | sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) 202 | sz = self.bs 203 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] 204 | max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, 205 | ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. 206 | sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) 207 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) 208 | return iter(sort_idx) 209 | 210 | 211 | logger = getLogger(__name__) 212 | 213 | 214 | def use_task_specific_params(model, task): 215 | """Update config with summarization specific params.""" 216 | task_specific_params = model.config.task_specific_params 217 | 218 | if task_specific_params is not None: 219 | pars = task_specific_params.get(task, {}) 220 | logger.info(f"using task specific params for {task}: {pars}") 221 | model.config.update(pars) 222 | 223 | 224 | def pickle_load(path): 225 | """pickle.load(path)""" 226 | with open(path, "rb") as f: 227 | return pickle.load(f) 228 | 229 | 230 | def pickle_save(obj, path): 231 | """pickle.dump(obj, path)""" 232 | with open(path, "wb") as f: 233 | return pickle.dump(obj, f) 234 | 235 | 236 | def flatten_list(summary_ids: List[List]): 237 | return [x for x in itertools.chain.from_iterable(summary_ids)] 238 | 239 | 240 | def save_git_info(folder_path: str) -> None: 241 | """Save git information to output_dir/git_log.json""" 242 | repo_infos = get_git_info() 243 | save_json(repo_infos, os.path.join(folder_path, "git_log.json")) 244 | 245 | 246 | def save_json(content, path): 247 | with open(path, "w") as f: 248 | json.dump(content, f, indent=4) 249 | 250 | 251 | def load_json(path): 252 | with open(path) as f: 253 | return json.load(f) 254 | 255 | 256 | def get_git_info(): 257 | repo = git.Repo(search_parent_directories=True) 258 | repo_infos = { 259 | "repo_id": str(repo), 260 | "repo_sha": str(repo.head.object.hexsha), 261 | "repo_branch": str(repo.active_branch), 262 | } 263 | return repo_infos 264 | 265 | 266 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] 267 | 268 | 269 | def calculate_rouge(output_lns: List[str], reference_lns: List[str], use_stemmer=True) -> Dict: 270 | scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) 271 | aggregator = scoring.BootstrapAggregator() 272 | 273 | for reference_ln, output_ln in zip(reference_lns, output_lns): 274 | scores = scorer.score(reference_ln, output_ln) 275 | aggregator.add_scores(scores) 276 | 277 | result = aggregator.aggregate() 278 | return {k: v.mid.fmeasure for k, v in result.items()} 279 | 280 | 281 | def freeze_params(model: nn.Module): 282 | for par in model.parameters(): 283 | par.requires_grad = False 284 | 285 | 286 | def grad_status(model: nn.Module) -> Iterable: 287 | return (par.requires_grad for par in model.parameters()) 288 | 289 | 290 | def any_requires_grad(model: nn.Module) -> bool: 291 | return any(grad_status(model)) 292 | 293 | 294 | def assert_all_frozen(model): 295 | model_grads: List[bool] = list(grad_status(model)) 296 | n_require_grad = sum(lmap(int, model_grads)) 297 | npars = len(model_grads) 298 | assert not any(model_grads), f"{n_require_grad/npars:.1%} of {npars} weights require grad" 299 | 300 | 301 | def assert_not_all_frozen(model): 302 | model_grads: List[bool] = list(grad_status(model)) 303 | npars = len(model_grads) 304 | assert any(model_grads), f"none of {npars} weights require grad" 305 | --------------------------------------------------------------------------------